Procházet zdrojové kódy

Add MetaData layer (#30)

* Add MetaData layer

* fix style issues

* fix CI

* small fixes

* fix style issues
Fangjun Kuang před 2 roky
rodič
revize
af2b31e168

+ 1 - 0
sherpa-ncnn/csrc/CMakeLists.txt

@@ -5,6 +5,7 @@ set(sherpa_ncnn_core_srcs
   decode.cc
   features.cc
   lstm-model.cc
+  meta-data.cc
   model.cc
   symbol-table.cc
   wave-reader.cc

+ 26 - 1
sherpa-ncnn/csrc/conv-emformer-model.cc

@@ -4,9 +4,13 @@
 
 #include "sherpa-ncnn/csrc/conv-emformer-model.h"
 
-#include <regex>
+#include <regex>  // NOLINT
+#include <string>
+#include <utility>
+#include <vector>
 
 #include "net.h"  // NOLINT
+#include "sherpa-ncnn/csrc/meta-data.h"
 
 namespace sherpa_ncnn {
 
@@ -79,7 +83,28 @@ ncnn::Mat ConvEmformerModel::RunJoiner(ncnn::Mat &encoder_out,
 
 void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
                                     const std::string &encoder_bin) {
+  RegisterMetaDataLayer(encoder_);
   InitNet(encoder_, encoder_param, encoder_bin);
+
+  // Now load parameters for member variables
+  for (const auto *layer : encoder_.layers()) {
+    if (layer->type == "SherpaMetaData" && layer->name == "sherpa_meta_data1") {
+      // Note: We don't use dynamic_cast<> here since it will throw
+      // the following error
+      //  error: ‘dynamic_cast’ not permitted with -fno-rtti
+      const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
+
+      num_layers_ = meta_data->arg1;
+      memory_size_ = meta_data->arg2;
+      cnn_module_kernel_ = meta_data->arg3;
+      left_context_length_ = meta_data->arg4;
+      chunk_length_ = meta_data->arg5;
+      right_context_length_ = meta_data->arg6;
+      d_model_ = meta_data->arg7;
+
+      break;
+    }
+  }
 }
 
 void ConvEmformerModel::InitDecoder(const std::string &decoder_param,

+ 14 - 7
sherpa-ncnn/csrc/conv-emformer-model.h

@@ -2,6 +2,12 @@
 //
 // Copyright (c)  2022  Xiaomi Corporation
 
+#ifndef SHERPA_NCNN_CSRC_CONV_EMFORMER_MODEL_H_
+#define SHERPA_NCNN_CSRC_CONV_EMFORMER_MODEL_H_
+#include <string>
+#include <utility>
+#include <vector>
+
 #include "net.h"  // NOLINT
 #include "sherpa-ncnn/csrc/model.h"
 
@@ -53,13 +59,13 @@ class ConvEmformerModel : public Model {
 
   int32_t num_threads_;
 
-  int32_t num_layers_ = 12;
-  int32_t memory_size_ = 32;
-  int32_t cnn_module_kernel_ = 31;
-  int32_t left_context_length_ = 32 / 4;
-  int32_t chunk_length_ = 32;
-  int32_t right_context_length_ = 8;
-  int32_t d_model_ = 512;
+  int32_t num_layers_ = 12;               // arg1
+  int32_t memory_size_ = 32;              // arg2
+  int32_t cnn_module_kernel_ = 31;        // arg3
+  int32_t left_context_length_ = 32 / 4;  // arg4
+  int32_t chunk_length_ = 32;             // arg5
+  int32_t right_context_length_ = 8;      // arg6
+  int32_t d_model_ = 512;                 // arg7
 
   std::vector<int32_t> encoder_input_indexes_;
   std::vector<int32_t> encoder_output_indexes_;
@@ -72,3 +78,4 @@ class ConvEmformerModel : public Model {
 };
 
 }  // namespace sherpa_ncnn
+#endif  // SHERPA_NCNN_CSRC_CONV_EMFORMER_MODEL_H_

+ 56 - 0
sherpa-ncnn/csrc/meta-data.cc

@@ -0,0 +1,56 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/meta-data.h"
+
+#include "net.h"  // NOLINT
+namespace sherpa_ncnn {
+
+int MetaData::load_param(const ncnn::ParamDict &pd) {
+  arg0 = pd.get(0, 0), arg1 = pd.get(1, 0), arg2 = pd.get(2, 0);
+  arg3 = pd.get(3, 0), arg4 = pd.get(4, 0), arg5 = pd.get(5, 0);
+  arg6 = pd.get(6, 0), arg7 = pd.get(7, 0), arg8 = pd.get(8, 0);
+  arg9 = pd.get(9, 0), arg10 = pd.get(10, 0), arg11 = pd.get(11, 0);
+  arg12 = pd.get(12, 0), arg13 = pd.get(13, 0), arg14 = pd.get(14, 0);
+  arg15 = pd.get(15, 0), arg16 = pd.get(16, 0), arg17 = pd.get(17, 0);
+  arg18 = pd.get(18, 0), arg19 = pd.get(19, 0), arg20 = pd.get(20, 0);
+  arg21 = pd.get(21, 0), arg22 = pd.get(22, 0), arg23 = pd.get(23, 0);
+
+  // The following 8 attributes are of type float
+  arg24 = pd.get(24, 0.f), arg25 = pd.get(25, 0.f), arg26 = pd.get(26, 0.f);
+  arg27 = pd.get(27, 0.f), arg28 = pd.get(28, 0.f), arg29 = pd.get(29, 0.f);
+  arg30 = pd.get(30, 0.f), arg31 = pd.get(31, 0.f);
+
+  return 0;
+}
+
+static ncnn::Layer *MetaDataCreator(void * /*userdata*/) {
+  return new MetaData();
+}
+
+/*
+In encoder.ncnn.param, you can use
+
+SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 24=1.5
+ */
+
+void RegisterMetaDataLayer(ncnn::Net &net) {
+  net.register_custom_layer("SherpaMetaData", MetaDataCreator);
+}
+
+}  // namespace sherpa_ncnn

+ 65 - 0
sherpa-ncnn/csrc/meta-data.h

@@ -0,0 +1,65 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_META_DATA_H_
+#define SHERPA_NCNN_CSRC_META_DATA_H_
+
+#include "layer.h"  // NOLINT
+#include "net.h"    // NOLINT
+
+namespace sherpa_ncnn {
+
+class MetaData : public ncnn::Layer {
+ public:
+  int load_param(const ncnn::ParamDict &pd) override;
+
+  int32_t arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7;
+  int32_t arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15;
+  int32_t arg16, arg17, arg18, arg19, arg20, arg21, arg22, arg23;
+
+  float arg24, arg25, arg26, arg27, arg28, arg29, arg30, arg31;
+};
+
+/*
+In encoder.ncnn.param, you can use
+
+SherpaMetaData sherpa_meta_data1 0 0 0=1 1=12 24=1.5
+
+For instace, suppose you have a encoder.ncnn.param looks like below:
+
+
+7767517
+1060 1342
+Input                    in0                      0 1 in0
+
+You can change it to
+
+7767517
+1061 1342
+SherpaMetaData            sherpa_meta_data1       0 0 0=1
+Input                    in0                      0 1 in0
+
+Note: You first need to change 1060 to 1061 since we add one layer
+
+ */
+
+void RegisterMetaDataLayer(ncnn::Net &net);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_META_DATA_H_

+ 16 - 11
sherpa-ncnn/csrc/model.cc

@@ -21,6 +21,7 @@
 
 #include "sherpa-ncnn/csrc/conv-emformer-model.h"
 #include "sherpa-ncnn/csrc/lstm-model.h"
+#include "sherpa-ncnn/csrc/meta-data.h"
 
 namespace sherpa_ncnn {
 
@@ -52,17 +53,19 @@ static bool IsLstmModel(const ncnn::Net &net) {
 
 static bool IsConvEmformerModel(const ncnn::Net &net) {
   // Note: We may need to add more constraints if number of models gets larger.
-  if (net.input_indexes().size() < 49) {
-    return false;
-  }
-
-  if (net.output_indexes().size() < 49) {
-    return false;
-  }
-
-  for (const auto &layer : net.layers()) {
-    if (layer->type == "GLU") {
-      return true;
+  //
+  // If the net has a layer of type SherpaMetaData and with name
+  // sherpa_meta_data1 and if attribute 0 is 1, we assume the model is
+  // a ConvEmformer model
+
+  for (const auto *layer : net.layers()) {
+    if (layer->type == "SherpaMetaData" && layer->name == "sherpa_meta_data1") {
+      // Note: We don't use dynamic_cast<> here since it will throw
+      // the following error
+      //  error: ‘dynamic_cast’ not permitted with -fno-rtti
+      const auto *meta_data = reinterpret_cast<const MetaData *>(layer);
+
+      if (meta_data->arg0 == 1) return true;
     }
   }
 
@@ -90,6 +93,8 @@ std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
   // in the future
 
   ncnn::Net net;
+  RegisterMetaDataLayer(net);
+
   auto ret = net.load_param(config.encoder_param.c_str());
   if (ret != 0) {
     NCNN_LOGE("Failed to load %s", config.encoder_param.c_str());

+ 2 - 0
sherpa-ncnn/csrc/model.h

@@ -21,6 +21,8 @@
 
 #include <memory>
 #include <string>
+#include <utility>
+#include <vector>
 
 #include "net.h"  // NOLINT