Эх сурвалжийг харах

Use tencent/ncnn master (#204)

* Optimize modified beam search to use ncnn/master

* Add PollingModuleNoProj

* add tensor as strided layer

* add simple up sample layer

* Refactor the code to use the master from tencent/ncnn
Fangjun Kuang 2 жил өмнө
parent
commit
05561bbeea

+ 9 - 8
cmake/ncnn.cmake

@@ -6,18 +6,18 @@ function(download_ncnn)
   # https://github.com/csukuangfj/ncnn/pull/7
   # https://github.com/csukuangfj/ncnn/pull/7
 
 
   # Please also change ../pack-for-embedded-systems.sh
   # Please also change ../pack-for-embedded-systems.sh
-  set(ncnn_URL  "https://github.com/csukuangfj/ncnn/archive/refs/tags/sherpa-0.9.tar.gz")
-  set(ncnn_URL2 "https://huggingface.co/csukuangfj/sherpa-ncnn-cmake-deps/resolve/main/ncnn-sherpa-0.9.tar.gz")
-  set(ncnn_HASH "SHA256=a5fe1f69c75c06d6de858c7c660c43395b6ed3df9ee59d6e2fe621211e6928cd")
+  set(ncnn_URL  "https://github.com/csukuangfj/ncnn/archive/refs/tags/sherpa-1.0.tar.gz")
+  set(ncnn_URL2 "https://huggingface.co/csukuangfj/sherpa-ncnn-cmake-deps/resolve/main/ncnn-sherpa-1.0.tar.gz")
+  set(ncnn_HASH "SHA256=7c80c34c7bdfb5ce4ef3c7727f1b24d2217fca590fc87a69b1a185ed62e33c95")
 
 
   # If you don't have access to the Internet, please download it to your
   # If you don't have access to the Internet, please download it to your
   # local drive and modify the following line according to your needs.
   # local drive and modify the following line according to your needs.
   set(possible_file_locations
   set(possible_file_locations
-    $ENV{HOME}/Downloads/ncnn-sherpa-0.9.tar.gz
-    $ENV{HOME}/asr/ncnn-sherpa-0.9.tar.gz
-    ${PROJECT_SOURCE_DIR}/ncnn-sherpa-0.9.tar.gz
-    ${PROJECT_BINARY_DIR}/ncnn-sherpa-0.9.tar.gz
-    /tmp/ncnn-sherpa-0.9.tar.gz
+    $ENV{HOME}/Downloads/ncnn-sherpa-1.0.tar.gz
+    $ENV{HOME}/asr/ncnn-sherpa-1.0.tar.gz
+    ${PROJECT_SOURCE_DIR}/ncnn-sherpa-1.0.tar.gz
+    ${PROJECT_BINARY_DIR}/ncnn-sherpa-1.0.tar.gz
+    /tmp/ncnn-sherpa-1.0.tar.gz
   )
   )
 
 
   foreach(f IN LISTS possible_file_locations)
   foreach(f IN LISTS possible_file_locations)
@@ -168,6 +168,7 @@ function(download_ncnn)
     Unfold
     Unfold
     GridSample
     GridSample
     CumulativeSum
     CumulativeSum
+    CopyTo
   )
   )
 
 
   foreach(layer IN LISTS disabled_layers)
   foreach(layer IN LISTS disabled_layers)

+ 5 - 1
sherpa-ncnn/csrc/CMakeLists.txt

