Jelajahi Sumber

Support streaming zipformer (#102)

* Support streaming zipformer

* Release v1.5.0
Fangjun Kuang 2 tahun lalu
induk
melakukan
f15a247b57

+ 76 - 0
.github/scripts/run-test.sh

@@ -257,3 +257,79 @@ for wave in ${waves[@]}; do
 done
 
 rm -rf $repo
+
+log "------------------------------------------------------------"
+log "Run Zipformer transducer (English + Chinese, bilingual)"
+log "------------------------------------------------------------"
+repo_url=https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13
+log "Start testing ${repo_url}"
+repo=$(basename $repo_url)
+log "Download pretrained model and test-data from $repo_url"
+
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+pushd $repo
+git lfs pull --include "*.bin"
+popd
+
+waves=(
+$repo/test_wavs/0.wav
+$repo/test_wavs/1.wav
+$repo/test_wavs/2.wav
+$repo/test_wavs/3.wav
+$repo/test_wavs/4.wav
+)
+
+for wave in ${waves[@]}; do
+  for m in greedy_search modified_beam_search; do
+    time $EXE \
+      $repo/tokens.txt \
+      $repo/encoder_jit_trace-pnnx.ncnn.param \
+      $repo/encoder_jit_trace-pnnx.ncnn.bin \
+      $repo/decoder_jit_trace-pnnx.ncnn.param \
+      $repo/decoder_jit_trace-pnnx.ncnn.bin \
+      $repo/joiner_jit_trace-pnnx.ncnn.param \
+      $repo/joiner_jit_trace-pnnx.ncnn.bin \
+      $wave \
+      4 \
+      $m
+  done
+done
+
+rm -rf $repo
+
+log "------------------------------------------------------------"
+log "Run Zipformer transducer (English)"
+log "------------------------------------------------------------"
+repo_url=https://huggingface.co/csukuangfj/sherpa-ncnn-streaming-zipformer-en-2023-02-13
+log "Start testing ${repo_url}"
+repo=$(basename $repo_url)
+log "Download pretrained model and test-data from $repo_url"
+
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+pushd $repo
+git lfs pull --include "*.bin"
+popd
+
+waves=(
+$repo/test_wavs/1089-134686-0001.wav
+$repo/test_wavs/1221-135766-0001.wav
+$repo/test_wavs/1221-135766-0002.wav
+)
+
+for wave in ${waves[@]}; do
+  for m in greedy_search modified_beam_search; do
+    time $EXE \
+      $repo/tokens.txt \
+      $repo/encoder_jit_trace-pnnx.ncnn.param \
+      $repo/encoder_jit_trace-pnnx.ncnn.bin \
+      $repo/decoder_jit_trace-pnnx.ncnn.param \
+      $repo/decoder_jit_trace-pnnx.ncnn.bin \
+      $repo/joiner_jit_trace-pnnx.ncnn.param \
+      $repo/joiner_jit_trace-pnnx.ncnn.bin \
+      $wave \
+      4 \
+      $m
+  done
+done
+
+rm -rf $repo

+ 1 - 1
CMakeLists.txt

@@ -1,7 +1,7 @@
 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
 project(sherpa-ncnn)
 
-set(SHERPA_NCNN_VERSION "1.4.3")
+set(SHERPA_NCNN_VERSION "1.5.0")
 
 # Disable warning about
 #

+ 12 - 11
cmake/ncnn.cmake

@@ -6,19 +6,19 @@ function(download_ncnn)
   # https://github.com/csukuangfj/ncnn/pull/7
 
   # Please also change ../pack-for-embedded-systems.sh
-  set(ncnn_URL "https://github.com/csukuangfj/ncnn/archive/refs/tags/sherpa-0.8.tar.gz")
-  set(ncnn_HASH "SHA256=f605c48986406800615d00cf14b955e95f73286eadacedb6c3371542540e1df0")
+  set(ncnn_URL "https://github.com/csukuangfj/ncnn/archive/refs/tags/sherpa-0.9.tar.gz")
+  set(ncnn_HASH "SHA256=a5fe1f69c75c06d6de858c7c660c43395b6ed3df9ee59d6e2fe621211e6928cd")
 
   # If you don't have access to the Internet, please download it to your
   # local drive and modify the following line according to your needs.
-  if(EXISTS "/star-fj/fangjun/download/github/ncnn-sherpa-0.8.tar.gz")
-    set(ncnn_URL  "file:///star-fj/fangjun/download/github/ncnn-sherpa-0.8.tar.gz")
-  elseif(EXISTS "/Users/fangjun/Downloads/ncnn-sherpa-0.8.tar.gz")
-    set(ncnn_URL  "file:///Users/fangjun/Downloads/ncnn-sherpa-0.8.tar.gz")
-  elseif(EXISTS "/tmp/ncnn-sherpa-0.8.tar.gz")
-    set(ncnn_URL  "file:///tmp/ncnn-sherpa-0.8.tar.gz")
-  elseif(EXISTS "$ENV{HOME}/asr/ncnn-sherpa-0.8.tar.gz")
-    set(ncnn_URL  "file://$ENV{HOME}/asr/ncnn-sherpa-0.8.tar.gz")
+  if(EXISTS "/star-fj/fangjun/download/github/ncnn-sherpa-0.9.tar.gz")
+    set(ncnn_URL  "file:///star-fj/fangjun/download/github/ncnn-sherpa-0.9.tar.gz")
+  elseif(EXISTS "/Users/fangjun/Downloads/ncnn-sherpa-0.9.tar.gz")
+    set(ncnn_URL  "file:///Users/fangjun/Downloads/ncnn-sherpa-0.9.tar.gz")
+  elseif(EXISTS "/tmp/ncnn-sherpa-0.9.tar.gz")
+    set(ncnn_URL  "file:///tmp/ncnn-sherpa-0.9.tar.gz")
+  elseif(EXISTS "$ENV{HOME}/asr/ncnn-sherpa-0.9.tar.gz")
+    set(ncnn_URL  "file://$ENV{HOME}/asr/ncnn-sherpa-0.9.tar.gz")
   endif()
 
   FetchContent_Declare(ncnn
@@ -64,7 +64,7 @@ function(download_ncnn)
     # Input
     Log
     LRN
-    MemoryData
+    # MemoryData
     MVN
     Pooling
     Power
@@ -146,6 +146,7 @@ function(download_ncnn)
     Fold
     Unfold
     GridSample
+    CumulativeSum
   )
 
   foreach(layer IN LISTS disabled_layers)

+ 1 - 1
sherpa-ncnn/csrc/CMakeLists.txt

@@ -14,6 +14,7 @@ set(sherpa_ncnn_core_srcs
   resample.cc
   symbol-table.cc
   wave-reader.cc
+  zipformer-model.cc
 )
 add_library(sherpa-ncnn-core ${sherpa_ncnn_core_srcs})
 target_link_libraries(sherpa-ncnn-core PUBLIC kaldi-native-fbank-core ncnn)
@@ -77,4 +78,3 @@ if(SHERPA_NCNN_ENABLE_TEST)
   add_executable(test-resample test-resample.cc)
   target_link_libraries(test-resample sherpa-ncnn-core)
 endif()
-

+ 11 - 8
sherpa-ncnn/csrc/meta-data.cc

@@ -22,14 +22,17 @@
 namespace sherpa_ncnn {
 
 int MetaData::load_param(const ncnn::ParamDict &pd) {
-  arg0 = pd.get(0, 0), arg1 = pd.get(1, 0), arg2 = pd.get(2, 0);
-  arg3 = pd.get(3, 0), arg4 = pd.get(4, 0), arg5 = pd.get(5, 0);
-  arg6 = pd.get(6, 0), arg7 = pd.get(7, 0), arg8 = pd.get(8, 0);
-  arg9 = pd.get(9, 0), arg10 = pd.get(10, 0), arg11 = pd.get(11, 0);
-  arg12 = pd.get(12, 0), arg13 = pd.get(13, 0), arg14 = pd.get(14, 0);
-  arg15 = pd.get(15, 0), arg16 = pd.get(16, 0), arg17 = pd.get(17, 0);
-  arg18 = pd.get(18, 0), arg19 = pd.get(19, 0), arg20 = pd.get(20, 0);
-  arg21 = pd.get(21, 0), arg22 = pd.get(22, 0), arg23 = pd.get(23, 0);
+  arg0 = pd.get(0, 0);
+  arg1 = pd.get(1, 0), arg2 = pd.get(2, 0), arg3 = pd.get(3, 0);
+  arg4 = pd.get(4, 0), arg5 = pd.get(5, 0), arg6 = pd.get(6, 0);
+  arg7 = pd.get(7, 0), arg8 = pd.get(8, 0), arg9 = pd.get(9, 0);
+  arg10 = pd.get(10, 0), arg11 = pd.get(11, 0), arg12 = pd.get(12, 0);
+  arg13 = pd.get(13, 0), arg14 = pd.get(14, 0), arg15 = pd.get(15, 0);
+
+  arg16 = pd.get(16, ncnn::Mat()), arg17 = pd.get(17, ncnn::Mat());
+  arg18 = pd.get(18, ncnn::Mat()), arg19 = pd.get(19, ncnn::Mat());
+  arg20 = pd.get(20, ncnn::Mat()), arg21 = pd.get(21, ncnn::Mat());
+  arg22 = pd.get(22, ncnn::Mat()), arg23 = pd.get(23, ncnn::Mat());
 
   // The following 8 attributes are of type float
   arg24 = pd.get(24, 0.f), arg25 = pd.get(25, 0.f), arg26 = pd.get(26, 0.f);

+ 5 - 1
sherpa-ncnn/csrc/meta-data.h

@@ -28,9 +28,13 @@ class MetaData : public ncnn::Layer {
  public:
   int load_param(const ncnn::ParamDict &pd) override;
 
+  // arg0 is the model type:
+  //  0 - ConvEmformer
+  //  1 - Zipformer
   int32_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
   int32_t arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15;
-  int32_t arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23;
+
+  ncnn::Mat arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23;
 
   float arg24, arg25, arg26, arg27, arg28, arg29, arg30, arg31;
 };

+ 28 - 2
sherpa-ncnn/csrc/model.cc

@@ -22,6 +22,7 @@
 #include "sherpa-ncnn/csrc/conv-emformer-model.h"
 #include "sherpa-ncnn/csrc/lstm-model.h"
 #include "sherpa-ncnn/csrc/meta-data.h"
+#include "sherpa-ncnn/csrc/zipformer-model.h"
 
 namespace sherpa_ncnn {
 
@@ -73,6 +74,27 @@ static bool IsConvEmformerModel(const ncnn::Net &net) {
   return false;
 }
 
+static bool IsZipformerModel(const ncnn::Net &net) {
+  // Note: We may need to add more constraints if number of models gets larger.
+  //
+  // If the net has a layer of type SherpaMetaData and with name
+  // sherpa_meta_data1 and if attribute 0 is 2, we assume the model is
+  // a Zipformer model.
+
+  for (const auto *layer : net.layers()) {
+    if (layer->type == "SherpaMetaData" && layer->name == "sherpa_meta_data1") {
+      // Note: We don't use dynamic_cast<> here since it will throw
+      // the following error
+      //  error: ‘dynamic_cast’ not permitted with -fno-rtti
+      const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
+
+      if (meta_data->arg0 == 2) return true;
+    }
+  }
+
+  return false;
+}
+
 void Model::InitNet(ncnn::Net &net, const std::string &param,
                     const std::string &bin) {
   if (net.load_param(param.c_str())) {
@@ -125,11 +147,15 @@ std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
     return std::make_unique<ConvEmformerModel>(config);
   }
 
+  if (IsZipformerModel(net)) {
+    return std::make_unique<ZipformerModel>(config);
+  }
+
   NCNN_LOGE(
       "Unable to create a model from specified model files.\n"
       "Please check: \n"
-      "  1. If you are using a ConvEmformer model, please make sure you have "
-      "added SherapMetaData to encoder_xxx.ncnn.param "
+      "  1. If you are using a ConvEmformer/Zipformer model, please make sure "
+      "you have added SherapMetaData to encoder_xxx.ncnn.param "
       "(or encoder_xxx.ncnn.int8.param if you are using an int8 model). "
       "You need to add it manually after converting the model with pnnx.\n"
       "  2. (Android) Whether the app requires an int8 model or not\n");

+ 367 - 0
sherpa-ncnn/csrc/zipformer-model.cc

@@ -0,0 +1,367 @@
+// sherpa-ncnn/csrc/zipformer-model.cc
+//
+// Copyright (c)  2023  Xiaomi Corporation
+
+#include "sherpa-ncnn/csrc/zipformer-model.h"
+
+#include <regex>  // NOLINT
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "net.h"       // NOLINT
+#include "platform.h"  // NOLINT
+#include "sherpa-ncnn/csrc/meta-data.h"
+
+namespace sherpa_ncnn {
+
+ZipformerModel::ZipformerModel(const ModelConfig &config) {
+  encoder_.opt = config.encoder_opt;
+  decoder_.opt = config.decoder_opt;
+  joiner_.opt = config.joiner_opt;
+
+  bool has_gpu = false;
+#if NCNN_VULKAN
+  has_gpu = ncnn::get_gpu_count() > 0;
+#endif
+
+  if (has_gpu && config.use_vulkan_compute) {
+    encoder_.opt.use_vulkan_compute = true;
+    decoder_.opt.use_vulkan_compute = true;
+    joiner_.opt.use_vulkan_compute = true;
+    NCNN_LOGE("Use GPU");
+  } else {
+    NCNN_LOGE("Don't Use GPU. has_gpu: %d, config.use_vulkan_compute: %d",
+              static_cast<int32_t>(has_gpu),
+              static_cast<int32_t>(config.use_vulkan_compute));
+  }
+
+  InitEncoder(config.encoder_param, config.encoder_bin);
+  InitDecoder(config.decoder_param, config.decoder_bin);
+  InitJoiner(config.joiner_param, config.joiner_bin);
+
+  InitEncoderInputOutputIndexes();
+  InitDecoderInputOutputIndexes();
+  InitJoinerInputOutputIndexes();
+}
+
+#if __ANDROID_API__ >= 9
+ZipformerModel::ZipformerModel(AAssetManager *mgr, const ModelConfig &config) {
+  InitEncoder(mgr, config.encoder_param, config.encoder_bin);
+  InitDecoder(mgr, config.decoder_param, config.decoder_bin);
+  InitJoiner(mgr, config.joiner_param, config.joiner_bin);
+
+  InitEncoderInputOutputIndexes();
+  InitDecoderInputOutputIndexes();
+  InitJoinerInputOutputIndexes();
+}
+#endif
+
+std::pair<ncnn::Mat, std::vector<ncnn::Mat>> ZipformerModel::RunEncoder(
+    ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
+  ncnn::Extractor encoder_ex = encoder_.create_extractor();
+  return RunEncoder(features, states, &encoder_ex);
+}
+
+std::pair<ncnn::Mat, std::vector<ncnn::Mat>> ZipformerModel::RunEncoder(
+    ncnn::Mat &features, const std::vector<ncnn::Mat> &states,
+    ncnn::Extractor *encoder_ex) {
+  std::vector<ncnn::Mat> _states;
+
+  const ncnn::Mat *p;
+  if (states.empty()) {
+    _states = GetEncoderInitStates();
+    p = _states.data();
+  } else {
+    p = states.data();
+  }
+
+  // Note: We ignore error check there
+  encoder_ex->input(encoder_input_indexes_[0], features);
+  for (int32_t i = 1; i != encoder_input_indexes_.size(); ++i) {
+    encoder_ex->input(encoder_input_indexes_[i], p[i - 1]);
+  }
+
+  ncnn::Mat encoder_out;
+  encoder_ex->extract(encoder_output_indexes_[0], encoder_out);
+
+  std::vector<ncnn::Mat> next_states(num_encoder_layers_.size() * 7);
+  for (int32_t i = 1; i != encoder_output_indexes_.size(); ++i) {
+    encoder_ex->extract(encoder_output_indexes_[i], next_states[i - 1]);
+  }
+
+  // reshape cached_avg to 1-D tensors; remove the w dim, which is 1
+  for (size_t i = 0; i != num_encoder_layers_.size(); ++i) {
+    next_states[i] = next_states[i].reshape(next_states[i].h);
+  }
+
+  // reshape cached_len to 2-D tensors, remove the h dim, which is 1
+  for (size_t i = num_encoder_layers_.size();
+       i != num_encoder_layers_.size() * 2; ++i) {
+    next_states[i] = next_states[i].reshape(next_states[i].w, next_states[i].c);
+  }
+
+  return {encoder_out, next_states};
+}
+
+ncnn::Mat ZipformerModel::RunDecoder(ncnn::Mat &decoder_input) {
+  ncnn::Extractor decoder_ex = decoder_.create_extractor();
+  return RunDecoder(decoder_input, &decoder_ex);
+}
+
+ncnn::Mat ZipformerModel::RunDecoder(ncnn::Mat &decoder_input,
+                                     ncnn::Extractor *decoder_ex) {
+  ncnn::Mat decoder_out;
+  decoder_ex->input(decoder_input_indexes_[0], decoder_input);
+  decoder_ex->extract(decoder_output_indexes_[0], decoder_out);
+  decoder_out = decoder_out.reshape(decoder_out.w);
+
+  return decoder_out;
+}
+
+ncnn::Mat ZipformerModel::RunJoiner(ncnn::Mat &encoder_out,
+                                    ncnn::Mat &decoder_out) {
+  auto joiner_ex = joiner_.create_extractor();
+  return RunJoiner(encoder_out, decoder_out, &joiner_ex);
+}
+
+ncnn::Mat ZipformerModel::RunJoiner(ncnn::Mat &encoder_out,
+                                    ncnn::Mat &decoder_out,
+                                    ncnn::Extractor *joiner_ex) {
+  joiner_ex->input(joiner_input_indexes_[0], encoder_out);
+  joiner_ex->input(joiner_input_indexes_[1], decoder_out);
+
+  ncnn::Mat joiner_out;
+  joiner_ex->extract("out0", joiner_out);
+  return joiner_out;
+}
+
+void ZipformerModel::InitEncoderPostProcessing() {
+  // Now load parameters for member variables
+  for (const auto *layer : encoder_.layers()) {
+    if (layer->type == "SherpaMetaData" && layer->name == "sherpa_meta_data1") {
+      // Note: We don't use dynamic_cast<> here since it will throw
+      // the following error
+      //  error: ‘dynamic_cast’ not permitted with -fno-rtti
+      const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
+
+      decode_chunk_length_ = meta_data->arg1;
+      num_left_chunks_ = meta_data->arg2;
+      pad_length_ = meta_data->arg3;
+
+      num_encoder_layers_ = std::vector<int32_t>(
+          static_cast<const int32_t *>(meta_data->arg16),
+          static_cast<const int32_t *>(meta_data->arg16) + meta_data->arg16.w);
+
+      encoder_dims_ = std::vector<int32_t>(
+          static_cast<const int32_t *>(meta_data->arg17),
+          static_cast<const int32_t *>(meta_data->arg17) + meta_data->arg17.w);
+
+      attention_dims_ = std::vector<int32_t>(
+          static_cast<const int32_t *>(meta_data->arg18),
+          static_cast<const int32_t *>(meta_data->arg18) + meta_data->arg18.w);
+
+      zipformer_downsampling_factors_ = std::vector<int32_t>(
+          static_cast<const int32_t *>(meta_data->arg19),
+          static_cast<const int32_t *>(meta_data->arg19) + meta_data->arg19.w);
+
+      cnn_module_kernels_ = std::vector<int32_t>(
+          static_cast<const int32_t *>(meta_data->arg20),
+          static_cast<const int32_t *>(meta_data->arg20) + meta_data->arg20.w);
+
+      break;
+    }
+  }
+}
+
+void ZipformerModel::InitEncoder(const std::string &encoder_param,
+                                 const std::string &encoder_bin) {
+  RegisterMetaDataLayer(encoder_);
+  InitNet(encoder_, encoder_param, encoder_bin);
+  InitEncoderPostProcessing();
+}
+
+void ZipformerModel::InitDecoder(const std::string &decoder_param,
+                                 const std::string &decoder_bin) {
+  InitNet(decoder_, decoder_param, decoder_bin);
+}
+
+void ZipformerModel::InitJoiner(const std::string &joiner_param,
+                                const std::string &joiner_bin) {
+  InitNet(joiner_, joiner_param, joiner_bin);
+}
+
+#if __ANDROID_API__ >= 9
+void ZipformerModel::InitEncoder(AAssetManager *mgr,
+                                 const std::string &encoder_param,
+                                 const std::string &encoder_bin) {
+  RegisterMetaDataLayer(encoder_);
+  InitNet(mgr, encoder_, encoder_param, encoder_bin);
+  InitEncoderPostProcessing();
+}
+
+void ZipformerModel::InitDecoder(AAssetManager *mgr,
+                                 const std::string &decoder_param,
+                                 const std::string &decoder_bin) {
+  InitNet(mgr, decoder_, decoder_param, decoder_bin);
+}
+
+void ZipformerModel::InitJoiner(AAssetManager *mgr,
+                                const std::string &joiner_param,
+                                const std::string &joiner_bin) {
+  InitNet(mgr, joiner_, joiner_param, joiner_bin);
+}
+#endif
+
+// see
+// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py#L673
+std::vector<ncnn::Mat> ZipformerModel::GetEncoderInitStates() const {
+  // each layer has 7 states:
+  // cached_len, (num_layers,)
+  // cached_avg, (num_layers, encoder_dim)
+  // cached_key, (num_layers, left_context_length, attention_dim)
+  // cached_val, (num_layers, left_context_length, attention_dim / 2)
+  // cached_val2, (num_layers, left_context_length, attention_dim / 2)
+  // cached_conv1, (num_layers, encoder_dim, cnn_module_kernel_ - 1)
+  // cached_conv2, (num_layers, encoder_dim, cnn_module_kernel_ - 1)
+
+  std::vector<ncnn::Mat> cached_len_vec;
+  std::vector<ncnn::Mat> cached_avg_vec;
+  std::vector<ncnn::Mat> cached_key_vec;
+  std::vector<ncnn::Mat> cached_val_vec;
+  std::vector<ncnn::Mat> cached_val2_vec;
+  std::vector<ncnn::Mat> cached_conv1_vec;
+  std::vector<ncnn::Mat> cached_conv2_vec;
+
+  cached_len_vec.reserve(num_encoder_layers_.size());
+  cached_avg_vec.reserve(num_encoder_layers_.size());
+  cached_key_vec.reserve(num_encoder_layers_.size());
+  cached_val_vec.reserve(num_encoder_layers_.size());
+  cached_val2_vec.reserve(num_encoder_layers_.size());
+  cached_conv1_vec.reserve(num_encoder_layers_.size());
+  cached_conv2_vec.reserve(num_encoder_layers_.size());
+
+  int32_t left_context_length = decode_chunk_length_ / 2 * num_left_chunks_;
+  for (size_t i = 0; i != num_encoder_layers_.size(); ++i) {
+    int32_t num_layers = num_encoder_layers_[i];
+    int32_t ds = zipformer_downsampling_factors_[i];
+    int32_t attention_dim = attention_dims_[i];
+    int32_t left_context_len = left_context_length / ds;
+    int32_t encoder_dim = encoder_dims_[i];
+    int32_t cnn_module_kernel = cnn_module_kernels_[i];
+
+    auto cached_len = ncnn::Mat(num_layers);
+    auto cached_avg = ncnn::Mat(encoder_dim, num_layers);
+    auto cached_key = ncnn::Mat(attention_dim, left_context_len, num_layers);
+    auto cached_val =
+        ncnn::Mat(attention_dim / 2, left_context_len, num_layers);
+    auto cached_val2 =
+        ncnn::Mat(attention_dim / 2, left_context_len, num_layers);
+    auto cached_conv1 =
+        ncnn::Mat(cnn_module_kernel - 1, encoder_dim, num_layers);
+    auto cached_conv2 =
+        ncnn::Mat(cnn_module_kernel - 1, encoder_dim, num_layers);
+
+    cached_len.fill(0);
+    cached_avg.fill(0);
+    cached_key.fill(0);
+    cached_val.fill(0);
+    cached_val2.fill(0);
+    cached_conv1.fill(0);
+    cached_conv2.fill(0);
+
+    cached_len_vec.push_back(cached_len);
+    cached_avg_vec.push_back(cached_avg);
+    cached_key_vec.push_back(cached_key);
+    cached_val_vec.push_back(cached_val);
+    cached_val2_vec.push_back(cached_val2);
+    cached_conv1_vec.push_back(cached_conv1);
+    cached_conv2_vec.push_back(cached_conv2);
+  }
+
+  std::vector<ncnn::Mat> states;
+
+  states.reserve(num_encoder_layers_.size() * 7);
+  states.insert(states.end(), cached_len_vec.begin(), cached_len_vec.end());
+  states.insert(states.end(), cached_avg_vec.begin(), cached_avg_vec.end());
+  states.insert(states.end(), cached_key_vec.begin(), cached_key_vec.end());
+  states.insert(states.end(), cached_val_vec.begin(), cached_val_vec.end());
+  states.insert(states.end(), cached_val2_vec.begin(), cached_val2_vec.end());
+  states.insert(states.end(), cached_conv1_vec.begin(), cached_conv1_vec.end());
+  states.insert(states.end(), cached_conv2_vec.begin(), cached_conv2_vec.end());
+
+  return states;
+}
+
+void ZipformerModel::InitEncoderInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, features,
+  // [1] -> in1, layer0, cached_len
+  // [2] -> in2, layer1, cached_len
+  // [3] -> in3, layer2, cached_len
+  // ... ...
+  encoder_input_indexes_.resize(1 + num_encoder_layers_.size() * 7);
+
+  // output indexes map
+  // [0] -> out0, encoder_out
+  //
+  // [1] -> out1, layer0, cached_len
+  // [2] -> out2, layer1, cached_len
+  // [3] -> out3, layer2, cached_len
+  // ... ...
+  encoder_output_indexes_.resize(1 + num_encoder_layers_.size() * 7);
+  const auto &blobs = encoder_.blobs();
+
+  std::regex in_regex("in(\\d+)");
+  std::regex out_regex("out(\\d+)");
+
+  std::smatch match;
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (std::regex_match(b.name, match, in_regex)) {
+      auto index = std::atoi(match[1].str().c_str());
+      encoder_input_indexes_[index] = i;
+    } else if (std::regex_match(b.name, match, out_regex)) {
+      auto index = std::atoi(match[1].str().c_str());
+      encoder_output_indexes_[index] = i;
+    }
+  }
+}
+
+void ZipformerModel::InitDecoderInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, decoder_input,
+  decoder_input_indexes_.resize(1);
+
+  // output indexes map
+  // [0] -> out0, decoder_out,
+  decoder_output_indexes_.resize(1);
+
+  const auto &blobs = decoder_.blobs();
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (b.name == "in0") decoder_input_indexes_[0] = i;
+    if (b.name == "out0") decoder_output_indexes_[0] = i;
+  }
+}
+
+void ZipformerModel::InitJoinerInputOutputIndexes() {
+  // input indexes map
+  // [0] -> in0, encoder_input,
+  // [1] -> in1, decoder_input,
+  joiner_input_indexes_.resize(2);
+
+  // output indexes map
+  // [0] -> out0, joiner_out,
+  joiner_output_indexes_.resize(1);
+
+  const auto &blobs = joiner_.blobs();
+  for (int32_t i = 0; i != blobs.size(); ++i) {
+    const auto &b = blobs[i];
+    if (b.name == "in0") joiner_input_indexes_[0] = i;
+    if (b.name == "in1") joiner_input_indexes_[1] = i;
+    if (b.name == "out0") joiner_output_indexes_[0] = i;
+  }
+}
+
+}  // namespace sherpa_ncnn

+ 111 - 0
sherpa-ncnn/csrc/zipformer-model.h

@@ -0,0 +1,111 @@
+// sherpa-ncnn/csrc/zipformer-model.h
+//
+// Copyright (c)  2023  Xiaomi Corporation
+
+#ifndef SHERPA_NCNN_CSRC_ZIPFORMER_MODEL_H_
+#define SHERPA_NCNN_CSRC_ZIPFORMER_MODEL_H_
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "net.h"  // NOLINT
+#include "sherpa-ncnn/csrc/model.h"
+
+namespace sherpa_ncnn {
+// Please refer to https://github.com/k2-fsa/icefall/pull/906
+// for how the model is converted from icefall to ncnn
+class ZipformerModel : public Model {
+ public:
+  explicit ZipformerModel(const ModelConfig &config);
+#if __ANDROID_API__ >= 9
+  ZipformerModel(AAssetManager *mgr, const ModelConfig &config);
+#endif
+
+  ncnn::Net &GetEncoder() override { return encoder_; }
+  ncnn::Net &GetDecoder() override { return decoder_; }
+  ncnn::Net &GetJoiner() override { return joiner_; }
+
+  std::pair<ncnn::Mat, std::vector<ncnn::Mat>> RunEncoder(
+      ncnn::Mat &features, const std::vector<ncnn::Mat> &states) override;
+
+  std::pair<ncnn::Mat, std::vector<ncnn::Mat>> RunEncoder(
+      ncnn::Mat &features, const std::vector<ncnn::Mat> &states,
+      ncnn::Extractor *extractor) override;
+
+  ncnn::Mat RunDecoder(ncnn::Mat &decoder_input) override;
+
+  ncnn::Mat RunDecoder(ncnn::Mat &decoder_input,
+                       ncnn::Extractor *extractor) override;
+
+  ncnn::Mat RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out) override;
+
+  ncnn::Mat RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out,
+                      ncnn::Extractor *extractor) override;
+
+  int32_t Segment() const override {
+    // pad_length 7, because the subsampling expression is
+    // ((x_len - 7) // 2 + 1)//2, we need to pad 7 frames
+    //
+    // decode chunk length before subsample is 32 frames
+    //
+    // So each segment is pad_length + decode_chunk_length = 7 + 32 = 39
+    return decode_chunk_length_ + pad_length_;
+  }
+
+  // Advance the feature extract by this number of frames after
+  // running the encoder network
+  int32_t Offset() const override { return decode_chunk_length_; }
+
+ private:
+  void InitEncoder(const std::string &encoder_param,
+                   const std::string &encoder_bin);
+  void InitDecoder(const std::string &decoder_param,
+                   const std::string &decoder_bin);
+  void InitJoiner(const std::string &joiner_param,
+                  const std::string &joiner_bin);
+
+  void InitEncoderPostProcessing();
+
+#if __ANDROID_API__ >= 9
+  void InitEncoder(AAssetManager *mgr, const std::string &encoder_param,
+                   const std::string &encoder_bin);
+  void InitDecoder(AAssetManager *mgr, const std::string &decoder_param,
+                   const std::string &decoder_bin);
+  void InitJoiner(AAssetManager *mgr, const std::string &joiner_param,
+                  const std::string &joiner_bin);
+#endif
+
+  std::vector<ncnn::Mat> GetEncoderInitStates() const;
+
+  void InitEncoderInputOutputIndexes();
+  void InitDecoderInputOutputIndexes();
+  void InitJoinerInputOutputIndexes();
+
+ private:
+  ncnn::Net encoder_;
+  ncnn::Net decoder_;
+  ncnn::Net joiner_;
+
+  int32_t decode_chunk_length_ = 32;  // arg1, before subsampling
+  int32_t num_left_chunks_ = 4;       // arg2
+  int32_t pad_length_ = 7;            // arg3
+
+  std::vector<int32_t> num_encoder_layers_;              // arg16
+  std::vector<int32_t> encoder_dims_;                    // arg17
+  std::vector<int32_t> attention_dims_;                  // arg18
+  std::vector<int32_t> zipformer_downsampling_factors_;  // arg19
+  std::vector<int32_t> cnn_module_kernels_;              // arg20
+
+  std::vector<int32_t> encoder_input_indexes_;
+  std::vector<int32_t> encoder_output_indexes_;
+
+  std::vector<int32_t> decoder_input_indexes_;
+  std::vector<int32_t> decoder_output_indexes_;
+
+  std::vector<int32_t> joiner_input_indexes_;
+  std::vector<int32_t> joiner_output_indexes_;
+};
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_ZIPFORMER_MODEL_H_