浏览代码

Support conv-emformer (#25)

* Support conv-emformer

* small fixes

* small fixes

* Refactor lstm

* Refactor conv-emformer-model

* Requires c++11

* small fixes
Fangjun Kuang 2 年之前
父节点
当前提交
c98d2999e2

+ 3 - 0
CMakeLists.txt

@@ -38,6 +38,9 @@ set(CMAKE_CXX_EXTENSIONS OFF)
 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
 
+set(CMAKE_CXX_STANDARD 11 CACHE STRING "The C++ version to be used.")
+set(CMAKE_CXX_EXTENSIONS OFF)
+
 include(kaldi-native-fbank)
 include(ncnn)
 include(portaudio)

+ 20 - 11
cmake/ncnn.cmake

@@ -1,8 +1,16 @@
 function(download_ncnn)
   include(FetchContent)
 
-  set(ncnn_URL  "http://github.com/csukuangfj/ncnn/archive/refs/tags/sherpa-0.6.tar.gz")
-  set(ncnn_HASH "SHA256=aac5298f00ae9ce447c2aefa6c46579dcb3a284b9ce17687c182ecf4d499b3c8")
+  # We use a modified version of NCNN.
+  # The changed code is in
+  # https://github.com/csukuangfj/ncnn/pull/7
+
+  # If you don't have access to the internet, please download it to your
+  # local drive and modify the following line according to your needs.
+  # set(ncnn_URL  "file:///ceph-fj/fangjun/372e5f3d0e8b4024e377388b0f336bc4397a2f06.zip")
+
+  set(ncnn_URL  "https://github.com/csukuangfj/ncnn/archive/372e5f3d0e8b4024e377388b0f336bc4397a2f06.zip")
+  set(ncnn_HASH "SHA256=1b1bcd510085c5173a1fb1f7d1459690b8919dd2fa527b1140e39d2a820e0ae0")
 
   FetchContent_Declare(ncnn
     URL               ${ncnn_URL}
@@ -16,9 +24,6 @@ function(download_ncnn)
   set(NCNN_PIXEL_DRAWING OFF CACHE BOOL "" FORCE)
   set(NCNN_BUILD_BENCHMARK OFF CACHE BOOL "" FORCE)
 
-  set(NCNN_INT8 OFF CACHE BOOL "" FORCE) # TODO(fangjun): enable it
-  set(NCNN_BF16 OFF CACHE BOOL "" FORCE) # TODO(fangjun): enable it
-
   set(NCNN_BUILD_TOOLS OFF CACHE BOOL "" FORCE)
   set(NCNN_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
   set(NCNN_BUILD_TESTS OFF CACHE BOOL "" FORCE)
@@ -61,8 +66,8 @@ function(download_ncnn)
     ROIPooling
     Scale
     # Sigmoid
-    Slice
-    Softmax
+    # Slice
+    # Softmax
     # Split
     SPP
     # TanH
@@ -111,15 +116,15 @@ function(download_ncnn)
     GRU
     MultiHeadAttention
     GELU
-    Convolution1D
+    # Convolution1D
     Pooling1D
     # ConvolutionDepthWise1D
     Convolution3D
     ConvolutionDepthWise3D
     Pooling3D
-    MatMul
+    # MatMul
     Deconvolution1D
-    DeconvolutionDepthWise1D
+    # DeconvolutionDepthWise1D
     Deconvolution3D
     DeconvolutionDepthWise3D
     Einsum
@@ -127,8 +132,12 @@ function(download_ncnn)
     RelPositionalEncoding
     MakePadMask
     RelShift
-    GLU
+    # GLU
+    Fold
+    Unfold
+    GridSample
   )
+
   foreach(layer IN LISTS disabled_layers)
     string(TOLOWER ${layer} name)
     set(WITH_LAYER_${name} OFF CACHE BOOL "" FORCE)

+ 1 - 0
sherpa-ncnn/csrc/CMakeLists.txt

@@ -1,6 +1,7 @@
 include_directories(${CMAKE_SOURCE_DIR})
 
 set(sherpa_ncnn_core_srcs
+  conv-emformer-model.cc
   decode.cc
   features.cc
   lstm-model.cc

+ 202 - 0
sherpa-ncnn/csrc/conv-emformer-model.cc

@@ -0,0 +1,202 @@
+// sherpa-ncnn/csrc/conv-emformer-model.cc
+//
+// Copyright (c)  2022  Xiaomi Corporation
+
+#include "sherpa-ncnn/csrc/conv-emformer-model.h"
+
+#include <regex>
+
+#include "net.h"  // NOLINT
+
+namespace sherpa_ncnn {
+
+ConvEmformerModel::ConvEmformerModel(const ModelConfig &config)
+    : num_threads_(config.num_threads) {
+  InitEncoder(config.encoder_param, config.encoder_bin);
+  InitDecoder(config.decoder_param, config.decoder_bin);
+  InitJoiner(config.joiner_param, config.joiner_bin);
+
+  InitEncoderInputOutputIndexes();
+  InitDecoderInputOutputIndexes();
+  InitJoinerInputOutputIndexes();
+}
+
+std::pair<ncnn::Mat, std::vector<ncnn::Mat>> ConvEmformerModel::RunEncoder(
+    ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
+  std::vector<ncnn::Mat> _states;
+
+  const ncnn::Mat *p;
+  if (states.empty()) {
+    _states = GetEncoderInitStates();
+    p = _states.data();
+  } else {
+    p = states.data();
+  }
+
+  ncnn::Extractor encoder_ex = encoder_.create_extractor();
+  encoder_ex.set_num_threads(num_threads_);
+
+  // Note: We ignore error check there
+  encoder_ex.input(encoder_input_indexes_[0], features);
+  for (int32_t i = 1; i != encoder_input_indexes_.size(); ++i) {
+    encoder_ex.input(encoder_input_indexes_[i], p[i - 1]);
+  }
+
+  ncnn::Mat encoder_out;
+  encoder_ex.extract(encoder_output_indexes_[0], encoder_out);
+
+  std::vector<ncnn::Mat> next_states(num_layers_ * 4);
+  for (int32_t i = 1; i != encoder_output_indexes_.size(); ++i) {
+    encoder_ex.extract(encoder_output_indexes_[i], next_states[i - 1]);
+  }
+
+  return {encoder_out, next_states};
+}
+
+ncnn::Mat ConvEmformerModel::RunDecoder(ncnn::Mat &decoder_input) {
+  ncnn::Extractor decoder_ex = decoder_.create_extractor();
+  decoder_ex.set_num_threads(num_threads_);
+
+  ncnn::Mat decoder_out;
+  decoder_ex.input(decoder_input_indexes_[0], decoder_input);
+  decoder_ex.extract(decoder_output_indexes_[0], decoder_out);
+  decoder_out = decoder_out.reshape(decoder_out.w);
+
+  return decoder_out;
+}
+
+ncnn::Mat ConvEmformerModel::RunJoiner(ncnn::Mat &encoder_out,
+                                       ncnn::Mat &decoder_out) {
+  auto joiner_ex = joiner_.create_extractor();
+  joiner_ex.set_num_threads(num_threads_);
+  joiner_ex.input(joiner_input_indexes_[0], encoder_out);
+  joiner_ex.input(joiner_input_indexes_[1], decoder_out);
+
+  ncnn::Mat joiner_out;
+  joiner_ex.extract("out0", joiner_out);
+  return joiner_out;
+}
+
+void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
+                                    const std::string &encoder_bin) {
+  InitNet(encoder_, encoder_param, encoder_bin);
+}
+
+void ConvEmformerModel::InitDecoder(const std::string &decoder_param,
+                                    const std::string &decoder_bin) {
+  InitNet(decoder_, decoder_param, decoder_bin);
+}
+
+void ConvEmformerModel::InitJoiner(const std::string &joiner_param,
+                                   const std::string &joiner_bin) {
+  InitNet(joiner_, joiner_param, joiner_bin);
+}
+
+std::vector<ncnn::Mat> ConvEmformerModel::GetEncoderInitStates() const {
+  std::vector<ncnn::Mat> states;
+  states.reserve(num_layers_ * 4);
+
+  for (int32_t i = 0; i != num_layers_; ++i) {
+    auto s0 = ncnn::Mat(d_model_, memory_size_);
+    auto s1 = ncnn::Mat(d_model_, left_context_length_);
+    auto s2 = ncnn::Mat(d_model_, left_context_length_);
+    auto s3 = ncnn::Mat(cnn_module_kernel_ - 1, d_model_);
+
+    s0.fill(0);
+    s1.fill(0);
+    s2.fill(0);
+    s3.fill(0);
+
+    states.push_back(s0);
+    states.push_back(s1);
+    states.push_back(s2);
+    states.push_back(s3);
+  }
+
+  return states;
+}
+
+void ConvEmformerModel::InitEncoderInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, features,
+  // [1] -> in1, layer0, s0
+  // [2] -> in2, layer0, s1
+  // [3] -> in3, layer0, s2
+  // [4] -> in4, layer0, s3
+  //
+  // [5] -> in5, layer1, s0
+  // [6] -> in6, layer1, s1
+  // [7] -> in7, layer1, s2
+  // [8] -> in8, layer1, s3
+  //
+  // until layer 11
+  encoder_input_indexes_.resize(1 + num_layers_ * 4);
+
+  // output indexes map
+  // [0] -> out0, encoder_out
+  //
+  // [1] -> out1, layer0, s0
+  // [2] -> out2, layer0, s1
+  // [3] -> out3, layer0, s2
+  // [4] -> out4, layer0, s3
+  //
+  // [5] -> out5, layer1, s0
+  // [6] -> out6, layer1, s1
+  // [7] -> out7, layer1, s2
+  // [8] -> out8, layer1, s3
+  encoder_output_indexes_.resize(1 + num_layers_ * 4);
+  const auto &blobs = encoder_.blobs();
+
+  std::regex in_regex("in(\\d+)");
+  std::regex out_regex("out(\\d+)");
+
+  std::smatch match;
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (std::regex_match(b.name, match, in_regex)) {
+      auto index = std::atoi(match[1].str().c_str());
+      encoder_input_indexes_[index] = i;
+    } else if (std::regex_match(b.name, match, out_regex)) {
+      auto index = std::atoi(match[1].str().c_str());
+      encoder_output_indexes_[index] = i;
+    }
+  }
+}
+
+void ConvEmformerModel::InitDecoderInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, decoder_input,
+  decoder_input_indexes_.resize(1);
+
+  // output indexes map
+  // [0] -> out0, decoder_out,
+  decoder_output_indexes_.resize(1);
+
+  const auto &blobs = decoder_.blobs();
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (b.name == "in0") decoder_input_indexes_[0] = i;
+    if (b.name == "out0") decoder_output_indexes_[0] = i;
+  }
+}
+
+void ConvEmformerModel::InitJoinerInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, encoder_input,
+  // [1] -> in1, decoder_input,
+  joiner_input_indexes_.resize(2);
+
+  // output indexes map
+  // [0] -> out0, joiner_out,
+  joiner_output_indexes_.resize(1);
+
+  const auto &blobs = joiner_.blobs();
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (b.name == "in0") joiner_input_indexes_[0] = i;
+    if (b.name == "in1") joiner_input_indexes_[1] = i;
+    if (b.name == "out0") joiner_output_indexes_[0] = i;
+  }
+}
+
+}  // namespace sherpa_ncnn

