浏览代码

Refactor the code to make it easier to support other types of models. (#24)

* Refactor LSTM model

* Refactor the code to support other models
Fangjun Kuang 2 年之前
父节点
当前提交
6f355b4211

+ 1 - 0
sherpa-ncnn/csrc/CMakeLists.txt

@@ -4,6 +4,7 @@ set(sherpa_ncnn_core_srcs
   decode.cc
   decode.cc
   features.cc
   features.cc
   lstm-model.cc
   lstm-model.cc
+  model.cc
   symbol-table.cc
   symbol-table.cc
   wave-reader.cc
   wave-reader.cc
 )
 )

+ 4 - 4
sherpa-ncnn/csrc/decode.cc

@@ -20,15 +20,15 @@
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
-void GreedySearch(LstmModel &model, ncnn::Mat &encoder_out,
-                  ncnn::Mat *decoder_out, std::vector<int32_t> *hyp) {
+void GreedySearch(Model *model, ncnn::Mat &encoder_out, ncnn::Mat *decoder_out,
+                  std::vector<int32_t> *hyp) {
   int32_t context_size = 2;
   int32_t context_size = 2;
   int32_t blank_id = 0;  // hard-code it to 0
   int32_t blank_id = 0;  // hard-code it to 0
   ncnn::Mat decoder_input(context_size);
   ncnn::Mat decoder_input(context_size);
 
 
   for (int32_t t = 0; t != encoder_out.h; ++t) {
   for (int32_t t = 0; t != encoder_out.h; ++t) {
     ncnn::Mat encoder_out_t(encoder_out.w, encoder_out.row(t));
     ncnn::Mat encoder_out_t(encoder_out.w, encoder_out.row(t));
-    ncnn::Mat joiner_out = model.RunJoiner(encoder_out_t, *decoder_out);
+    ncnn::Mat joiner_out = model->RunJoiner(encoder_out_t, *decoder_out);
 
 
     auto y = static_cast<int32_t>(std::distance(
     auto y = static_cast<int32_t>(std::distance(
         static_cast<const float *>(joiner_out),
         static_cast<const float *>(joiner_out),
@@ -41,7 +41,7 @@ void GreedySearch(LstmModel &model, ncnn::Mat &encoder_out,
       static_cast<int32_t *>(decoder_input)[1] = y;
       static_cast<int32_t *>(decoder_input)[1] = y;
       hyp->push_back(y);
       hyp->push_back(y);
 
 
-      *decoder_out = model.RunDecoder(decoder_input);
+      *decoder_out = model->RunDecoder(decoder_input);
     }
     }
   }
   }
 }
 }

+ 4 - 4
sherpa-ncnn/csrc/decode.h

@@ -22,13 +22,13 @@
 #include <vector>
 #include <vector>
 
 
 #include "net.h"  // NOLINT
 #include "net.h"  // NOLINT
-#include "sherpa-ncnn/csrc/lstm-model.h"
+#include "sherpa-ncnn/csrc/model.h"
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
 /**
 /**
  *
  *
- * @param model  The LstmModel
+ * @param model  The neural network.
  * @param encoder_out  Its shape is (num_frames, encoder_out_dim).
  * @param encoder_out  Its shape is (num_frames, encoder_out_dim).
  *                     encoder_out.w == encoder_out_dim
  *                     encoder_out.w == encoder_out_dim
  *                     encoder_out.h == num_frames
  *                     encoder_out.h == num_frames
@@ -37,8 +37,8 @@ namespace sherpa_ncnn {
  *                     decoder_out.h == 1
  *                     decoder_out.h == 1
  * @param hyp The recognition result. It is changed in place.
  * @param hyp The recognition result. It is changed in place.
  */
  */
-void GreedySearch(LstmModel &model, ncnn::Mat &encoder_out,
-                  ncnn::Mat *decoder_out, std::vector<int32_t> *hyp);
+void GreedySearch(Model *model, ncnn::Mat &encoder_out, ncnn::Mat *decoder_out,
+                  std::vector<int32_t> *hyp);
 
 
 }  // namespace sherpa_ncnn
 }  // namespace sherpa_ncnn
 
 

+ 28 - 24
sherpa-ncnn/csrc/lstm-model.cc

@@ -18,6 +18,8 @@
 #include "sherpa-ncnn/csrc/lstm-model.h"
 #include "sherpa-ncnn/csrc/lstm-model.h"
 
 
 #include <iostream>
 #include <iostream>
+#include <utility>
+#include <vector>
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
@@ -34,30 +36,30 @@ static void InitNet(ncnn::Net &net, const std::string &param,
   }
   }
 }
 }
 
 
