Browse Source

Fix hotword bugs (#262)

Fangjun Kuang 1 year ago
parent
commit
8be3e0818b

+ 1 - 1
CMakeLists.txt

@@ -1,7 +1,7 @@
 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
 project(sherpa-ncnn)
 project(sherpa-ncnn)
 
 
-set(SHERPA_NCNN_VERSION "2.1.0")
+set(SHERPA_NCNN_VERSION "2.1.1")
 
 
 # Disable warning about
 # Disable warning about
 #
 #

+ 1 - 1
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -251,7 +251,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
         new_hyp.ys.push_back(new_token);
         new_hyp.ys.push_back(new_token);
         new_hyp.num_trailing_blanks = 0;
         new_hyp.num_trailing_blanks = 0;
         new_hyp.timestamps.push_back(t + frame_offset);
         new_hyp.timestamps.push_back(t + frame_offset);
-        if (s != nullptr && s->GetContextGraph() != nullptr) {
+        if (s && s->GetContextGraph()) {
           auto context_res =
           auto context_res =
               s->GetContextGraph()->ForwardOneStep(context_state, new_token);
               s->GetContextGraph()->ForwardOneStep(context_state, new_token);
           context_score = context_res.first;
           context_score = context_res.first;

+ 6 - 8
sherpa-ncnn/csrc/recognizer.cc

@@ -172,11 +172,7 @@ class Recognizer::Impl {
   void DecodeStream(Stream *s) const {
   void DecodeStream(Stream *s) const {
     int32_t segment = model_->Segment();
     int32_t segment = model_->Segment();
     int32_t offset = model_->Offset();
     int32_t offset = model_->Offset();
-    bool has_context_graph = false;
 
 
-    if (!has_context_graph && s->GetContextGraph()) {
-      has_context_graph = true;
-    }
     ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment);
     ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment);
     s->GetNumProcessedFrames() += offset;
     s->GetNumProcessedFrames() += offset;
     std::vector<ncnn::Mat> states = s->GetStates();
     std::vector<ncnn::Mat> states = s->GetStates();
@@ -184,7 +180,7 @@ class Recognizer::Impl {
     ncnn::Mat encoder_out;
     ncnn::Mat encoder_out;
     std::tie(encoder_out, states) = model_->RunEncoder(features, states);
     std::tie(encoder_out, states) = model_->RunEncoder(features, states);
 
 
-    if (has_context_graph) {
+    if (s->GetContextGraph()) {
       decoder_->Decode(encoder_out, s, &s->GetResult());
       decoder_->Decode(encoder_out, s, &s->GetResult());
     } else {
     } else {
       decoder_->Decode(encoder_out, &s->GetResult());
       decoder_->Decode(encoder_out, &s->GetResult());
@@ -216,7 +212,7 @@ class Recognizer::Impl {
     }
     }
     // Caution: We need to keep the decoder output state
     // Caution: We need to keep the decoder output state
     ncnn::Mat decoder_out = s->GetResult().decoder_out;
     ncnn::Mat decoder_out = s->GetResult().decoder_out;
-    s->SetResult(decoder_->GetEmptyResult());
+    s->SetResult(r);
     s->GetResult().decoder_out = decoder_out;
     s->GetResult().decoder_out = decoder_out;
 
 
     // don't reset encoder state
     // don't reset encoder state
@@ -284,8 +280,10 @@ class Recognizer::Impl {
           int32_t number = sym_[word];
           int32_t number = sym_[word];
           tmp.push_back(number);
           tmp.push_back(number);
         } else {
         } else {
-          NCNN_LOGE("Cannot find ID for hotword %s at line: %s", word.c_str(),
-                    line.c_str());
+          NCNN_LOGE(
+              "Cannot find ID for hotword %s at line: %s. (Hint: words on the "
+              "same line are separated by spaces)",
+              word.c_str(), line.c_str());
           exit(-1);
           exit(-1);
         }
         }
       }
       }