+ 74 - 0
sherpa-ncnn/csrc/conv-emformer-model.h

@@ -0,0 +1,74 @@
+// sherpa-ncnn/csrc/conv-emformer-model.h
+//
+// Copyright (c)  2022  Xiaomi Corporation
+
+#include "net.h"  // NOLINT
+#include "sherpa-ncnn/csrc/model.h"
+
+namespace sherpa_ncnn {
+// Please refer to https://github.com/k2-fsa/icefall/pull/717
+// for how the model is converted from icefall to ncnn
+class ConvEmformerModel : public Model {
+ public:
+  explicit ConvEmformerModel(const ModelConfig &config);
+
+  std::pair<ncnn::Mat, std::vector<ncnn::Mat>> RunEncoder(
+      ncnn::Mat &features, const std::vector<ncnn::Mat> &states) override;
+
+  ncnn::Mat RunDecoder(ncnn::Mat &decoder_input) override;
+
+  ncnn::Mat RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out) override;
+
+  int32_t Segment() const override {
+    // chunk_length 32
+    // right_context 8
+    // subsampling factor 4
+    //
+    // segment = 32 + (8 + 2 * 4 + 3) = 32 + 19 = 51
+    return 51;
+  }
+
+  // Advance the feature extract by this number of frames after
+  // running the encoder network
+  int32_t Offset() const override { return chunk_length_; }
+
+ private:
+  void InitEncoder(const std::string &encoder_param,
+                   const std::string &encoder_bin);
+  void InitDecoder(const std::string &decoder_param,
+                   const std::string &decoder_bin);
+  void InitJoiner(const std::string &joiner_param,
+                  const std::string &joiner_bin);
+
+  std::vector<ncnn::Mat> GetEncoderInitStates() const;
+
+  void InitEncoderInputOutputIndexes();
+  void InitDecoderInputOutputIndexes();
+  void InitJoinerInputOutputIndexes();
+
+ private:
+  ncnn::Net encoder_;
+  ncnn::Net decoder_;
+  ncnn::Net joiner_;
+
+  int32_t num_threads_;
+
+  int32_t num_layers_ = 12;
+  int32_t memory_size_ = 32;
+  int32_t cnn_module_kernel_ = 31;
+  int32_t left_context_length_ = 32 / 4;
+  int32_t chunk_length_ = 32;
+  int32_t right_context_length_ = 8;
+  int32_t d_model_ = 512;
+
+  std::vector<int32_t> encoder_input_indexes_;
+  std::vector<int32_t> encoder_output_indexes_;
+
+  std::vector<int32_t> decoder_input_indexes_;
+  std::vector<int32_t> decoder_output_indexes_;
+
+  std::vector<int32_t> joiner_input_indexes_;
+  std::vector<int32_t> joiner_output_indexes_;
+};
+
+}  // namespace sherpa_ncnn