@@ -2,6 +2,7 @@ include_directories(${CMAKE_SOURCE_DIR})
 
 
 set(sherpa_ncnn_core_srcs
 set(sherpa_ncnn_core_srcs
   conv-emformer-model.cc
   conv-emformer-model.cc
+  decoder.cc
   endpoint.cc
   endpoint.cc
   features.cc
   features.cc
   greedy-search-decoder.cc
   greedy-search-decoder.cc
@@ -9,12 +10,15 @@ set(sherpa_ncnn_core_srcs
   lstm-model.cc
   lstm-model.cc
   meta-data.cc
   meta-data.cc
   model.cc
   model.cc
-  decoder.cc
   modified-beam-search-decoder.cc
   modified-beam-search-decoder.cc
+  poolingmodulenoproj.cc
   recognizer.cc
   recognizer.cc
   resample.cc
   resample.cc
+  simpleupsample.cc
+  stack.cc
   stream.cc
   stream.cc
   symbol-table.cc
   symbol-table.cc
+  tensorasstrided.cc
   wave-reader.cc
   wave-reader.cc
   zipformer-model.cc
   zipformer-model.cc
 )
 )

+ 2 - 2
sherpa-ncnn/csrc/conv-emformer-model.cc

@@ -170,7 +170,7 @@ void ConvEmformerModel::InitEncoderPostProcessing() {
 
 
 void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
 void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
                                     const std::string &encoder_bin) {
                                     const std::string &encoder_bin) {
-  RegisterMetaDataLayer(encoder_);
+  RegisterCustomLayers(encoder_);
   InitNet(encoder_, encoder_param, encoder_bin);
   InitNet(encoder_, encoder_param, encoder_bin);
   InitEncoderPostProcessing();
   InitEncoderPostProcessing();
 }
 }
@@ -189,7 +189,7 @@ void ConvEmformerModel::InitJoiner(const std::string &joiner_param,
 void ConvEmformerModel::InitEncoder(AAssetManager *mgr,
 void ConvEmformerModel::InitEncoder(AAssetManager *mgr,
                                     const std::string &encoder_param,
                                     const std::string &encoder_param,
                                     const std::string &encoder_bin) {
                                     const std::string &encoder_bin) {
-  RegisterMetaDataLayer(encoder_);
+  RegisterCustomLayers(encoder_);
   InitNet(mgr, encoder_, encoder_param, encoder_bin);
   InitNet(mgr, encoder_, encoder_param, encoder_bin);
   InitEncoderPostProcessing();
   InitEncoderPostProcessing();
 }
 }

+ 2 - 2
sherpa-ncnn/csrc/lstm-model.cc

@@ -162,7 +162,7 @@ void LstmModel::InitEncoder(const std::string &encoder_param,
   encoder_.opt.use_packing_layout = false;
   encoder_.opt.use_packing_layout = false;
   encoder_.opt.use_fp16_storage = false;
   encoder_.opt.use_fp16_storage = false;
 
 
-  RegisterMetaDataLayer(encoder_);
+  RegisterCustomLayers(encoder_);
   InitNet(encoder_, encoder_param, encoder_bin);
   InitNet(encoder_, encoder_param, encoder_bin);
 
 
   InitEncoderPostProcessing();
   InitEncoderPostProcessing();
@@ -185,7 +185,7 @@ void LstmModel::InitEncoder(AAssetManager *mgr,
   encoder_.opt.use_packing_layout = false;
   encoder_.opt.use_packing_layout = false;
   encoder_.opt.use_fp16_storage = false;
   encoder_.opt.use_fp16_storage = false;
 
 
-  RegisterMetaDataLayer(encoder_);
+  RegisterCustomLayers(encoder_);
   InitNet(mgr, encoder_, encoder_param, encoder_bin);
   InitNet(mgr, encoder_, encoder_param, encoder_bin);
 
 
   InitEncoderPostProcessing();
   InitEncoderPostProcessing();

+ 2 - 0
sherpa-ncnn/csrc/meta-data.h

@@ -32,6 +32,8 @@ class MetaData : public ncnn::Layer {
   //  1 - ConvEmformer
   //  1 - ConvEmformer
   //  2 - Zipformer
   //  2 - Zipformer
   //  3 - LSTM
   //  3 - LSTM
+  //
+  //  arg15 is the model version, defaults to 0
   int32_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
   int32_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
   int32_t arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15;
   int32_t arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15;
 
 

+ 40 - 6
sherpa-ncnn/csrc/model.cc

@@ -22,6 +22,10 @@
 #include "sherpa-ncnn/csrc/conv-emformer-model.h"
 #include "sherpa-ncnn/csrc/conv-emformer-model.h"
 #include "sherpa-ncnn/csrc/lstm-model.h"
 #include "sherpa-ncnn/csrc/lstm-model.h"
 #include "sherpa-ncnn/csrc/meta-data.h"
 #include "sherpa-ncnn/csrc/meta-data.h"
+#include "sherpa-ncnn/csrc/poolingmodulenoproj.h"
+#include "sherpa-ncnn/csrc/simpleupsample.h"
+#include "sherpa-ncnn/csrc/stack.h"
+#include "sherpa-ncnn/csrc/tensorasstrided.h"
 #include "sherpa-ncnn/csrc/zipformer-model.h"
 #include "sherpa-ncnn/csrc/zipformer-model.h"
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
@@ -93,10 +97,29 @@ static bool IsZipformerModel(const ncnn::Net &net) {
       //  error: ‘dynamic_cast’ not permitted with -fno-rtti
       //  error: ‘dynamic_cast’ not permitted with -fno-rtti
       const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
       const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
 
 
-      if (meta_data->arg0 == 2) return true;
+      if (meta_data->arg0 == 2) {
+        // arg15 is the version.
+        // Staring from sherpa-ncnn 2.0, we use the master of tencent/ncnn
+        // directly and we have update the version of Zipformer from 0 to 1.
+        //
+        // If yo are using an older version of Zipformer, please
+        // re-download the model or re-export the model using the latest icefall
+        // or use sherpa-ncnn < v2.0
+        if (meta_data->arg15 < 1) {
+          NCNN_LOGE(
+              "You are using a too old version of Zipformer. You can "
+              "choose one of the following solutions: \n"
+              "  (1) Re-download the latest model\n"
+              "  (2) Re-export your model using the latest icefall. Remember "
+              "to strictly follow the documentation\n"
+              "      to update the version number to 1.\n"
+              "  (3) Use sherpa-ncnn < v2.0 (not recommended)\n");
+          exit(-1);
+        }
+        return true;
+      }
     }
     }
   }
   }
-
   return false;
   return false;
 }
 }
 
 
