Browse Source

minor fix : code style && typo (#59)

PF Luo 2 years ago
parent
commit
8b9d083f82

+ 2 - 2
.github/scripts/run-test.sh

@@ -184,8 +184,8 @@ for wave in ${waves[@]}; do
     $repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
     $repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
     $repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
-    $wave
-    4
+    $wave \
+    4 \
     "modified_beam_search"
 done
 

+ 4 - 8
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -33,7 +33,6 @@ void ModifiedBeamSearchDecoder::AcceptWaveform(const float sample_rate,
 }
 
 void ModifiedBeamSearchDecoder::BuildDecoderInput(Hypothesis hyp) {
-  decoder_input_.reshape(context_size_);
   for (int32_t i = 0; i != context_size_; ++i) {
     static_cast<int32_t *>(decoder_input_)[i] =
         *(hyp.ys.end() - context_size_ + i);
@@ -58,11 +57,7 @@ void ModifiedBeamSearchDecoder::Decode() {
     std::vector<Hypothesis> prev;
     /* encoder_out_.w == encoder_out_dim, encoder_out_.h == num_frames. */
     for (int32_t t = 0; t != encoder_out_.h; ++t) {
-      prev.clear();
-      prev.reserve(config_.num_active_paths);
-      for (auto &h : cur) {
-        prev.push_back(std::move(h.second));
-      }
+      prev = std::move(cur.Vec());
       cur.clear();
 
       for (const auto &h : prev) {
@@ -76,19 +71,20 @@ void ModifiedBeamSearchDecoder::Decode() {
         // update active_paths
         auto topk =
             topk_index(joiner_out_ptr, joiner_out.w, config_.num_active_paths);
-        for (int i = 0; i != topk.size(); i++) {
+        for (int i = 0; i != topk.size(); ++i) {
           Hypothesis new_hyp = h;
           int32_t new_token = topk[i];
           if (new_token != blank_id_) {
             new_hyp.ys.push_back(new_token);
             new_hyp.num_trailing_blanks = 0;
           } else {
-            new_hyp.num_trailing_blanks += 1;
+            ++new_hyp.num_trailing_blanks;
           }
           new_hyp.log_prob += joiner_out_ptr[new_token];
           cur.Add(std::move(new_hyp));
         }
       }
+      // prune active_paths
       while (cur.Size() > config_.num_active_paths) {
         auto least_hyp = cur.GetLeastProbable(true);
         cur.Remove(least_hyp);

+ 3 - 3
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -45,7 +45,7 @@ static void Handler(int sig) {
 };
 
 int main(int32_t argc, char *argv[]) {
-  if (argc != 8 && argc != 9) {
+  if (argc < 8 || argc > 10) {
     const char *usage = R"usage(
 Usage:
   ./bin/sherpa-ncnn-microphone \
@@ -56,7 +56,7 @@ Usage:
     /path/to/decoder.ncnn.bin \
     /path/to/joiner.ncnn.param \
     /path/to/joiner.ncnn.bin \
-    [num_threads]
+    [num_threads] [decode_method, can be greedy_search/modified_beam_search]
 
 Please refer to
 https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
@@ -91,7 +91,7 @@ for a list of pre-trained models to download.
   sherpa_ncnn::DecoderConfig decoder_conf;
   if (argc == 10) {
     std::string method = argv[9];
-    if (method.compare("greed_search") ||
+    if (method.compare("greedy_search") ||
         method.compare("modified_beam_search")) {
       decoder_conf.method = method;
     }

+ 1 - 1
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -66,7 +66,7 @@ for a list of pre-trained models to download.
   sherpa_ncnn::DecoderConfig decoder_conf;
   if (argc == 11) {
     std::string method = argv[10];
-    if (method.compare("greed_search") ||
+    if (method.compare("greedy_search") ||
         method.compare("modified_beam_search")) {
       decoder_conf.method = method;
     }