+ 97 - 36
sherpa-ncnn/csrc/lstm-model.cc

@@ -17,67 +17,52 @@
  */
 #include "sherpa-ncnn/csrc/lstm-model.h"
 
-#include <iostream>
 #include <utility>
 #include <vector>
 
 namespace sherpa_ncnn {
 
-static void InitNet(ncnn::Net &net, const std::string &param,
-                    const std::string &bin) {
-  if (net.load_param(param.c_str())) {
-    std::cerr << "failed to load " << param << "\n";
-    exit(-1);
-  }
-
-  if (net.load_model(bin.c_str())) {
-    std::cerr << "failed to load " << bin << "\n";
-    exit(-1);
-  }
-}
-
 LstmModel::LstmModel(const ModelConfig &config)
     : num_threads_(config.num_threads) {
   InitEncoder(config.encoder_param, config.encoder_bin);
   InitDecoder(config.decoder_param, config.decoder_bin);
   InitJoiner(config.joiner_param, config.joiner_bin);
+
+  InitEncoderInputOutputIndexes();
+  InitDecoderInputOutputIndexes();
+  InitJoinerInputOutputIndexes();
 }
 
 std::pair<ncnn::Mat, std::vector<ncnn::Mat>> LstmModel::RunEncoder(
     ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
-  int32_t num_encoder_layers = 12;
-  int32_t d_model = 512;
-  int32_t rnn_hidden_size = 1024;
   ncnn::Mat hx;
   ncnn::Mat cx;
 
   if (states.empty()) {
-    hx.create(d_model, num_encoder_layers);
-    cx.create(rnn_hidden_size, num_encoder_layers);
-
-    hx.fill(0);
-    cx.fill(0);
+    auto s = GetEncoderInitStates();
+    hx = s[0];
+    cx = s[1];
   } else {
     hx = states[0];
     cx = states[1];
   }
 
-  ncnn::Mat feature_lengths(1);
-  feature_lengths[0] = features.h;
+  ncnn::Mat feature_length(1);
+  feature_length[0] = features.h;
 
   ncnn::Extractor encoder_ex = encoder_.create_extractor();
   encoder_ex.set_num_threads(num_threads_);
 
-  encoder_ex.input("in0", features);
-  encoder_ex.input("in1", feature_lengths);
-  encoder_ex.input("in2", hx);
-  encoder_ex.input("in3", cx);
+  encoder_ex.input(encoder_input_indexes_[0], features);
+  encoder_ex.input(encoder_input_indexes_[1], feature_length);
+  encoder_ex.input(encoder_input_indexes_[2], hx);
+  encoder_ex.input(encoder_input_indexes_[3], cx);
 
   ncnn::Mat encoder_out;
-  encoder_ex.extract("out0", encoder_out);
+  encoder_ex.extract(encoder_output_indexes_[0], encoder_out);
 
-  encoder_ex.extract("out2", hx);
-  encoder_ex.extract("out3", cx);
+  encoder_ex.extract(encoder_output_indexes_[1], hx);
+  encoder_ex.extract(encoder_output_indexes_[2], cx);
 
   std::vector<ncnn::Mat> next_states = {hx, cx};
 
@@ -89,8 +74,8 @@ ncnn::Mat LstmModel::RunDecoder(ncnn::Mat &decoder_input) {
   decoder_ex.set_num_threads(num_threads_);
 
   ncnn::Mat decoder_out;
-  decoder_ex.input("in0", decoder_input);
-  decoder_ex.extract("out0", decoder_out);
+  decoder_ex.input(decoder_input_indexes_[0], decoder_input);
+  decoder_ex.extract(decoder_output_indexes_[0], decoder_out);
   decoder_out = decoder_out.reshape(decoder_out.w);
 
   return decoder_out;
@@ -99,11 +84,11 @@ ncnn::Mat LstmModel::RunDecoder(ncnn::Mat &decoder_input) {
 ncnn::Mat LstmModel::RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out) {
   auto joiner_ex = joiner_.create_extractor();
   joiner_ex.set_num_threads(num_threads_);
-  joiner_ex.input("in0", encoder_out);
-  joiner_ex.input("in1", decoder_out);
+  joiner_ex.input(joiner_input_indexes_[0], encoder_out);
+  joiner_ex.input(joiner_input_indexes_[1], decoder_out);
 
   ncnn::Mat joiner_out;
-  joiner_ex.extract("out0", joiner_out);
+  joiner_ex.extract(joiner_output_indexes_[0], joiner_out);
   return joiner_out;
 }
 
@@ -124,4 +109,80 @@ void LstmModel::InitJoiner(const std::string &joiner_param,
   InitNet(joiner_, joiner_param, joiner_bin);
 }
 
+std::vector<ncnn::Mat> LstmModel::GetEncoderInitStates() const {
+  int32_t num_encoder_layers = 12;
+  int32_t d_model = 512;
+  int32_t rnn_hidden_size = 1024;
+
+  auto hx = ncnn::Mat(d_model, num_encoder_layers);
+  auto cx = ncnn::Mat(rnn_hidden_size, num_encoder_layers);
+
+  hx.fill(0);
+  cx.fill(0);
+
+  return {hx, cx};
+}
+
+void LstmModel::InitEncoderInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, features,
+  // [1] -> in1, features_length
+  // [2] -> in2, hx
+  // [3] -> in3, cx
+  encoder_input_indexes_.resize(4);
+
+  // output indexes map
+  // [0] -> out0, encoder_out
+  // [1] -> out2, hx
+  // [2] -> out3, cx
+  encoder_output_indexes_.resize(3);
+  const auto &blobs = encoder_.blobs();
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (b.name == "in0") encoder_input_indexes_[0] = i;
+    if (b.name == "in1") encoder_input_indexes_[1] = i;
+    if (b.name == "in2") encoder_input_indexes_[2] = i;
+    if (b.name == "in3") encoder_input_indexes_[3] = i;
+    if (b.name == "out0") encoder_output_indexes_[0] = i;
+    if (b.name == "out2") encoder_output_indexes_[1] = i;
+    if (b.name == "out3") encoder_output_indexes_[2] = i;
+  }
+}
+
+void LstmModel::InitDecoderInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, decoder_input,
+  decoder_input_indexes_.resize(1);
+
+  // output indexes map
+  // [0] -> out0, decoder_out,
+  decoder_output_indexes_.resize(1);
+
+  const auto &blobs = decoder_.blobs();
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (b.name == "in0") decoder_input_indexes_[0] = i;
+    if (b.name == "out0") decoder_output_indexes_[0] = i;
+  }
+}
+
+void LstmModel::InitJoinerInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, encoder_input,
+  // [1] -> in1, decoder_input,
+  joiner_input_indexes_.resize(2);
+
+  // output indexes map
+  // [0] -> out0, joiner_out,
+  joiner_output_indexes_.resize(1);
+
+  const auto &blobs = joiner_.blobs();
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (b.name == "in0") joiner_input_indexes_[0] = i;
+    if (b.name == "in1") joiner_input_indexes_[1] = i;
+    if (b.name == "out0") joiner_output_indexes_[0] = i;
+  }
+}
+
 }  // namespace sherpa_ncnn

