conv-emformer-model.cc 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. // sherpa-ncnn/csrc/conv-emformer-model.cc
  2. //
  3. // Copyright (c) 2022 Xiaomi Corporation
  4. #include "sherpa-ncnn/csrc/conv-emformer-model.h"
  5. #include <regex> // NOLINT
  6. #include <string>
  7. #include <utility>
  8. #include <vector>
  9. #include "net.h" // NOLINT
  10. #include "platform.h" // NOLINT
  11. #include "sherpa-ncnn/csrc/meta-data.h"
  12. namespace sherpa_ncnn {
  13. ConvEmformerModel::ConvEmformerModel(const ModelConfig &config)
  14. : num_threads_(config.num_threads) {
  15. bool has_gpu = false;
  16. #if NCNN_VULKAN
  17. has_gpu = ncnn::get_gpu_count() > 0;
  18. #endif
  19. if (has_gpu && config.use_vulkan_compute) {
  20. encoder_.opt.use_vulkan_compute = true;
  21. decoder_.opt.use_vulkan_compute = true;
  22. joiner_.opt.use_vulkan_compute = true;
  23. NCNN_LOGE("Use GPU");
  24. } else {
  25. NCNN_LOGE("Don't Use GPU. has_gpu: %d, config.use_vulkan_compute: %d",
  26. static_cast<int32_t>(has_gpu),
  27. static_cast<int32_t>(config.use_vulkan_compute));
  28. }
  29. InitEncoder(config.encoder_param, config.encoder_bin);
  30. InitDecoder(config.decoder_param, config.decoder_bin);
  31. InitJoiner(config.joiner_param, config.joiner_bin);
  32. InitEncoderInputOutputIndexes();
  33. InitDecoderInputOutputIndexes();
  34. InitJoinerInputOutputIndexes();
  35. }
  36. #if __ANDROID_API__ >= 9
  37. ConvEmformerModel::ConvEmformerModel(AAssetManager *mgr,
  38. const ModelConfig &config)
  39. : num_threads_(config.num_threads) {
  40. InitEncoder(mgr, config.encoder_param, config.encoder_bin);
  41. InitDecoder(mgr, config.decoder_param, config.decoder_bin);
  42. InitJoiner(mgr, config.joiner_param, config.joiner_bin);
  43. InitEncoderInputOutputIndexes();
  44. InitDecoderInputOutputIndexes();
  45. InitJoinerInputOutputIndexes();
  46. }
  47. #endif
  48. std::pair<ncnn::Mat, std::vector<ncnn::Mat>> ConvEmformerModel::RunEncoder(
  49. ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
  50. std::vector<ncnn::Mat> _states;
  51. const ncnn::Mat *p;
  52. if (states.empty()) {
  53. _states = GetEncoderInitStates();
  54. p = _states.data();
  55. } else {
  56. p = states.data();
  57. }
  58. ncnn::Extractor encoder_ex = encoder_.create_extractor();
  59. encoder_ex.set_num_threads(num_threads_);
  60. // Note: We ignore error check there
  61. encoder_ex.input(encoder_input_indexes_[0], features);
  62. for (int32_t i = 1; i != encoder_input_indexes_.size(); ++i) {
  63. encoder_ex.input(encoder_input_indexes_[i], p[i - 1]);
  64. }
  65. ncnn::Mat encoder_out;
  66. encoder_ex.extract(encoder_output_indexes_[0], encoder_out);
  67. std::vector<ncnn::Mat> next_states(num_layers_ * 4);
  68. for (int32_t i = 1; i != encoder_output_indexes_.size(); ++i) {
  69. encoder_ex.extract(encoder_output_indexes_[i], next_states[i - 1]);
  70. }
  71. return {encoder_out, next_states};
  72. }
  73. ncnn::Mat ConvEmformerModel::RunDecoder(ncnn::Mat &decoder_input) {
  74. ncnn::Extractor decoder_ex = decoder_.create_extractor();
  75. decoder_ex.set_num_threads(num_threads_);
  76. ncnn::Mat decoder_out;
  77. decoder_ex.input(decoder_input_indexes_[0], decoder_input);
  78. decoder_ex.extract(decoder_output_indexes_[0], decoder_out);
  79. decoder_out = decoder_out.reshape(decoder_out.w);
  80. return decoder_out;
  81. }
  82. ncnn::Mat ConvEmformerModel::RunJoiner(ncnn::Mat &encoder_out,
  83. ncnn::Mat &decoder_out) {
  84. auto joiner_ex = joiner_.create_extractor();
  85. joiner_ex.set_num_threads(num_threads_);
  86. joiner_ex.input(joiner_input_indexes_[0], encoder_out);
  87. joiner_ex.input(joiner_input_indexes_[1], decoder_out);
  88. ncnn::Mat joiner_out;
  89. joiner_ex.extract("out0", joiner_out);
  90. return joiner_out;
  91. }
  92. void ConvEmformerModel::InitEncoderPostProcessing() {
  93. // Now load parameters for member variables
  94. for (const auto *layer : encoder_.layers()) {
  95. if (layer->type == "SherpaMetaData" && layer->name == "sherpa_meta_data1") {
  96. // Note: We don't use dynamic_cast<> here since it will throw
  97. // the following error
  98. // error: ‘dynamic_cast’ not permitted with -fno-rtti
  99. const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
  100. num_layers_ = meta_data->arg1;
  101. memory_size_ = meta_data->arg2;
  102. cnn_module_kernel_ = meta_data->arg3;
  103. left_context_length_ = meta_data->arg4;
  104. chunk_length_ = meta_data->arg5;
  105. right_context_length_ = meta_data->arg6;
  106. d_model_ = meta_data->arg7;
  107. break;
  108. }
  109. }
  110. }
  111. void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
  112. const std::string &encoder_bin) {
  113. RegisterMetaDataLayer(encoder_);
  114. InitNet(encoder_, encoder_param, encoder_bin);
  115. InitEncoderPostProcessing();
  116. }
  117. void ConvEmformerModel::InitDecoder(const std::string &decoder_param,
  118. const std::string &decoder_bin) {
  119. InitNet(decoder_, decoder_param, decoder_bin);
  120. }
  121. void ConvEmformerModel::InitJoiner(const std::string &joiner_param,
  122. const std::string &joiner_bin) {
  123. InitNet(joiner_, joiner_param, joiner_bin);
  124. }
  125. #if __ANDROID_API__ >= 9
  126. void ConvEmformerModel::InitEncoder(AAssetManager *mgr,
  127. const std::string &encoder_param,
  128. const std::string &encoder_bin) {
  129. RegisterMetaDataLayer(encoder_);
  130. InitNet(mgr, encoder_, encoder_param, encoder_bin);
  131. InitEncoderPostProcessing();
  132. }
  133. void ConvEmformerModel::InitDecoder(AAssetManager *mgr,
  134. const std::string &decoder_param,
  135. const std::string &decoder_bin) {
  136. InitNet(mgr, decoder_, decoder_param, decoder_bin);
  137. }
  138. void ConvEmformerModel::InitJoiner(AAssetManager *mgr,
  139. const std::string &joiner_param,
  140. const std::string &joiner_bin) {
  141. InitNet(mgr, joiner_, joiner_param, joiner_bin);
  142. }
  143. #endif
  144. std::vector<ncnn::Mat> ConvEmformerModel::GetEncoderInitStates() const {
  145. std::vector<ncnn::Mat> states;
  146. states.reserve(num_layers_ * 4);
  147. for (int32_t i = 0; i != num_layers_; ++i) {
  148. auto s0 = ncnn::Mat(d_model_, memory_size_);
  149. auto s1 = ncnn::Mat(d_model_, left_context_length_);
  150. auto s2 = ncnn::Mat(d_model_, left_context_length_);
  151. auto s3 = ncnn::Mat(cnn_module_kernel_ - 1, d_model_);
  152. s0.fill(0);
  153. s1.fill(0);
  154. s2.fill(0);
  155. s3.fill(0);
  156. states.push_back(s0);
  157. states.push_back(s1);
  158. states.push_back(s2);
  159. states.push_back(s3);
  160. }
  161. return states;
  162. }
  163. void ConvEmformerModel::InitEncoderInputOutputIndexes() {
  164. // input indexes map
  165. // [0] -> in0, features,
  166. // [1] -> in1, layer0, s0
  167. // [2] -> in2, layer0, s1
  168. // [3] -> in3, layer0, s2
  169. // [4] -> in4, layer0, s3
  170. //
  171. // [5] -> in5, layer1, s0
  172. // [6] -> in6, layer1, s1
  173. // [7] -> in7, layer1, s2
  174. // [8] -> in8, layer1, s3
  175. //
  176. // until layer 11
  177. encoder_input_indexes_.resize(1 + num_layers_ * 4);
  178. // output indexes map
  179. // [0] -> out0, encoder_out
  180. //
  181. // [1] -> out1, layer0, s0
  182. // [2] -> out2, layer0, s1
  183. // [3] -> out3, layer0, s2
  184. // [4] -> out4, layer0, s3
  185. //
  186. // [5] -> out5, layer1, s0
  187. // [6] -> out6, layer1, s1
  188. // [7] -> out7, layer1, s2
  189. // [8] -> out8, layer1, s3
  190. encoder_output_indexes_.resize(1 + num_layers_ * 4);
  191. const auto &blobs = encoder_.blobs();
  192. std::regex in_regex("in(\\d+)");
  193. std::regex out_regex("out(\\d+)");
  194. std::smatch match;
  195. for (int32_t i = 0; i != blobs.size(); ++i) {
  196. const auto &b = blobs[i];
  197. if (std::regex_match(b.name, match, in_regex)) {
  198. auto index = std::atoi(match[1].str().c_str());
  199. encoder_input_indexes_[index] = i;
  200. } else if (std::regex_match(b.name, match, out_regex)) {
  201. auto index = std::atoi(match[1].str().c_str());
  202. encoder_output_indexes_[index] = i;
  203. }
  204. }
  205. }
  206. void ConvEmformerModel::InitDecoderInputOutputIndexes() {
  207. // input indexes map
  208. // [0] -> in0, decoder_input,
  209. decoder_input_indexes_.resize(1);
  210. // output indexes map
  211. // [0] -> out0, decoder_out,
  212. decoder_output_indexes_.resize(1);
  213. const auto &blobs = decoder_.blobs();
  214. for (int32_t i = 0; i != blobs.size(); ++i) {
  215. const auto &b = blobs[i];
  216. if (b.name == "in0") decoder_input_indexes_[0] = i;
  217. if (b.name == "out0") decoder_output_indexes_[0] = i;
  218. }
  219. }
  220. void ConvEmformerModel::InitJoinerInputOutputIndexes() {
  221. // input indexes map
  222. // [0] -> in0, encoder_input,
  223. // [1] -> in1, decoder_input,
  224. joiner_input_indexes_.resize(2);
  225. // output indexes map
  226. // [0] -> out0, joiner_out,
  227. joiner_output_indexes_.resize(1);
  228. const auto &blobs = joiner_.blobs();
  229. for (int32_t i = 0; i != blobs.size(); ++i) {
  230. const auto &b = blobs[i];
  231. if (b.name == "in0") joiner_input_indexes_[0] = i;
  232. if (b.name == "in1") joiner_input_indexes_[1] = i;
  233. if (b.name == "out0") joiner_output_indexes_[0] = i;
  234. }
  235. }
  236. } // namespace sherpa_ncnn