瀏覽代碼

support hotwords in C++ (#257)

HalFTeen 1 年之前
父節點
當前提交
0f8e46db9b

+ 45 - 0
.github/scripts/run-test.sh

@@ -544,3 +544,48 @@ for wave in ${waves[@]}; do
 done
 
 rm -rf $repo
+
+log "------------------------------------------------------------"
+log "Run hotwords test (Chinese)"
+log "------------------------------------------------------------"
+repo_url=https://huggingface.co/HalFTeen/sherpa-ncnn-hotwords-test/
+log "Start testing ${repo_url}"
+repo=$(basename $repo_url)
+log "Download pretrained model and test-data from $repo_url"
+GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
+pushd $repo
+git lfs pull --include "encoder_jit_trace-pnnx.ncnn.bin"
+git lfs pull --include "decoder_jit_trace-pnnx.ncnn.bin"
+git lfs pull --include "joiner_jit_trace-pnnx.ncnn.bin"
+popd
+
+
+log "----test $m without hotwords---"
+time $EXE \
+  $repo/tokens.txt \
+  $repo/encoder_jit_trace-pnnx.ncnn.param \
+  $repo/encoder_jit_trace-pnnx.ncnn.bin \
+  $repo/decoder_jit_trace-pnnx.ncnn.param \
+  $repo/decoder_jit_trace-pnnx.ncnn.bin \
+  $repo/joiner_jit_trace-pnnx.ncnn.param \
+  $repo/joiner_jit_trace-pnnx.ncnn.bin \
+  $repo/hotwords.wav \
+  4 \
+  modified_beam_search
+
+
+log "----test $m with hotwords---"
+time $EXE \
+  $repo/tokens.txt \
+  $repo/encoder_jit_trace-pnnx.ncnn.param \
+  $repo/encoder_jit_trace-pnnx.ncnn.bin \
+  $repo/decoder_jit_trace-pnnx.ncnn.param \
+  $repo/decoder_jit_trace-pnnx.ncnn.bin \
+  $repo/joiner_jit_trace-pnnx.ncnn.param \
+  $repo/joiner_jit_trace-pnnx.ncnn.bin \
+  $repo/hotwords.wav \
+  4 \
+  modified_beam_search \
+  $repo/hotwords.txt 1.6
+
+rm -rf $repo

+ 12 - 4
c-api-examples/decode-file-c-api.c

@@ -40,7 +40,7 @@ const char *kUsage =
     "for a list of pre-trained models to download.\n";
 
 int32_t main(int32_t argc, char *argv[]) {
-  if (argc < 9 || argc > 11) {
+  if (argc < 9 || argc > 13) {
     fprintf(stderr, "%s\n", kUsage);
     return -1;
   }
@@ -62,7 +62,7 @@ int32_t main(int32_t argc, char *argv[]) {
 
   config.decoder_config.decoding_method = "greedy_search";
 
-  if (argc == 11) {
+  if (argc >= 11) {
     config.decoder_config.decoding_method = argv[10];
   }
   config.decoder_config.num_active_paths = 4;
@@ -73,7 +73,16 @@ int32_t main(int32_t argc, char *argv[]) {
 
   config.feat_config.sampling_rate = 16000;
   config.feat_config.feature_dim = 80;
-
+  if(argc >= 12) {
+    config.hotwords_file = argv[11];
+  } else {
+    config.hotwords_file = "";
+  }
+  if(argc == 13) {
+    config.hotwords_score = atof(argv[12]);
+  } else {
+    config.hotwords_score = 1.5;
+  }
   SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config);
 
   const char *wav_filename = argv[8];
@@ -92,7 +101,6 @@ int32_t main(int32_t argc, char *argv[]) {
 
   int16_t buffer[N];
   float samples[N];
-
   SherpaNcnnStream *s = CreateStream(recognizer);
 
   SherpaNcnnDisplay *display = CreateDisplay(50);

+ 2 - 0
sherpa-ncnn/c-api/c-api.cc

@@ -66,6 +66,8 @@ SherpaNcnnRecognizer *CreateRecognizer(
   config.decoder_config.method = in_config->decoder_config.decoding_method;
   config.decoder_config.num_active_paths =
       in_config->decoder_config.num_active_paths;
+  config.hotwords_file = in_config->hotwords_file;
+  config.hotwords_score = in_config->hotwords_score;
 
   config.enable_endpoint = in_config->enable_endpoint;
 

+ 8 - 0
sherpa-ncnn/c-api/c-api.h

@@ -133,6 +133,14 @@ SHERPA_NCNN_API typedef struct SherpaNcnnRecognizerConfig {
   /// this value.
   /// Used only when enable_endpoint is not 0.
   float rule3_min_utterance_length;
+  
+  /// hotwords file, each line is a hotword which is segmented into char by space
+  /// if language is something like CJK, segment manually,
+  /// if language is something like English, segment by bpe model.
+  const char *hotwords_file;
+
+  /// scale of hotwords, used only when hotwords_file is not empty
+  float hotwords_score;
 } SherpaNcnnRecognizerConfig;
 
 SHERPA_NCNN_API typedef struct SherpaNcnnResult {

+ 1 - 0
sherpa-ncnn/csrc/CMakeLists.txt

@@ -1,6 +1,7 @@
 include_directories(${CMAKE_SOURCE_DIR})
 
 set(sherpa_ncnn_core_srcs
+  context-graph.cc
   conv-emformer-model.cc
   decoder.cc
   endpoint.cc

+ 95 - 0
sherpa-ncnn/csrc/context-graph.cc

@@ -0,0 +1,95 @@
+// sherpa-ncnn/csrc/context-graph.cc
+//
+// Copyright (c)  2023  Xiaomi Corporation
+
+#include "sherpa-ncnn/csrc/context-graph.h"
+
+#include <cassert>
+#include <queue>
+#include <utility>
+
+namespace sherpa_ncnn {
+void ContextGraph::Build(
+    const std::vector<std::vector<int32_t>> &token_ids) const {
+  for (int32_t i = 0; i < token_ids.size(); ++i) {
+    auto node = root_.get();
+    for (int32_t j = 0; j < token_ids[i].size(); ++j) {
+      int32_t token = token_ids[i][j];
+      if (0 == node->next.count(token)) {
+        bool is_end = j == token_ids[i].size() - 1;
+        node->next[token] = std::make_unique<ContextState>(
+            token, context_score_, node->node_score + context_score_,
+            is_end ? node->node_score + context_score_ : 0, is_end);
+      }
+      node = node->next[token].get();
+    }
+  }
+  FillFailOutput();
+}
+
+std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
+    const ContextState *state, int32_t token) const {
+  const ContextState *node;
+  float score;
+  if (1 == state->next.count(token)) {
+    node = state->next.at(token).get();
+    score = node->token_score;
+  } else {
+    node = state->fail;
+    while (0 == node->next.count(token)) {
+      node = node->fail;
+      if (-1 == node->token) break;  // root
+    }
+    if (1 == node->next.count(token)) {
+      node = node->next.at(token).get();
+    }
+    score = node->node_score - state->node_score;
+  }
+  return std::make_pair(score + node->output_score, node);
+}
+
+std::pair<float, const ContextState *> ContextGraph::Finalize(
+    const ContextState *state) const {
+  float score = -state->node_score;
+  return std::make_pair(score, root_.get());
+}
+
+void ContextGraph::FillFailOutput() const {
+  std::queue<const ContextState *> node_queue;
+  for (auto &kv : root_->next) {
+    kv.second->fail = root_.get();
+    node_queue.push(kv.second.get());
+  }
+  while (!node_queue.empty()) {
+    auto current_node = node_queue.front();
+    node_queue.pop();
+    for (auto &kv : current_node->next) {
+      auto fail = current_node->fail;
+      if (1 == fail->next.count(kv.first)) {
+        fail = fail->next.at(kv.first).get();
+      } else {
+        fail = fail->fail;
+        while (0 == fail->next.count(kv.first)) {
+          fail = fail->fail;
+          if (-1 == fail->token) break;
+        }
+        if (1 == fail->next.count(kv.first))
+          fail = fail->next.at(kv.first).get();
+      }
+      kv.second->fail = fail;
+      // fill the output arc
+      auto output = fail;
+      while (!output->is_end) {
+        output = output->fail;
+        if (-1 == output->token) {
+          output = nullptr;
+          break;
+        }
+      }
+      kv.second->output = output;
+      kv.second->output_score += output == nullptr ? 0 : output->output_score;
+      node_queue.push(kv.second.get());
+    }
+  }
+}
+}  // namespace sherpa_ncnn

+ 65 - 0
sherpa-ncnn/csrc/context-graph.h

@@ -0,0 +1,65 @@
+// sherpa-ncnn/csrc/context-graph.h
+//
+// Copyright (c)  2023  Xiaomi Corporation
+
+#ifndef SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_
+#define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_
+
+#include <memory>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+
+namespace sherpa_ncnn {
+
+class ContextGraph;
+using ContextGraphPtr = std::shared_ptr<ContextGraph>;
+
+struct ContextState {
+  int32_t token;
+  float token_score;
+  float node_score;
+  float output_score;
+  bool is_end;
+  std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
+  const ContextState *fail = nullptr;
+  const ContextState *output = nullptr;
+
+  ContextState() = default;
+  ContextState(int32_t token, float token_score, float node_score,
+               float output_score, bool is_end)
+      : token(token),
+        token_score(token_score),
+        node_score(node_score),
+        output_score(output_score),
+        is_end(is_end) {}
+};
+
+class ContextGraph {
+ public:
+  ContextGraph() = default;
+  ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
+               float hotwords_score)
+      : context_score_(hotwords_score) {
+    root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
+    root_->fail = root_.get();
+    Build(token_ids);
+  }
+
+  std::pair<float, const ContextState *> ForwardOneStep(
+      const ContextState *state, int32_t token_id) const;
+  std::pair<float, const ContextState *> Finalize(
+      const ContextState *state) const;
+
+  const ContextState *Root() const { return root_.get(); }
+
+ private:
+  float context_score_;
+  std::unique_ptr<ContextState> root_;
+  void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
+  void FillFailOutput() const;
+};
+
+}  // namespace sherpa_ncnn
+#endif  // SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_

+ 2 - 1
sherpa-ncnn/csrc/decoder.h

@@ -59,7 +59,7 @@ struct DecoderResult {
   // used only for modified_beam_search
   Hypotheses hyps;
 };
-
+class Stream;
 class Decoder {
  public:
   virtual ~Decoder() = default;
@@ -88,6 +88,7 @@ class Decoder {
    * and there are no paddings.
    */
   virtual void Decode(ncnn::Mat encoder_out, DecoderResult *result) = 0;
+  virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result){};
 };
 
 }  // namespace sherpa_ncnn

+ 7 - 3
sherpa-ncnn/csrc/hypothesis.h

@@ -24,6 +24,7 @@
 #include <unordered_map>
 #include <utility>
 #include <vector>
+#include "sherpa-ncnn/csrc/context-graph.h"
 
 namespace sherpa_ncnn {
 
@@ -37,12 +38,13 @@ struct Hypothesis {
 
   // The total score of ys in log space.
   double log_prob = 0;
-
+  const ContextState *context_state;
   int32_t num_trailing_blanks = 0;
 
   Hypothesis() = default;
-  Hypothesis(const std::vector<int32_t> &ys, double log_prob)
-      : ys(ys), log_prob(log_prob) {}
+  Hypothesis(const std::vector<int32_t> &ys, double log_prob,
+            const ContextState *context_state = nullptr)
+      : ys(ys), log_prob(log_prob), context_state(context_state) {}
 
   // If two Hypotheses have the same `Key`, then they contain
   // the same token sequence.
@@ -104,6 +106,8 @@ class Hypotheses {
 
   const auto begin() const { return hyps_dict_.begin(); }
   const auto end() const { return hyps_dict_.end(); }
+  auto begin() { return hyps_dict_.begin(); }
+  auto end() { return hyps_dict_.end(); }
 
   void Clear() { hyps_dict_.clear(); }
 

+ 88 - 1
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -37,7 +37,7 @@ DecoderResult ModifiedBeamSearchDecoder::GetEmptyResult() const {
   Hypotheses blank_hyp({{blanks, 0}});
 
   r.hyps = std::move(blank_hyp);
-
+  r.tokens = std::move(blanks);
   return r;
 }
 
@@ -195,4 +195,91 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
   result->num_trailing_blanks = hyp.num_trailing_blanks;
 }
 
+
+void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
+                                       DecoderResult *result) {
+  int32_t context_size = model_->ContextSize();
+  Hypotheses cur = std::move(result->hyps);
+  /* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */
+  for (int32_t t = 0; t != encoder_out.h; ++t) {
+    std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
+    cur.Clear();
+
+
+    ncnn::Mat decoder_input = BuildDecoderInput(prev);
+    ncnn::Mat decoder_out;
+    if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
+        !result->decoder_out.empty()) {
+      // When an endpoint is detected, we keep the decoder_out
+      decoder_out = result->decoder_out;
+    } else {
+      decoder_out = RunDecoder2D(model_, decoder_input);
+    }
+
+    // decoder_out.w == decoder_dim
+    // decoder_out.h == num_active_paths
+	ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
+
+    ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
+    // joiner_out.w == vocab_size
+    // joiner_out.h == num_active_paths
+    LogSoftmax(&joiner_out);
+
+
+    float *p_joiner_out = joiner_out;
+
+    for (int32_t i = 0; i != joiner_out.h; ++i) {
+      float prev_log_prob = prev[i].log_prob;
+      for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
+        *p_joiner_out += prev_log_prob;
+      }
+    }
+
+    auto topk = TopkIndex(static_cast<float *>(joiner_out),
+                          joiner_out.w * joiner_out.h, num_active_paths_);
+
+    int32_t frame_offset = result->frame_offset;
+    for (auto i : topk) {
+      int32_t hyp_index = i / joiner_out.w;
+      int32_t new_token = i % joiner_out.w;
+
+      const float *p = joiner_out.row(hyp_index);
+
+      Hypothesis new_hyp = prev[hyp_index];
+      // const float prev_lm_log_prob = new_hyp.lm_log_prob;
+      float context_score = 0;
+      auto context_state = new_hyp.context_state;
+      // blank id is fixed to 0
+      if (new_token != 0) {
+        new_hyp.ys.push_back(new_token);
+        new_hyp.num_trailing_blanks = 0;
+        new_hyp.timestamps.push_back(t + frame_offset);
+        if (s != nullptr && s->GetContextGraph() != nullptr) {
+          auto context_res = s->GetContextGraph()->ForwardOneStep(
+              context_state, new_token);
+          context_score = context_res.first;
+          new_hyp.context_state = context_res.second;
+        }
+      } else {
+        ++new_hyp.num_trailing_blanks;
+      }
+      // We have already added prev[hyp_index].log_prob to p[new_token]
+      new_hyp.log_prob = p[new_token] + context_score;
+
+      cur.Add(std::move(new_hyp));
+    }
+  }
+
+  result->hyps = std::move(cur);
+  result->frame_offset += encoder_out.h;
+  auto hyp = result->hyps.GetMostProbable(true);
+
+  // set decoder_out in case of endpointing
+  ncnn::Mat decoder_input = BuildDecoderInput({hyp});
+  result->decoder_out = model_->RunDecoder(decoder_input);
+
+  result->tokens = std::move(hyp.ys);
+  result->num_trailing_blanks = hyp.num_trailing_blanks;
+}
+
 }  // namespace sherpa_ncnn

+ 3 - 0
sherpa-ncnn/csrc/modified-beam-search-decoder.h

@@ -25,6 +25,8 @@
 #include "mat.h"  // NOLINT
 #include "sherpa-ncnn/csrc/decoder.h"
 #include "sherpa-ncnn/csrc/model.h"
+#include "sherpa-ncnn/csrc/stream.h"
+#include "sherpa-ncnn/csrc/context-graph.h"
 
 namespace sherpa_ncnn {
 
@@ -38,6 +40,7 @@ class ModifiedBeamSearchDecoder : public Decoder {
   void StripLeadingBlanks(DecoderResult *r) const override;
 
   void Decode(ncnn::Mat encoder_out, DecoderResult *result) override;
+  void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result) override;
 
  private:
   ncnn::Mat BuildDecoderInput(const std::vector<Hypothesis> &hyps) const;

+ 97 - 8
sherpa-ncnn/csrc/recognizer.cc

@@ -18,7 +18,8 @@
  */
 
 #include "sherpa-ncnn/csrc/recognizer.h"
-
+#include <iostream>
+#include <fstream>
 #include <memory>
 #include <string>
 #include <utility>
@@ -75,8 +76,12 @@ std::string RecognizerConfig::ToString() const {
   os << "feat_config=" << feat_config.ToString() << ", ";
   os << "model_config=" << model_config.ToString() << ", ";
   os << "decoder_config=" << decoder_config.ToString() << ", ";
+  os << "max_active_paths=" << max_active_paths << ", ";
   os << "endpoint_config=" << endpoint_config.ToString() << ", ";
-  os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
+  os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
+  os << "hotwords_file=\"" << hotwords_file << "\", ";
+  os << "hotwrods_score=" << hotwords_score << ", ";
+  os << "decoding_method=\"" << decoding_method << "\")";
 
   return os.str();
 }
@@ -93,6 +98,30 @@ class Recognizer::Impl {
     } else if (config.decoder_config.method == "modified_beam_search") {
       decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
           model_.get(), config.decoder_config.num_active_paths);
+      std::vector<int32_t> tmp;
+      /*each line in hotwords file is a string which is segmented by space*/
+      std::ifstream file(config_.hotwords_file);
+      if (file) {
+        std::string line;
+        std::string word;
+        while (std::getline(file, line)) {
+          std::istringstream iss(line);
+          while(iss >> word){
+            if (sym_.contains(word)) {
+              int number = sym_[word];
+              tmp.push_back(number);
+            } else {
+              NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
+              exit(-1);
+            }
+          }
+          hotwords_.push_back(tmp);
+          tmp.clear();
+        }
+      } else {
+        NCNN_LOGE("open file failed: %s, hotwords will not be used", 
+                 config_.hotwords_file.c_str());
+      }
     } else {
       NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
       exit(-1);
@@ -110,6 +139,30 @@ class Recognizer::Impl {
     } else if (config.decoder_config.method == "modified_beam_search") {
       decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
           model_.get(), config.decoder_config.num_active_paths);
+      std::vector<int32_t> tmp;
+      /*each line in hotwords file is a string which is segmented by space*/
+      std::ifstream file(config_.hotwords_file);
+      if (file) {
+        std::string line;
+        std::string word;
+        while (std::getline(file, line)) {
+          std::istringstream iss(line);
+          while(iss >> word){
+            if (sym_.contains(word)) {
+              int number = sym_[word];
+              tmp.push_back(number);
+            } else {
+              NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
+              exit(-1);
+            }
+          }
+          hotwords_.push_back(tmp);
+          tmp.clear();
+        }
+      } else {
+        NCNN_LOGE("open file failed: %s, hotwords will not be used", 
+                 config_.hotwords_file.c_str());
+      }
     } else {
       NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
       exit(-1);
@@ -118,10 +171,29 @@ class Recognizer::Impl {
 #endif
 
   std::unique_ptr<Stream> CreateStream() const {
-    auto stream = std::make_unique<Stream>(config_.feat_config);
-    stream->SetResult(decoder_->GetEmptyResult());
-    stream->SetStates(model_->GetEncoderInitStates());
-    return stream;
+    if(hotwords_.empty()) {
+      auto stream = std::make_unique<Stream>(config_.feat_config);
+      stream->SetResult(decoder_->GetEmptyResult());
+      stream->SetStates(model_->GetEncoderInitStates());
+      return stream;
+    } else {
+      auto r = decoder_->GetEmptyResult();
+      auto context_graph =
+          std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
+      auto stream =
+          std::make_unique<Stream>(config_.feat_config, context_graph);
+      if (config_.decoder_config.method == "modified_beam_search" &&
+          nullptr != stream->GetContextGraph()) {
+        std::cout<<"create contexts stream"<<std::endl;
+        // r.hyps has only one element.
+        for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
+          it->second.context_state = stream->GetContextGraph()->Root();
+        }
+      }
+      stream->SetResult(r);
+      stream->SetStates(model_->GetEncoderInitStates());
+      return stream;
+    }
   }
 
   bool IsReady(Stream *s) const {
@@ -131,16 +203,24 @@ class Recognizer::Impl {
   void DecodeStream(Stream *s) const {
     int32_t segment = model_->Segment();
     int32_t offset = model_->Offset();
+    bool has_context_graph = false;
 
+    if (!has_context_graph && s->GetContextGraph()) {
+      has_context_graph = true;
+    }
     ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment);
     s->GetNumProcessedFrames() += offset;
     std::vector<ncnn::Mat> states = s->GetStates();
 
     ncnn::Mat encoder_out;
     std::tie(encoder_out, states) = model_->RunEncoder(features, states);
-    s->SetStates(states);
 
-    decoder_->Decode(encoder_out, &s->GetResult());
+    if (has_context_graph) {
+      decoder_->Decode(encoder_out, s, &s->GetResult());
+    } else {
+      decoder_->Decode(encoder_out, &s->GetResult());
+    }
+    s->SetStates(states);
   }
 
   bool IsEndpoint(Stream *s) const {
@@ -158,6 +238,13 @@ class Recognizer::Impl {
   }
 
   void Reset(Stream *s) const {
+    auto r = decoder_->GetEmptyResult();
+    if (config_.decoding_method == "modified_beam_search" &&
+        nullptr != s->GetContextGraph()) {
+      for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
+        it->second.context_state = s->GetContextGraph()->Root();
+      }
+    }
     // Caution: We need to keep the decoder output state
     ncnn::Mat decoder_out = s->GetResult().decoder_out;
     s->SetResult(decoder_->GetEmptyResult());
@@ -190,6 +277,7 @@ class Recognizer::Impl {
   std::unique_ptr<Decoder> decoder_;
   Endpoint endpoint_;
   SymbolTable sym_;
+  std::vector<std::vector<int32_t>> hotwords_;
 };
 
 Recognizer::Recognizer(const RecognizerConfig &config)
@@ -206,6 +294,7 @@ std::unique_ptr<Stream> Recognizer::CreateStream() const {
   return impl_->CreateStream();
 }
 
+
 bool Recognizer::IsReady(Stream *s) const { return impl_->IsReady(s); }
 
 void Recognizer::DecodeStream(Stream *s) const { impl_->DecodeStream(s); }

+ 15 - 4
sherpa-ncnn/csrc/recognizer.h

@@ -48,21 +48,32 @@ struct RecognizerConfig {
   FeatureExtractorConfig feat_config;
   ModelConfig model_config;
   DecoderConfig decoder_config;
-
+  std::string decoding_method;
+  std::string hotwords_file;
   EndpointConfig endpoint_config;
   bool enable_endpoint = false;
-
+  // used only for modified_beam_search
+  int32_t max_active_paths = 4;
+  /// used only for modified_beam_search
+  float hotwords_score = 1.5;
   RecognizerConfig() = default;
 
   RecognizerConfig(const FeatureExtractorConfig &feat_config,
                    const ModelConfig &model_config,
                    const DecoderConfig decoder_config,
-                   const EndpointConfig &endpoint_config, bool enable_endpoint)
+                   const EndpointConfig &endpoint_config, bool enable_endpoint,
+                   const std::string &decoding_method,
+                   const std::string &hotwords_file,
+                   int32_t max_active_paths, float hotwords_score)
       : feat_config(feat_config),
         model_config(model_config),
         decoder_config(decoder_config),
         endpoint_config(endpoint_config),
-        enable_endpoint(enable_endpoint) {}
+        enable_endpoint(enable_endpoint),
+        decoding_method(decoding_method),
+        hotwords_file(hotwords_file),
+        max_active_paths(max_active_paths),
+        hotwords_score(hotwords_score) {}
 
   std::string ToString() const;
 };

+ 14 - 5
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -18,7 +18,7 @@
  */
 
 #include <stdio.h>
-
+#include <fstream>
 #include <algorithm>
 #include <chrono>  // NOLINT
 #include <iostream>
@@ -28,7 +28,7 @@
 #include "sherpa-ncnn/csrc/wave-reader.h"
 
 int32_t main(int32_t argc, char *argv[]) {
-  if (argc < 9 || argc > 11) {
+  if (argc < 9 || argc > 13) {
     const char *usage = R"usage(
 Usage:
   ./bin/sherpa-ncnn \
@@ -66,13 +66,23 @@ for a list of pre-trained models to download.
   config.model_config.joiner_opt.num_threads = num_threads;
 
   float expected_sampling_rate = 16000;
-  if (argc == 11) {
+  if (argc >= 11) {
     std::string method = argv[10];
     if (method == "greedy_search" || method == "modified_beam_search") {
       config.decoder_config.method = method;
     }
   }
-
+  std::cout<<"decode method:"<<config.decoder_config.method<<std::endl;
+  if(argc >= 12) {
+	config.hotwords_file = argv[11];
+  } else {
+    config.hotwords_file = "";
+  }
+  if(argc == 13) {
+    config.hotwords_score = atof(argv[12]);
+  } else {
+    config.hotwords_file = 1.5;
+  }
   config.feat_config.sampling_rate = expected_sampling_rate;
   config.feat_config.feature_dim = 80;
 
@@ -96,7 +106,6 @@ for a list of pre-trained models to download.
 
   auto begin = std::chrono::steady_clock::now();
   std::cout << "Started!\n";
-
   auto stream = recognizer.CreateStream();
   stream->AcceptWaveform(expected_sampling_rate, samples.data(),
                          samples.size());

+ 10 - 4
sherpa-ncnn/csrc/stream.cc

@@ -22,8 +22,8 @@ namespace sherpa_ncnn {
 
 class Stream::Impl {
  public:
-  explicit Impl(const FeatureExtractorConfig &config)
-      : feat_extractor_(config) {}
+  explicit Impl(const FeatureExtractorConfig &config,ContextGraphPtr context_graph)
+      : feat_extractor_(config), context_graph_(context_graph) {}
 
   void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
     feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
@@ -62,16 +62,19 @@ class Stream::Impl {
 
   std::vector<ncnn::Mat> &GetStates() { return states_; }
 
+  const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
+
  private:
   FeatureExtractor feat_extractor_;
+  ContextGraphPtr context_graph_;
   int32_t num_processed_frames_ = 0;  // before subsampling
   int32_t start_frame_index_ = 0;
   DecoderResult result_;
   std::vector<ncnn::Mat> states_;
 };
 
-Stream::Stream(const FeatureExtractorConfig &config)
-    : impl_(std::make_unique<Impl>(config)) {}
+Stream::Stream(const FeatureExtractorConfig &config, ContextGraphPtr context_graph)
+    : impl_(std::make_unique<Impl>(config, context_graph)) {}
 
 Stream::~Stream() = default;
 
@@ -108,4 +111,7 @@ void Stream::SetStates(const std::vector<ncnn::Mat> &states) {
 
 std::vector<ncnn::Mat> &Stream::GetStates() { return impl_->GetStates(); }
 
+const ContextGraphPtr &Stream::GetContextGraph() const {
+  return impl_->GetContextGraph();
+  }
 }  // namespace sherpa_ncnn

+ 9 - 1
sherpa-ncnn/csrc/stream.h

@@ -24,11 +24,13 @@
 
 #include "sherpa-ncnn/csrc/decoder.h"
 #include "sherpa-ncnn/csrc/features.h"
+#include "sherpa-ncnn/csrc/context-graph.h"
 
 namespace sherpa_ncnn {
 class Stream {
  public:
-  explicit Stream(const FeatureExtractorConfig &config);
+  explicit Stream(const FeatureExtractorConfig &config = {},
+				ContextGraphPtr context_graph = nullptr);
   ~Stream();
 
   /**
@@ -80,6 +82,12 @@ class Stream {
 
   void SetStates(const std::vector<ncnn::Mat> &states);
   std::vector<ncnn::Mat> &GetStates();
+  /**
+   * Get the context graph corresponding to this stream.
+   *
+   * @return Return the context graph for this stream.
+   */
+  const ContextGraphPtr &GetContextGraph() const;
 
  private:
   class Impl;