conv-emformer-model.cc 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. // sherpa-ncnn/csrc/conv-emformer-model.cc
  2. //
  3. // Copyright (c) 2022 Xiaomi Corporation
  4. #include "sherpa-ncnn/csrc/conv-emformer-model.h"
  5. #include <regex> // NOLINT
  6. #include <string>
  7. #include <utility>
  8. #include <vector>
  9. #include "net.h" // NOLINT
  10. #include "platform.h" // NOLINT
  11. #include "sherpa-ncnn/csrc/meta-data.h"
  12. namespace sherpa_ncnn {
  13. ConvEmformerModel::ConvEmformerModel(const ModelConfig &config) {
  14. encoder_.opt = config.encoder_opt;
  15. decoder_.opt = config.decoder_opt;
  16. joiner_.opt = config.joiner_opt;
  17. bool has_gpu = false;
  18. #if NCNN_VULKAN
  19. has_gpu = ncnn::get_gpu_count() > 0;
  20. #endif
  21. if (has_gpu && config.use_vulkan_compute) {
  22. encoder_.opt.use_vulkan_compute = true;
  23. decoder_.opt.use_vulkan_compute = true;
  24. joiner_.opt.use_vulkan_compute = true;
  25. NCNN_LOGE("Use GPU");
  26. } else {
  27. // NCNN_LOGE("Don't Use GPU. has_gpu: %d, config.use_vulkan_compute: %d",
  28. // static_cast<int32_t>(has_gpu),
  29. // static_cast<int32_t>(config.use_vulkan_compute));
  30. }
  31. InitEncoder(config.encoder_param, config.encoder_bin);
  32. InitDecoder(config.decoder_param, config.decoder_bin);
  33. InitJoiner(config.joiner_param, config.joiner_bin);
  34. InitEncoderInputOutputIndexes();
  35. InitDecoderInputOutputIndexes();
  36. InitJoinerInputOutputIndexes();
  37. }
  38. #if __ANDROID_API__ >= 9
  39. ConvEmformerModel::ConvEmformerModel(AAssetManager *mgr,
  40. const ModelConfig &config) {
  41. encoder_.opt = config.encoder_opt;
  42. decoder_.opt = config.decoder_opt;
  43. joiner_.opt = config.joiner_opt;
  44. bool has_gpu = false;
  45. #if NCNN_VULKAN
  46. has_gpu = ncnn::get_gpu_count() > 0;
  47. #endif
  48. if (has_gpu && config.use_vulkan_compute) {
  49. encoder_.opt.use_vulkan_compute = true;
  50. decoder_.opt.use_vulkan_compute = true;
  51. joiner_.opt.use_vulkan_compute = true;
  52. NCNN_LOGE("Use GPU");
  53. } else {
  54. // NCNN_LOGE("Don't Use GPU. has_gpu: %d, config.use_vulkan_compute: %d",
  55. // static_cast<int32_t>(has_gpu),
  56. // static_cast<int32_t>(config.use_vulkan_compute));
  57. }
  58. InitEncoder(mgr, config.encoder_param, config.encoder_bin);
  59. InitDecoder(mgr, config.decoder_param, config.decoder_bin);
  60. InitJoiner(mgr, config.joiner_param, config.joiner_bin);
  61. InitEncoderInputOutputIndexes();
  62. InitDecoderInputOutputIndexes();
  63. InitJoinerInputOutputIndexes();
  64. }
  65. #endif
  66. std::pair<ncnn::Mat, std::vector<ncnn::Mat>> ConvEmformerModel::RunEncoder(
  67. ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
  68. ncnn::Extractor encoder_ex = encoder_.create_extractor();
  69. return RunEncoder(features, states, &encoder_ex);
  70. }
  71. std::pair<ncnn::Mat, std::vector<ncnn::Mat>> ConvEmformerModel::RunEncoder(
  72. ncnn::Mat &features, const std::vector<ncnn::Mat> &states,
  73. ncnn::Extractor *encoder_ex) {
  74. std::vector<ncnn::Mat> _states;
  75. const ncnn::Mat *p;
  76. if (states.empty()) {
  77. _states = GetEncoderInitStates();
  78. p = _states.data();
  79. } else {
  80. p = states.data();
  81. }
  82. // Note: We ignore error check there
  83. encoder_ex->input(encoder_input_indexes_[0], features);
  84. for (int32_t i = 1; i != encoder_input_indexes_.size(); ++i) {
  85. encoder_ex->input(encoder_input_indexes_[i], p[i - 1]);
  86. }
  87. ncnn::Mat encoder_out;
  88. encoder_ex->extract(encoder_output_indexes_[0], encoder_out);
  89. std::vector<ncnn::Mat> next_states(num_layers_ * 4);
  90. for (int32_t i = 1; i != encoder_output_indexes_.size(); ++i) {
  91. encoder_ex->extract(encoder_output_indexes_[i], next_states[i - 1]);
  92. }
  93. return {encoder_out, next_states};
  94. }
  95. ncnn::Mat ConvEmformerModel::RunDecoder(ncnn::Mat &decoder_input) {
  96. ncnn::Extractor decoder_ex = decoder_.create_extractor();
  97. return RunDecoder(decoder_input, &decoder_ex);
  98. }
  99. ncnn::Mat ConvEmformerModel::RunDecoder(ncnn::Mat &decoder_input,
  100. ncnn::Extractor *decoder_ex) {
  101. ncnn::Mat decoder_out;
  102. decoder_ex->input(decoder_input_indexes_[0], decoder_input);
  103. decoder_ex->extract(decoder_output_indexes_[0], decoder_out);
  104. decoder_out = decoder_out.reshape(decoder_out.w);
  105. return decoder_out;
  106. }
  107. ncnn::Mat ConvEmformerModel::RunJoiner(ncnn::Mat &encoder_out,
  108. ncnn::Mat &decoder_out) {
  109. auto joiner_ex = joiner_.create_extractor();
  110. return RunJoiner(encoder_out, decoder_out, &joiner_ex);
  111. }
  112. ncnn::Mat ConvEmformerModel::RunJoiner(ncnn::Mat &encoder_out,
  113. ncnn::Mat &decoder_out,
  114. ncnn::Extractor *joiner_ex) {
  115. joiner_ex->input(joiner_input_indexes_[0], encoder_out);
  116. joiner_ex->input(joiner_input_indexes_[1], decoder_out);
  117. ncnn::Mat joiner_out;
  118. joiner_ex->extract(joiner_output_indexes_[0], joiner_out);
  119. return joiner_out;
  120. }
  121. void ConvEmformerModel::InitEncoderPostProcessing() {
  122. // Now load parameters for member variables
  123. for (const auto *layer : encoder_.layers()) {
  124. if (layer->type == "SherpaMetaData" && layer->name == "sherpa_meta_data1") {
  125. // Note: We don't use dynamic_cast<> here since it will throw
  126. // the following error
  127. // error: ‘dynamic_cast’ not permitted with -fno-rtti
  128. const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
  129. num_layers_ = meta_data->arg1;
  130. memory_size_ = meta_data->arg2;
  131. cnn_module_kernel_ = meta_data->arg3;
  132. left_context_length_ = meta_data->arg4;
  133. chunk_length_ = meta_data->arg5;
  134. right_context_length_ = meta_data->arg6;
  135. d_model_ = meta_data->arg7;
  136. break;
  137. }
  138. }
  139. }
  140. void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
  141. const std::string &encoder_bin) {
  142. RegisterCustomLayers(encoder_);
  143. InitNet(encoder_, encoder_param, encoder_bin);
  144. InitEncoderPostProcessing();
  145. }
  146. void ConvEmformerModel::InitDecoder(const std::string &decoder_param,
  147. const std::string &decoder_bin) {
  148. InitNet(decoder_, decoder_param, decoder_bin);
  149. }
  150. void ConvEmformerModel::InitJoiner(const std::string &joiner_param,
  151. const std::string &joiner_bin) {
  152. InitNet(joiner_, joiner_param, joiner_bin);
  153. }
  154. #if __ANDROID_API__ >= 9
  155. void ConvEmformerModel::InitEncoder(AAssetManager *mgr,
  156. const std::string &encoder_param,
  157. const std::string &encoder_bin) {
  158. RegisterCustomLayers(encoder_);
  159. InitNet(mgr, encoder_, encoder_param, encoder_bin);
  160. InitEncoderPostProcessing();
  161. }
  162. void ConvEmformerModel::InitDecoder(AAssetManager *mgr,
  163. const std::string &decoder_param,
  164. const std::string &decoder_bin) {
  165. InitNet(mgr, decoder_, decoder_param, decoder_bin);
  166. }
  167. void ConvEmformerModel::InitJoiner(AAssetManager *mgr,
  168. const std::string &joiner_param,
  169. const std::string &joiner_bin) {
  170. InitNet(mgr, joiner_, joiner_param, joiner_bin);
  171. }
  172. #endif
  173. std::vector<ncnn::Mat> ConvEmformerModel::GetEncoderInitStates() const {
  174. std::vector<ncnn::Mat> states;
  175. states.reserve(num_layers_ * 4);
  176. for (int32_t i = 0; i != num_layers_; ++i) {
  177. auto s0 = ncnn::Mat(d_model_, memory_size_);
  178. auto s1 = ncnn::Mat(d_model_, left_context_length_);
  179. auto s2 = ncnn::Mat(d_model_, left_context_length_);
  180. auto s3 = ncnn::Mat(cnn_module_kernel_ - 1, d_model_);
  181. s0.fill(0);
  182. s1.fill(0);
  183. s2.fill(0);
  184. s3.fill(0);
  185. states.push_back(s0);
  186. states.push_back(s1);
  187. states.push_back(s2);
  188. states.push_back(s3);
  189. }
  190. return states;
  191. }
  192. void ConvEmformerModel::InitEncoderInputOutputIndexes() {
  193. // input indexes map
  194. // [0] -> in0, features,
  195. // [1] -> in1, layer0, s0
  196. // [2] -> in2, layer0, s1
  197. // [3] -> in3, layer0, s2
  198. // [4] -> in4, layer0, s3
  199. //
  200. // [5] -> in5, layer1, s0
  201. // [6] -> in6, layer1, s1
  202. // [7] -> in7, layer1, s2
  203. // [8] -> in8, layer1, s3
  204. //
  205. // until layer 11
  206. encoder_input_indexes_.resize(1 + num_layers_ * 4);
  207. // output indexes map
  208. // [0] -> out0, encoder_out
  209. //
  210. // [1] -> out1, layer0, s0
  211. // [2] -> out2, layer0, s1
  212. // [3] -> out3, layer0, s2
  213. // [4] -> out4, layer0, s3
  214. //
  215. // [5] -> out5, layer1, s0
  216. // [6] -> out6, layer1, s1
  217. // [7] -> out7, layer1, s2
  218. // [8] -> out8, layer1, s3
  219. encoder_output_indexes_.resize(1 + num_layers_ * 4);
  220. const auto &blobs = encoder_.blobs();
  221. std::regex in_regex("in(\\d+)");
  222. std::regex out_regex("out(\\d+)");
  223. std::smatch match;
  224. for (int32_t i = 0; i != blobs.size(); ++i) {
  225. const auto &b = blobs[i];
  226. if (std::regex_match(b.name, match, in_regex)) {
  227. auto index = std::atoi(match[1].str().c_str());
  228. encoder_input_indexes_[index] = i;
  229. } else if (std::regex_match(b.name, match, out_regex)) {
  230. auto index = std::atoi(match[1].str().c_str());
  231. encoder_output_indexes_[index] = i;
  232. }
  233. }
  234. }
  235. void ConvEmformerModel::InitDecoderInputOutputIndexes() {
  236. // input indexes map
  237. // [0] -> in0, decoder_input,
  238. decoder_input_indexes_.resize(1);
  239. // output indexes map
  240. // [0] -> out0, decoder_out,
  241. decoder_output_indexes_.resize(1);
  242. const auto &blobs = decoder_.blobs();
  243. for (int32_t i = 0; i != blobs.size(); ++i) {
  244. const auto &b = blobs[i];
  245. if (b.name == "in0") decoder_input_indexes_[0] = i;
  246. if (b.name == "out0") decoder_output_indexes_[0] = i;
  247. }
  248. }
  249. void ConvEmformerModel::InitJoinerInputOutputIndexes() {
  250. // input indexes map
  251. // [0] -> in0, encoder_input,
  252. // [1] -> in1, decoder_input,
  253. joiner_input_indexes_.resize(2);
  254. // output indexes map
  255. // [0] -> out0, joiner_out,
  256. joiner_output_indexes_.resize(1);
  257. const auto &blobs = joiner_.blobs();
  258. for (int32_t i = 0; i != blobs.size(); ++i) {
  259. const auto &b = blobs[i];
  260. if (b.name == "in0") joiner_input_indexes_[0] = i;
  261. if (b.name == "in1") joiner_input_indexes_[1] = i;
  262. if (b.name == "out0") joiner_output_indexes_[0] = i;
  263. }
  264. }
  265. } // namespace sherpa_ncnn