@@ -128,6 +151,15 @@ void Model::InitNet(AAssetManager *mgr, ncnn::Net &net,
 }
 }
 #endif
 #endif
 
 
+void Model::RegisterCustomLayers(ncnn::Net &net) {
+  RegisterMetaDataLayer(net);
+
+  RegisterPoolingModuleNoProjLayer(net);   // for zipformer only
+  RegisterTensorAsStridedLayer(net);       // for zipformer only
+  RegisterTensorSimpleUpsampleLayer(net);  // for zipformer only
+  RegisterStackLayer(net);                 // for zipformer only
+}
+
 std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
 std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
   // 1. Load the encoder network
   // 1. Load the encoder network
   // 2. If the encoder network has LSTM layers, we assume it is a LstmModel
   // 2. If the encoder network has LSTM layers, we assume it is a LstmModel
@@ -136,7 +168,7 @@ std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
   // in the future
   // in the future
 
 
   ncnn::Net net;
   ncnn::Net net;
-  RegisterMetaDataLayer(net);
+  RegisterCustomLayers(net);
 
 
   auto ret = net.load_param(config.encoder_param.c_str());
   auto ret = net.load_param(config.encoder_param.c_str());
   if (ret != 0) {
   if (ret != 0) {
@@ -159,7 +191,8 @@ std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
   NCNN_LOGE(
   NCNN_LOGE(
       "Unable to create a model from specified model files.\n"
       "Unable to create a model from specified model files.\n"
       "Please check: \n"
       "Please check: \n"
-      "  1. If you are using a ConvEmformer/Zipformer/LSTM model, please make "
+      "  1. If you are using a ConvEmformer/Zipformer/LSTM model, please "
+      "make "
       "sure "
       "sure "
       "you have added SherapMetaData to encoder_xxx.ncnn.param "
       "you have added SherapMetaData to encoder_xxx.ncnn.param "
       "(or encoder_xxx.ncnn.int8.param if you are using an int8 model). "
       "(or encoder_xxx.ncnn.int8.param if you are using an int8 model). "
@@ -173,7 +206,7 @@ std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
 std::unique_ptr<Model> Model::Create(AAssetManager *mgr,
 std::unique_ptr<Model> Model::Create(AAssetManager *mgr,
                                      const ModelConfig &config) {
                                      const ModelConfig &config) {
   ncnn::Net net;
   ncnn::Net net;
-  RegisterMetaDataLayer(net);
+  RegisterCustomLayers(net);
 
 
   auto ret = net.load_param(mgr, config.encoder_param.c_str());
   auto ret = net.load_param(mgr, config.encoder_param.c_str());
   if (ret != 0) {
   if (ret != 0) {
@@ -196,7 +229,8 @@ std::unique_ptr<Model> Model::Create(AAssetManager *mgr,
   NCNN_LOGE(
   NCNN_LOGE(
       "Unable to create a model from specified model files.\n"
       "Unable to create a model from specified model files.\n"
       "Please check: \n"
       "Please check: \n"
-      "  1. If you are using a ConvEmformer/Zipformer/LSTM model, please make "
+      "  1. If you are using a ConvEmformer/Zipformer/LSTM model, please "
+      "make "
       "sure "
       "sure "
       "you have added SherapMetaData to encoder_xxx.ncnn.param "
       "you have added SherapMetaData to encoder_xxx.ncnn.param "
       "(or encoder_xxx.ncnn.int8.param if you are using an int8 model). "
       "(or encoder_xxx.ncnn.int8.param if you are using an int8 model). "

+ 2 - 0
sherpa-ncnn/csrc/model.h

@@ -49,6 +49,8 @@ class Model {
  public:
  public:
   virtual ~Model() = default;
   virtual ~Model() = default;
 
 
+  static void RegisterCustomLayers(ncnn::Net &net);
+
   /** Create a model from a config. */
   /** Create a model from a config. */
   static std::unique_ptr<Model> Create(const ModelConfig &config);
   static std::unique_ptr<Model> Create(const ModelConfig &config);
 
 

+ 6 - 26
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -27,29 +27,6 @@
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
-// @param in 1-D tensor of shape (encoder_dim,)
-// @param n Number of times to repeat
-// @return Return a 2-d tensor of shape (n, encoder_dim)
-//
-// TODO(fangjun): Remove this function
-// once
-// https://github.com/nihui/ncnn/tree/pnnx-ncnn-binary-broadcast
-// gets merged
-static ncnn::Mat RepeatEncoderOut(ncnn::Mat in, int32_t n) {
-  int32_t w = in.w;
-  ncnn::Mat out(w, n, sizeof(float));
-
-  const float *in_ptr = in;
-  float *out_ptr = out;
-
-  for (int32_t i = 0; i != n; ++i) {
-    std::copy(in_ptr, in_ptr + w, out_ptr);
-    out_ptr += w;
-  }
-
-  return out;
-}
-
 DecoderResult ModifiedBeamSearchDecoder::GetEmptyResult() const {
 DecoderResult ModifiedBeamSearchDecoder::GetEmptyResult() const {
   DecoderResult r;
   DecoderResult r;
 
 
@@ -159,10 +136,13 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
 
 
     // decoder_out.w == decoder_dim
     // decoder_out.w == decoder_dim
     // decoder_out.h == num_active_paths
     // decoder_out.h == num_active_paths
-    ncnn::Mat encoder_out_t(encoder_out.w, encoder_out.row(t));
-    encoder_out_t = RepeatEncoderOut(encoder_out_t, decoder_out.h);
-
+    ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
+    // Note: encoder_out_t.h == 1, we rely on the binary op broadcasting
+    // in ncnn
+    // See https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
+    // broadcast B for outer axis, type 14
     ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
     ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
+
     // joiner_out.w == vocab_size
     // joiner_out.w == vocab_size
     // joiner_out.h == num_active_paths
     // joiner_out.h == num_active_paths
     LogSoftmax(&joiner_out);
     LogSoftmax(&joiner_out);

+ 97 - 0
sherpa-ncnn/csrc/poolingmodulenoproj.cc

@@ -0,0 +1,97 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/poolingmodulenoproj.h"
+
+namespace sherpa_ncnn {
+
+PoolingModuleNoProj::PoolingModuleNoProj() {
+  one_blob_only = false;
+  support_inplace = false;
+}
+
+int32_t PoolingModuleNoProj::forward(const std::vector<ncnn::Mat> &bottom_blobs,
+                                     std::vector<ncnn::Mat> &top_blobs,
+                                     const ncnn::Option &opt) const {
+  ncnn::Mat x = bottom_blobs[0];
+  ncnn::Mat cached_len = bottom_blobs[1];
+  ncnn::Mat cached_avg = bottom_blobs[2];
+
+  // x.dims = 2, x.w = C, x.h = T
+  // cached_len.dims = 1, cached_len.w = 1
+  // cached_avg.dims = 2, cached_avg.w = C, cached_avg.h = 1
+
+  ncnn::Mat &out_x = top_blobs[0];
+  out_x.create_like(x, opt.blob_allocator);
+
+  ncnn::Mat &out_cached_len = top_blobs[1];
+  out_cached_len.create(cached_len.w, cached_len.elemsize, opt.blob_allocator);
+
+  ncnn::Mat &out_cached_avg = top_blobs[2];
+  out_cached_avg.create_like(cached_avg, opt.blob_allocator);
+
+  int32_t w = x.w;
+  int32_t h = x.h;
+
+  const float *x_ptr = x;
+  const float *cached_avg_ptr = cached_avg;
+  float *out_ptr = out_x;
+
+  float n = cached_len[0];
+
+  // process row 0
+  for (int32_t c = 0; c < w; ++c) {
+    out_ptr[c] = x_ptr[c] + n * cached_avg_ptr[c];
+  }
+
+  for (int32_t r = 1; r < h; ++r) {
+    const float *x_cur = x.row(r);
+
+    float *out_prev = out_x.row(r - 1);
+    float *out_cur = out_x.row(r);
+
+    float scale = 1. / (n + r);  // scale for the previous row
+    for (int32_t c = 0; c < w; ++c) {
+      out_cur[c] = out_prev[c] + x_cur[c];
+      out_prev[c] *= scale;
+    }
+  }
+
+  float *last_row = out_x.row(h - 1);
+  float scale = 1. / (n + h);
+
+  float *out_cached_avg_ptr = out_cached_avg;
+  for (int32_t c = 0; c < w; ++c) {
+    last_row[c] *= scale;
+    out_cached_avg_ptr[c] = last_row[c];
+  }
+
+  out_cached_len[0] = n + h;
+
+  return 0;
+}
+
+static ncnn::Layer *PoolingModuleNoProjCreator(void * /*userdata*/) {
+  return new PoolingModuleNoProj();
+}
+
+void RegisterPoolingModuleNoProjLayer(ncnn::Net &net) {
+  net.register_custom_layer("PoolingModuleNoProj", PoolingModuleNoProjCreator);
+}
+
+}  // namespace sherpa_ncnn

+ 43 - 0
sherpa-ncnn/csrc/poolingmodulenoproj.h

@@ -0,0 +1,43 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SHERPA_NCNN_CSRC_POOLINGMODULENOPROJ_H_
+#define SHERPA_NCNN_CSRC_POOLINGMODULENOPROJ_H_
+
+#include <utility>
+#include <vector>
+
+#include "layer.h"  // NOLINT
+#include "net.h"    // NOLINT
+
+namespace sherpa_ncnn {
+
+// used only by zipformer
+class PoolingModuleNoProj : public ncnn::Layer {
+ public:
+  PoolingModuleNoProj();
+
+  int32_t forward(const std::vector<ncnn::Mat> &bottom_blobs,
+                  std::vector<ncnn::Mat> &top_blobs,
+                  const ncnn::Option &opt) const override;
+};
+
+void RegisterPoolingModuleNoProjLayer(ncnn::Net &net);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_POOLINGMODULENOPROJ_H_

+ 91 - 0
sherpa-ncnn/csrc/simpleupsample.cc

@@ -0,0 +1,91 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/simpleupsample.h"
+
+namespace sherpa_ncnn {
+
+SimpleUpsample::SimpleUpsample() {
+  one_blob_only = true;
+  support_inplace = false;
+}
+
+int32_t SimpleUpsample::load_param(const ncnn::ParamDict &pd) {
+  upsample = pd.get(0, 0);
+  num_channels = pd.get(1, 0);
+  bias_data_size = pd.get(2, 0);
+  if (bias_data_size != upsample * num_channels) {
+    NCNN_LOGE("upsample: %d, num_channels: %d, bias_data_size: %d. %dx%d!=%d",
+              upsample, num_channels, bias_data_size, upsample, num_channels,
+              bias_data_size);
+    return -100;
+  }
+
+  return 0;
+}
+
+int32_t SimpleUpsample::load_model(const ncnn::ModelBin &mb) {
+  bias = mb.load(num_channels, upsample, 0);
+  if (bias.empty()) return -100;
+
+  return 0;
+}
+
+int32_t SimpleUpsample::forward(const ncnn::Mat &bottom_blob,
+                                ncnn::Mat &top_blob,
+                                const ncnn::Option &opt) const {
+  // bottom_blob.dims == 2
+  // bottom_blob.w == seq_len
+  // bottom_blob.h == num_channels
+
+  int32_t outw = bottom_blob.w;
+  int32_t outh = upsample;
+  int32_t outc = bottom_blob.h;
+  size_t elemsize = bottom_blob.elemsize;
+
+  top_blob.create(outw, outh, outc, elemsize, opt.blob_allocator);
+  if (top_blob.empty()) return -100;
+
+#pragma omp parallel for num_threads(opt.num_threads)
+  for (int32_t q = 0; q < outc; ++q) {
+    ncnn::Mat out_m = top_blob.channel(q);
+    const float *a_ptr = bottom_blob.row(q);
+
+    for (int32_t y = 0; y < outh; ++y) {
+      float *out_ptr = out_m.row(y);
+      const float *b_ptr = bias.row(y);
+      for (int32_t x = 0; x < outw; ++x) {
+        out_ptr[x] = a_ptr[x] + b_ptr[x];
+      }
+    }
+  }
+
+  top_blob = top_blob.reshape(outw, outh * outc);
+
+  return 0;
+}
+
+static ncnn::Layer *SimpleUpsampleCreator(void * /*userdata*/) {
+  return new SimpleUpsample();
+}
+
+void RegisterTensorSimpleUpsampleLayer(ncnn::Net &net) {
+  net.register_custom_layer("SimpleUpsample", SimpleUpsampleCreator);
+}
+
+}  // namespace sherpa_ncnn

+ 52 - 0
sherpa-ncnn/csrc/simpleupsample.h

@@ -0,0 +1,52 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_SIMPLEUPSAMPLE_H_
+#define SHERPA_NCNN_CSRC_SIMPLEUPSAMPLE_H_
+
+#include <utility>
+
+#include "layer.h"  // NOLINT
+#include "net.h"    // NOLINT
+
+namespace sherpa_ncnn {
+
+class SimpleUpsample : public ncnn::Layer {
+ public:
+  SimpleUpsample();
+
+  int32_t load_param(const ncnn::ParamDict &pd) override;
+
+  int32_t load_model(const ncnn::ModelBin &mb) override;
+
+  int32_t forward(const ncnn::Mat &bottom_blob, ncnn::Mat &top_blob,
+                  const ncnn::Option &opt) const override;
+
+ public:
+  int32_t upsample;
+  int32_t num_channels;
+  int32_t bias_data_size;
+
+  ncnn::Mat bias;
+};
+
+void RegisterTensorSimpleUpsampleLayer(ncnn::Net &net);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_SIMPLEUPSAMPLE_H_

+ 98 - 0
sherpa-ncnn/csrc/stack.cc

@@ -0,0 +1,98 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/stack.h"
+
+namespace sherpa_ncnn {
+
+Stack::Stack() {
+  one_blob_only = false;
+  support_inplace = false;
+}
+
+int32_t Stack::load_param(const ncnn::ParamDict &pd) {
+  axis = pd.get(0, 0);
+  if (axis != 0) {
+    NCNN_LOGE("Stack: Only axis==0 is implemented. Given %d", axis);
+    return -100;
+  }
+
+  return 0;
+}
+
+int32_t Stack::forward(const std::vector<ncnn::Mat> &bottom_blobs,
+                       std::vector<ncnn::Mat> &top_blobs,
+                       const ncnn::Option &opt) const {
+  int32_t dims = bottom_blobs[0].dims;
+  size_t elemsize = bottom_blobs[0].elemsize;
+
+  if (dims == 1) {
+    int32_t out_w = bottom_blobs[0].w;
+    int32_t out_h = bottom_blobs.size();
+
+    ncnn::Mat &top_blob = top_blobs[0];
+    top_blob.create(out_w, out_h, elemsize, opt.blob_allocator);
+    if (top_blob.empty()) return -100;
+
+    unsigned char *outptr = top_blob;
+
+    size_t bytes_per_blob = out_w * elemsize;
+
+    for (size_t b = 0; b < bottom_blobs.size(); ++b) {
+      const unsigned char *ptr = bottom_blobs[b];
+      memcpy(outptr, ptr, bytes_per_blob);
+
+      outptr += bytes_per_blob;
+    }
+
+    return 0;
+  }
+
+  if (dims == 2) {
+    int32_t out_w = bottom_blobs[0].w;
+    int32_t out_h = bottom_blobs[0].h;
+    int32_t out_c = bottom_blobs.size();
+
+    ncnn::Mat &top_blob = top_blobs[0];
+    top_blob.create(out_w, out_h, out_c, elemsize, opt.blob_allocator);
+    if (top_blob.empty()) return -100;
+
+    size_t bytes_per_blob = out_w * out_h * elemsize;
+
+    for (size_t b = 0; b < bottom_blobs.size(); ++b) {
+      unsigned char *outptr = top_blob.channel(b);
+      const unsigned char *ptr = bottom_blobs[b];
+
+      memcpy(outptr, ptr, bytes_per_blob);
+    }
+
+    return 0;
+  }
+
+  NCNN_LOGE("Stack: dim %d is not implemented", dims);
+
+  return -100;
+}
+
+static ncnn::Layer *StackCreator(void * /*userdata*/) { return new Stack(); }
+
+void RegisterStackLayer(ncnn::Net &net) {
+  net.register_custom_layer("Stack", StackCreator);
+}
+
+}  // namespace sherpa_ncnn

+ 48 - 0
sherpa-ncnn/csrc/stack.h

@@ -0,0 +1,48 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_STACK_H_
+#define SHERPA_NCNN_CSRC_STACK_H_
+
+#include <utility>
+#include <vector>
+
+#include "layer.h"  // NOLINT
+#include "net.h"    // NOLINT
+
+namespace sherpa_ncnn {
+
+class Stack : public ncnn::Layer {
+ public:
+  Stack();
+
+  int32_t load_param(const ncnn::ParamDict &pd) override;
+
+  int32_t forward(const std::vector<ncnn::Mat> &bottom_blobs,
+                  std::vector<ncnn::Mat> &top_blobs,
+                  const ncnn::Option &opt) const override;
+
+ public:
+  int32_t axis;
+};
+
+void RegisterStackLayer(ncnn::Net &net);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_STACK_H_

+ 118 - 0
sherpa-ncnn/csrc/tensorasstrided.cc

@@ -0,0 +1,118 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/tensorasstrided.h"
+
+namespace sherpa_ncnn {
+
+TensorAsStrided::TensorAsStrided() {
+  one_blob_only = true;
+  support_inplace = false;
+}
+
+int32_t TensorAsStrided::load_param(const ncnn::ParamDict &pd) {
+  sizes = pd.get(0, ncnn::Mat());
+  strides = pd.get(1, ncnn::Mat());
+  storage_offset = pd.get(2, 0);
+
+  if (sizes.dims != 1 && strides.dims != 1) {
+    if (sizes.dims != 0) {
+      NCNN_LOGE("sizes.dims: %d, strides.dims: %d. They are not 1!\n",
+                sizes.dims, strides.dims);
+      return -100;
+    }
+  }
+
+  if (sizes.w != strides.w) {
+    NCNN_LOGE("sizes.w: %d, strides.w: %d. They are not equal!\n", sizes.w,
+              strides.w);
+    return -100;
+  }
+
+  return 0;
+}
+
+int32_t TensorAsStrided::forward(const ncnn::Mat &bottom_blob,
+                                 ncnn::Mat &top_blob,
+                                 const ncnn::Option &opt) const {
+  const int32_t *p_sizes = sizes;
+  const int32_t *p_strides = strides;
+
+  if (sizes.w == 3) {
+    if (bottom_blob.dims != 3) {
+      NCNN_LOGE("Only 3-D tensors are supported right now");
+      return -100;
+    }
+
+    int32_t inc = bottom_blob.c;
+    int32_t inh = bottom_blob.h;
+    int32_t inw = bottom_blob.w;
+
+    int32_t outc = p_sizes[0];
+    int32_t outh = p_sizes[1];
+    int32_t outw = p_sizes[2];
+
+    if (bottom_blob.c != outc) {
+      NCNN_LOGE("We only implement in_c == out_c right now");
+      return -100;
+    }
+
+    if (p_strides[0] != inh * inw) {
+      NCNN_LOGE("Stride that crosses channels is not supported");
+      return -100;
+    }
+
+    size_t elemsize = bottom_blob.elemsize;
+    top_blob.create(outw, outh, outc, elemsize, opt.blob_allocator);
+
+    int32_t stride1 = p_strides[1];
+    int32_t stride2 = p_strides[2];
+
+#pragma omp parallel for num_threads(opt.num_threads)
+    for (int32_t q = 0; q < outc; q++) {
+      ncnn::Mat out_m = top_blob.channel(q);
+
+      const float *in_m = bottom_blob.channel(q);
+      in_m += storage_offset;
+
+      for (int32_t y = 0; y < outh; ++y) {
+        float *out_ptr = out_m.row(y);
+        const float *in_ptr = in_m + y * stride1;
+        for (int32_t x = 0; x < outw; ++x) {
+          out_ptr[x] = in_ptr[x * stride2];
+        }
+      }
+    }
+
+    return 0;
+  }
+
+  NCNN_LOGE("TensorAsStrided: Only 3-D tensors are supported right now");
+
+  return -100;
+}
+
+static ncnn::Layer *TensorAsStridedCreator(void * /*userdata*/) {
+  return new TensorAsStrided();
+}
+
+void RegisterTensorAsStridedLayer(ncnn::Net &net) {
+  net.register_custom_layer("TensorAsStrided", TensorAsStridedCreator);
+}
+
+}  // namespace sherpa_ncnn

+ 48 - 0
sherpa-ncnn/csrc/tensorasstrided.h

@@ -0,0 +1,48 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_TENSORASSTRIDED_H_
+#define SHERPA_NCNN_CSRC_TENSORASSTRIDED_H_
+
+#include <utility>
+
+#include "layer.h"  // NOLINT
+#include "net.h"    // NOLINT
+
+namespace sherpa_ncnn {
+
+class TensorAsStrided : public ncnn::Layer {
+ public:
+  TensorAsStrided();
+
+  int32_t load_param(const ncnn::ParamDict &pd) override;
+
+  int32_t forward(const ncnn::Mat &bottom_blob, ncnn::Mat &top_blob,
+                  const ncnn::Option &opt) const override;
+
+ public:
+  ncnn::Mat sizes;
+  ncnn::Mat strides;
+  int32_t storage_offset;
+};
+
+void RegisterTensorAsStridedLayer(ncnn::Net &net);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_TENSORASSTRIDED_H_

+ 2 - 2
sherpa-ncnn/csrc/zipformer-model.cc

@@ -206,7 +206,7 @@ void ZipformerModel::InitEncoderPostProcessing() {
 
 
 void ZipformerModel::InitEncoder(const std::string &encoder_param,
 void ZipformerModel::InitEncoder(const std::string &encoder_param,
                                  const std::string &encoder_bin) {
                                  const std::string &encoder_bin) {
-  RegisterMetaDataLayer(encoder_);
+  RegisterCustomLayers(encoder_);
   InitNet(encoder_, encoder_param, encoder_bin);
   InitNet(encoder_, encoder_param, encoder_bin);
   InitEncoderPostProcessing();
   InitEncoderPostProcessing();
 }
 }
@@ -225,7 +225,7 @@ void ZipformerModel::InitJoiner(const std::string &joiner_param,
 void ZipformerModel::InitEncoder(AAssetManager *mgr,
 void ZipformerModel::InitEncoder(AAssetManager *mgr,
                                  const std::string &encoder_param,
                                  const std::string &encoder_param,
                                  const std::string &encoder_bin) {
                                  const std::string &encoder_bin) {
-  RegisterMetaDataLayer(encoder_);
+  RegisterCustomLayers(encoder_);
   InitNet(mgr, encoder_, encoder_param, encoder_bin);
   InitNet(mgr, encoder_, encoder_param, encoder_bin);
   InitEncoderPostProcessing();
   InitEncoderPostProcessing();
 }
 }