-LstmModel::LstmModel(const std::string &encoder_param,
-                     const std::string &encoder_bin,
-                     const std::string &decoder_param,
-                     const std::string &decoder_bin,
-                     const std::string &joiner_param,
-                     const std::string &joiner_bin, int32_t num_threads)
-    : num_threads_(num_threads) {
-  InitEncoder(encoder_param, encoder_bin);
-  InitDecoder(decoder_param, decoder_bin);
-  InitJoiner(joiner_param, joiner_bin);
+LstmModel::LstmModel(const ModelConfig &config)
+    : num_threads_(config.num_threads) {
+  InitEncoder(config.encoder_param, config.encoder_bin);
+  InitDecoder(config.decoder_param, config.decoder_bin);
+  InitJoiner(config.joiner_param, config.joiner_bin);
 }
 }
 
 
-ncnn::Mat LstmModel::RunEncoder(ncnn::Mat &features, ncnn::Mat *hx,
-                                ncnn::Mat *cx) {
+std::pair<ncnn::Mat, std::vector<ncnn::Mat>> LstmModel::RunEncoder(
+    ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
   int32_t num_encoder_layers = 12;
   int32_t num_encoder_layers = 12;
   int32_t d_model = 512;
   int32_t d_model = 512;
   int32_t rnn_hidden_size = 1024;
   int32_t rnn_hidden_size = 1024;
-
-  if (hx->empty()) {
-    hx->create(d_model, num_encoder_layers);
-    cx->create(rnn_hidden_size, num_encoder_layers);
-
-    hx->fill(0);
-    cx->fill(0);
+  ncnn::Mat hx;
+  ncnn::Mat cx;
+
+  if (states.empty()) {
+    hx.create(d_model, num_encoder_layers);
+    cx.create(rnn_hidden_size, num_encoder_layers);
+
+    hx.fill(0);
+    cx.fill(0);
+  } else {
+    hx = states[0];
+    cx = states[1];
   }
   }
 
 
   ncnn::Mat feature_lengths(1);
   ncnn::Mat feature_lengths(1);
@@ -68,16 +70,18 @@ ncnn::Mat LstmModel::RunEncoder(ncnn::Mat &features, ncnn::Mat *hx,
 
 
   encoder_ex.input("in0", features);
   encoder_ex.input("in0", features);
   encoder_ex.input("in1", feature_lengths);
   encoder_ex.input("in1", feature_lengths);
-  encoder_ex.input("in2", *hx);
-  encoder_ex.input("in3", *cx);
+  encoder_ex.input("in2", hx);
+  encoder_ex.input("in3", cx);
 
 
   ncnn::Mat encoder_out;
   ncnn::Mat encoder_out;
   encoder_ex.extract("out0", encoder_out);
   encoder_ex.extract("out0", encoder_out);
 
 
-  encoder_ex.extract("out2", *hx);
-  encoder_ex.extract("out3", *cx);
+  encoder_ex.extract("out2", hx);
+  encoder_ex.extract("out3", cx);
+
+  std::vector<ncnn::Mat> next_states = {hx, cx};
 
 
-  return encoder_out;
+  return {encoder_out, next_states};
 }
 }
 
 
 ncnn::Mat LstmModel::RunDecoder(ncnn::Mat &decoder_input) {
 ncnn::Mat LstmModel::RunDecoder(ncnn::Mat &decoder_input) {

+ 25 - 40
sherpa-ncnn/csrc/lstm-model.h

@@ -20,65 +20,50 @@
 #define SHERPA_NCNN_CSRC_LSTM_MODEL_H_
 #define SHERPA_NCNN_CSRC_LSTM_MODEL_H_
 
 
 #include <string>
 #include <string>
+#include <utility>
+#include <vector>
 
 
 #include "net.h"  // NOLINT
 #include "net.h"  // NOLINT
+#include "sherpa-ncnn/csrc/model.h"
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
-class LstmModel {
+class LstmModel : public Model {
  public:
  public:
-  /**
-   * @param encoder_param Path to encoder.ncnn.param
-   * @param encoder_bin Path to encoder.ncnn.bin
-   * @param decoder_param Path to decoder.ncnn.param
-   * @param decoder_bin Path to decoder.ncnn.bin
-   * @param joiner_param Path to joiner.ncnn.param
-   * @param joiner_bin Path to joiner.ncnn.bin
-   * @param num_threads Number of threads to use when running the network
-   */
-  LstmModel(const std::string &encoder_param, const std::string &encoder_bin,
-            const std::string &decoder_param, const std::string &decoder_bin,
-            const std::string &joiner_param, const std::string &joiner_bin,
-            int32_t num_threads);
+  explicit LstmModel(const ModelConfig &config);
 
 
   /** Run the encoder network.
   /** Run the encoder network.
    *
    *
    * @param features  A 2-d mat of shape (num_frames, feature_dim).
    * @param features  A 2-d mat of shape (num_frames, feature_dim).
    *                  Note: features.w = feature_dim.
    *                  Note: features.w = feature_dim.
    *                        features.h = num_frames.
    *                        features.h = num_frames.
-   * @param hx  Hidden state of the LSTM model. You can leave it to empty
-   *            on the first invocation. It is changed in-place.
+   * @param states Contains two tensors:
+   *          - hx  Hidden state of the LSTM model. You can leave it to empty
+   *                on the first invocation. It is changed in-place.
    *
    *
-   * @param cx  Hidden cell state of the LSTM model. You can leave it to empty
-   *            on the first invocation. It is changed in-place.
+   *          - cx  Hidden cell state of the LSTM model. You can leave it to
+   *                empty on the first invocation. It is changed in-place.
    *
    *
-   * @return Return the output of the encoder. Its shape is
-   *  (num_out_frames, encoder_dim).
-   *  Note: ans.w == encoder_dim; ans.h == num_out_frames
-   */
-  ncnn::Mat RunEncoder(ncnn::Mat &features, ncnn::Mat *hx, ncnn::Mat *cx);
-
-  /** Run the decoder network.
+   *          - Note: on the first invocation, you can pass an empty vector.
    *
    *
-   * @param  decoder_input A mat of shape (context_size,). Note: Its underlying
-   *                       content consists of integers, though its type is
-   *                       float.
+   * @return Return a pair containing:
+   *   - the output of the encoder. Its shape is (num_out_frames, encoder_dim).
+   *     Note: ans.w == encoder_dim; ans.h == num_out_frames
    *
    *
-   * @return Return a mat of shape (decoder_dim,)
+   *   - next_states, a vector containing hx and cx for the next invocation
    */
    */
-  ncnn::Mat RunDecoder(ncnn::Mat &decoder_input);
+  std::pair<ncnn::Mat, std::vector<ncnn::Mat>> RunEncoder(
+      ncnn::Mat &features, const std::vector<ncnn::Mat> &states) override;
 
 
-  /** Run the joiner network.
-   *
-   * @param encoder_out  A mat of shape (encoder_dim,)
-   * @param decoder_out  A mat of shape (decoder_dim,)
-   *
-   * @return Return the joiner output which is of shape (vocab_size,)
-   */
-  ncnn::Mat RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out);
+  ncnn::Mat RunDecoder(ncnn::Mat &decoder_input) override;
+
+  ncnn::Mat RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out) override;
+
+  int32_t Segment() const override { return 9; }
 
 
-  int32_t ContextSize() const { return 2; }
-  int32_t BlankId() const { return 0; }
+  // Advance the feature extract by this number of frames after
+  // running the encoder network
+  int32_t Offset() const override { return 4; }
 
 
  private:
  private:
   void InitEncoder(const std::string &encoder_param,
   void InitEncoder(const std::string &encoder_param,

+ 73 - 0
sherpa-ncnn/csrc/model.cc

@@ -0,0 +1,73 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "sherpa-ncnn/csrc/model.h"
+
+#include <sstream>
+
+#include "sherpa-ncnn/csrc/lstm-model.h"
+
+namespace sherpa_ncnn {
+
+std::string ModelConfig::ToString() const {
+  std::ostringstream os;
+  os << "encoder_param: " << encoder_param << "\n";
+  os << "encoder_bin: " << encoder_bin << "\n";
+
+  os << "decoder_param: " << decoder_param << "\n";
+  os << "decoder_bin: " << decoder_bin << "\n";
+
+  os << "joiner_param: " << joiner_param << "\n";
+  os << "joiner_bin: " << joiner_bin << "\n";
+
+  os << "num_threads: " << num_threads << "\n";
+
+  return os.str();
+}
+
+static bool IsLstmModel(const ncnn::Net &net) {
+  for (const auto &layer : net.layers()) {
+    if (layer->type == "LSTM" || layer->type == "LSTM2") {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
+  // 1. Load the encoder network
+  // 2. If the encoder network has LSTM layers, we assume it is a LstmModel
+  // 3. Otherwise, we assume it is a ConvEmformer
+  // 4. TODO(fangjun): We need to change this function to support more models
+  // in the future
+
+  ncnn::Net net;
+  auto ret = net.load_param(config.encoder_param.c_str());
+  if (ret != 0) {
+    NCNN_LOGE("Failed to load %s", config.encoder_param.c_str());
+    return nullptr;
+  }
+
+  if (IsLstmModel(net)) {
+    return std::make_unique<LstmModel>(config);
+  }
+
+  return nullptr;
+}
+
+}  // namespace sherpa_ncnn

+ 96 - 0
sherpa-ncnn/csrc/model.h

@@ -0,0 +1,96 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_MODEL_H_
+#define SHERPA_NCNN_CSRC_MODEL_H_
+
+#include <memory>
+#include <string>
+
+#include "net.h"  // NOLINT
+
+namespace sherpa_ncnn {
+
+struct ModelConfig {
+  std::string encoder_param;  // path to encoder.ncnn.param
+  std::string encoder_bin;    // path to encoder.ncnn.bin
+  std::string decoder_param;  // path to decoder.ncnn.param
+  std::string decoder_bin;    // path to decoder.ncnn.bin
+  std::string joiner_param;   // path to joiner.ncnn.param
+  std::string joiner_bin;     // path to joiner.ncnn.bin
+  int32_t num_threads;        // number of threads to run the model
+  std::string ToString() const;
+};
+
+class Model {
+ public:
+  /** Create a model from a config. */
+  static std::unique_ptr<Model> Create(const ModelConfig &config);
+
+  virtual ~Model() = default;
+
+  /** Run the encoder network.
+   *
+   * @param features  A 2-d mat of shape (num_frames, feature_dim).
+   *                  Note: features.w = feature_dim.
+   *                        features.h = num_frames.
+   * @param states It contains the states for the encoder network. Its exact
+   *               content is determined by the underlying network.
+   *
+   * @return Return a pair containing:
+   *   - encoder_out
+   *   - next_states
+   */
+  virtual std::pair<ncnn::Mat, std::vector<ncnn::Mat>> RunEncoder(
+      ncnn::Mat &features, const std::vector<ncnn::Mat> &states) = 0;
+
+  /** Run the decoder network.
+   *
+   * @param  decoder_input A mat of shape (context_size,). Note: Its underlying
+   *                       content consists of integers, though its type is
+   *                       float.
+   *
+   * @return Return a mat of shape (decoder_dim,)
+   */
+  virtual ncnn::Mat RunDecoder(ncnn::Mat &decoder_input) = 0;
+
+  /** Run the joiner network.
+   *
+   * @param encoder_out  A mat of shape (encoder_dim,)
+   * @param decoder_out  A mat of shape (decoder_dim,)
+   *
+   * @return Return the joiner output which is of shape (vocab_size,)
+   */
+  virtual ncnn::Mat RunJoiner(ncnn::Mat &encoder_out,
+                              ncnn::Mat &decoder_out) = 0;
+
+  virtual int32_t ContextSize() const { return 2; }
+
+  virtual int32_t BlankId() const { return 0; }
+
+  // The encoder takes this number of frames as input
+  virtual int32_t Segment() const = 0;
+
+  // Advance the feature extractor by this number of frames after
+  // running the encoder network
+  virtual int32_t Offset() const = 0;
+};
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_MODEL_H_

+ 28 - 21
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -23,8 +23,8 @@
 #include "portaudio.h"  // NOLINT
 #include "portaudio.h"  // NOLINT
 #include "sherpa-ncnn/csrc/decode.h"
 #include "sherpa-ncnn/csrc/decode.h"
 #include "sherpa-ncnn/csrc/features.h"
 #include "sherpa-ncnn/csrc/features.h"
-#include "sherpa-ncnn/csrc/lstm-model.h"
 #include "sherpa-ncnn/csrc/microphone.h"
 #include "sherpa-ncnn/csrc/microphone.h"
+#include "sherpa-ncnn/csrc/model.h"
 #include "sherpa-ncnn/csrc/symbol-table.h"
 #include "sherpa-ncnn/csrc/symbol-table.h"
 
 
 bool stop = false;
 bool stop = false;
@@ -71,25 +71,29 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
   }
   }
   signal(SIGINT, Handler);
   signal(SIGINT, Handler);
 
 
+  sherpa_ncnn::ModelConfig config;
+
   std::string tokens = argv[1];
   std::string tokens = argv[1];
-  std::string encoder_param = argv[2];
-  std::string encoder_bin = argv[3];
-  std::string decoder_param = argv[4];
-  std::string decoder_bin = argv[5];
-  std::string joiner_param = argv[6];
-  std::string joiner_bin = argv[7];
-
-  int32_t num_threads = 4;
+  config.encoder_param = argv[2];
+  config.encoder_bin = argv[3];
+  config.decoder_param = argv[4];
+  config.decoder_bin = argv[5];
+  config.joiner_param = argv[6];
+  config.joiner_bin = argv[7];
+
+  config.num_threads = 4;
   if (argc == 9) {
   if (argc == 9) {
-    num_threads = atoi(argv[8]);
+    config.num_threads = atoi(argv[8]);
   }
   }
 
 
   sherpa_ncnn::SymbolTable sym(tokens);
   sherpa_ncnn::SymbolTable sym(tokens);
-  fprintf(stderr, "Number of threads: %d\n", num_threads);
+  fprintf(stderr, "%s\n", config.ToString().c_str());
 
 
-  sherpa_ncnn::LstmModel model(encoder_param, encoder_bin, decoder_param,
-                               decoder_bin, joiner_param, joiner_bin,
-                               num_threads);
+  auto model = sherpa_ncnn::Model::Create(config);
+  if (!model) {
+    fprintf(stderr, "Failed to create a model\n");
+    exit(EXIT_FAILURE);
+  }
 
 
   sherpa_ncnn::Microphone mic;
   sherpa_ncnn::Microphone mic;
 
 
@@ -139,11 +143,11 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
     exit(EXIT_FAILURE);
     exit(EXIT_FAILURE);
   }
   }
 
 
-  int32_t segment = 9;
-  int32_t offset = 4;
+  int32_t segment = model->Segment();
+  int32_t offset = model->Offset();
 
 
-  int32_t context_size = model.ContextSize();
-  int32_t blank_id = model.BlankId();
+  int32_t context_size = model->ContextSize();
+  int32_t blank_id = model->BlankId();
 
 
   std::vector<int32_t> hyp(context_size, blank_id);
   std::vector<int32_t> hyp(context_size, blank_id);
 
 
@@ -152,7 +156,7 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
     static_cast<int32_t *>(decoder_input)[i] = blank_id;
     static_cast<int32_t *>(decoder_input)[i] = blank_id;
   }
   }
 
 
-  ncnn::Mat decoder_out = model.RunDecoder(decoder_input);
+  ncnn::Mat decoder_out = model->RunDecoder(decoder_input);
 
 
   ncnn::Mat hx;
   ncnn::Mat hx;
   ncnn::Mat cx;
   ncnn::Mat cx;
@@ -160,14 +164,17 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
   int32_t num_tokens = hyp.size();
   int32_t num_tokens = hyp.size();
   int32_t num_processed = 0;
   int32_t num_processed = 0;
 
 
+  std::vector<ncnn::Mat> states;
+  ncnn::Mat encoder_out;
+
   while (!stop) {
   while (!stop) {
     while (feature_extractor.NumFramesReady() - num_processed >= segment) {
     while (feature_extractor.NumFramesReady() - num_processed >= segment) {
       ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
       ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
       num_processed += offset;
       num_processed += offset;
 
 
-      ncnn::Mat encoder_out = model.RunEncoder(features, &hx, &cx);
+      std::tie(encoder_out, states) = model->RunEncoder(features, states);
 
 
-      GreedySearch(model, encoder_out, &decoder_out, &hyp);
+      GreedySearch(model.get(), encoder_out, &decoder_out, &hyp);
     }
     }
 
 
     if (hyp.size() != num_tokens) {
     if (hyp.size() != num_tokens) {

+ 28 - 22
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -23,7 +23,7 @@
 #include "net.h"  // NOLINT
 #include "net.h"  // NOLINT
 #include "sherpa-ncnn/csrc/decode.h"
 #include "sherpa-ncnn/csrc/decode.h"
 #include "sherpa-ncnn/csrc/features.h"
 #include "sherpa-ncnn/csrc/features.h"
-#include "sherpa-ncnn/csrc/lstm-model.h"
+#include "sherpa-ncnn/csrc/model.h"
 #include "sherpa-ncnn/csrc/symbol-table.h"
 #include "sherpa-ncnn/csrc/symbol-table.h"
 #include "sherpa-ncnn/csrc/wave-reader.h"
 #include "sherpa-ncnn/csrc/wave-reader.h"
 
 
@@ -89,29 +89,35 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
 
 
     return 0;
     return 0;
   }
   }
+  sherpa_ncnn::ModelConfig config;
+
   std::string tokens = argv[1];
   std::string tokens = argv[1];
-  std::string encoder_param = argv[2];
-  std::string encoder_bin = argv[3];
-  std::string decoder_param = argv[4];
-  std::string decoder_bin = argv[5];
-  std::string joiner_param = argv[6];
-  std::string joiner_bin = argv[7];
+
+  config.encoder_param = argv[2];
+  config.encoder_bin = argv[3];
+  config.decoder_param = argv[4];
+  config.decoder_bin = argv[5];
+  config.joiner_param = argv[6];
+  config.joiner_bin = argv[7];
+
   std::string wav_filename = argv[8];
   std::string wav_filename = argv[8];
 
 
-  int32_t num_threads = 4;
+  config.num_threads = 4;
   if (argc == 10) {
   if (argc == 10) {
-    num_threads = atoi(argv[9]);
+    config.num_threads = atoi(argv[9]);
   }
   }
 
 
   float expected_sampling_rate = 16000;
   float expected_sampling_rate = 16000;
 
 
   sherpa_ncnn::SymbolTable sym(tokens);
   sherpa_ncnn::SymbolTable sym(tokens);
 
 
-  std::cout << "number of threads: " << num_threads << "\n";
+  std::cout << config.ToString() << "\n";
 
 
-  sherpa_ncnn::LstmModel model(encoder_param, encoder_bin, decoder_param,
-                               decoder_bin, joiner_param, joiner_bin,
-                               num_threads);
+  auto model = sherpa_ncnn::Model::Create(config);
+  if (!model) {
+    std::cout << "Failed to create a model\n";
+    exit(EXIT_FAILURE);
+  }
 
 
   std::vector<float> samples =
   std::vector<float> samples =
       sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate);
       sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate);
@@ -132,11 +138,11 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
 
 
   feature_extractor.InputFinished();
   feature_extractor.InputFinished();
 
 
-  int32_t segment = 9;
-  int32_t offset = 4;
+  int32_t segment = model->Segment();
+  int32_t offset = model->Offset();
 
 
-  int32_t context_size = model.ContextSize();
-  int32_t blank_id = model.BlankId();
+  int32_t context_size = model->ContextSize();
+  int32_t blank_id = model->BlankId();
 
 
   std::vector<int32_t> hyp(context_size, blank_id);
   std::vector<int32_t> hyp(context_size, blank_id);
 
 
@@ -145,19 +151,19 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
     static_cast<int32_t *>(decoder_input)[i] = blank_id;
     static_cast<int32_t *>(decoder_input)[i] = blank_id;
   }
   }
 
 
-  ncnn::Mat decoder_out = model.RunDecoder(decoder_input);
+  ncnn::Mat decoder_out = model->RunDecoder(decoder_input);
 
 
-  ncnn::Mat hx;
-  ncnn::Mat cx;
+  std::vector<ncnn::Mat> states;
+  ncnn::Mat encoder_out;
 
 
   int32_t num_processed = 0;
   int32_t num_processed = 0;
   while (feature_extractor.NumFramesReady() - num_processed >= segment) {
   while (feature_extractor.NumFramesReady() - num_processed >= segment) {
     ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
     ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
     num_processed += offset;
     num_processed += offset;
 
 
-    ncnn::Mat encoder_out = model.RunEncoder(features, &hx, &cx);
+    std::tie(encoder_out, states) = model->RunEncoder(features, states);
 
 
-    GreedySearch(model, encoder_out, &decoder_out, &hyp);
+    GreedySearch(model.get(), encoder_out, &decoder_out, &hyp);
   }
   }
 
 
   std::string text;
   std::string text;