+ 14 - 0
sherpa-ncnn/csrc/lstm-model.h

@@ -73,10 +73,24 @@ class LstmModel : public Model {
   void InitJoiner(const std::string &joiner_param,
                   const std::string &joiner_bin);
 
+  std::vector<ncnn::Mat> GetEncoderInitStates() const;
+
+  void InitEncoderInputOutputIndexes();
+  void InitDecoderInputOutputIndexes();
+  void InitJoinerInputOutputIndexes();
+
  private:
   ncnn::Net encoder_;
   ncnn::Net decoder_;
   ncnn::Net joiner_;
+  std::vector<int32_t> encoder_input_indexes_;
+  std::vector<int32_t> encoder_output_indexes_;
+
+  std::vector<int32_t> decoder_input_indexes_;
+  std::vector<int32_t> decoder_output_indexes_;
+
+  std::vector<int32_t> joiner_input_indexes_;
+  std::vector<int32_t> joiner_output_indexes_;
 
   int32_t num_threads_;
 };

+ 38 - 1
sherpa-ncnn/csrc/model.cc

@@ -19,6 +19,7 @@
 
 #include <sstream>
 
+#include "sherpa-ncnn/csrc/conv-emformer-model.h"
 #include "sherpa-ncnn/csrc/lstm-model.h"
 
 namespace sherpa_ncnn {
@@ -41,7 +42,7 @@ std::string ModelConfig::ToString() const {
 
 static bool IsLstmModel(const ncnn::Net &net) {
   for (const auto &layer : net.layers()) {
-    if (layer->type == "LSTM" || layer->type == "LSTM2") {
+    if (layer->type == "LSTM") {
       return true;
     }
   }
@@ -49,6 +50,38 @@ static bool IsLstmModel(const ncnn::Net &net) {
   return false;
 }
 
+static bool IsConvEmformerModel(const ncnn::Net &net) {
+  // Note: We may need to add more constraints if number of models gets larger.
+  if (net.input_indexes().size() < 49) {
+    return false;
+  }
+
+  if (net.output_indexes().size() < 49) {
+    return false;
+  }
+
+  for (const auto &layer : net.layers()) {
+    if (layer->type == "GLU") {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+void Model::InitNet(ncnn::Net &net, const std::string &param,
+                    const std::string &bin) {
+  if (net.load_param(param.c_str())) {
+    NCNN_LOGE("failed to load %s", param.c_str());
+    exit(-1);
+  }
+
+  if (net.load_model(bin.c_str())) {
+    NCNN_LOGE("failed to load %s", bin.c_str());
+    exit(-1);
+  }
+}
+
 std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
   // 1. Load the encoder network
   // 2. If the encoder network has LSTM layers, we assume it is a LstmModel
@@ -67,6 +100,10 @@ std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
     return std::make_unique<LstmModel>(config);
   }
 
+  if (IsConvEmformerModel(net)) {
+    return std::make_unique<ConvEmformerModel>(config);
+  }
+
   return nullptr;
 }
 

+ 4 - 1
sherpa-ncnn/csrc/model.h

@@ -39,10 +39,13 @@ struct ModelConfig {
 
 class Model {
  public:
+  virtual ~Model() = default;
+
   /** Create a model from a config. */
   static std::unique_ptr<Model> Create(const ModelConfig &config);
 
-  virtual ~Model() = default;
+  static void InitNet(ncnn::Net &net, const std::string &param,
+                      const std::string &bin);
 
   /** Run the encoder network.
    *