Parcourir la source

Add hotwords support to Swift/Go/Python/C#/Kotlin APIs (#260)

Fangjun Kuang il y a 1 an
Parent
commit
401de81194

+ 1 - 1
CMakeLists.txt

@@ -1,7 +1,7 @@
 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
 project(sherpa-ncnn)
 
-set(SHERPA_NCNN_VERSION "2.0.7")
+set(SHERPA_NCNN_VERSION "2.1.0")
 
 # Disable warning about
 #

+ 2 - 0
android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/SherpaNcnn.kt

@@ -33,6 +33,8 @@ data class RecognizerConfig(
     var rule1MinTrailingSilence: Float = 2.4f,
     var rule2MinTrailingSilence: Float = 1.0f,
     var rule3MinUtteranceLength: Float = 30.0f,
+    var hotwordsFile: String = "",
+    var hotwordsScore: Float = 1.5f,
 )
 
 class SherpaNcnn(

+ 8 - 6
c-api-examples/decode-file-c-api.c

@@ -45,6 +45,8 @@ int32_t main(int32_t argc, char *argv[]) {
     return -1;
   }
   SherpaNcnnRecognizerConfig config;
+  memset(&config, 0, sizeof(config));
+
   config.model_config.tokens = argv[1];
   config.model_config.encoder_param = argv[2];
   config.model_config.encoder_bin = argv[3];
@@ -57,6 +59,7 @@ int32_t main(int32_t argc, char *argv[]) {
   if (argc >= 10 && atoi(argv[9]) > 0) {
     num_threads = atoi(argv[9]);
   }
+
   config.model_config.num_threads = num_threads;
   config.model_config.use_vulkan_compute = 0;
 
@@ -65,6 +68,7 @@ int32_t main(int32_t argc, char *argv[]) {
   if (argc >= 11) {
     config.decoder_config.decoding_method = argv[10];
   }
+
   config.decoder_config.num_active_paths = 4;
   config.enable_endpoint = 0;
   config.rule1_min_trailing_silence = 2.4;
@@ -73,16 +77,14 @@ int32_t main(int32_t argc, char *argv[]) {
 
   config.feat_config.sampling_rate = 16000;
   config.feat_config.feature_dim = 80;
-  if(argc >= 12) {
+  if (argc >= 12) {
     config.hotwords_file = argv[11];
-  } else {
-    config.hotwords_file = "";
   }
-  if(argc == 13) {
+
+  if (argc == 13) {
     config.hotwords_score = atof(argv[12]);
-  } else {
-    config.hotwords_score = 1.5;
   }
+
   SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config);
 
   const char *wav_filename = argv[8];

+ 5 - 0
scripts/dotnet/sherpa-ncnn.cs

@@ -63,6 +63,11 @@ namespace SherpaNcnn
         public float Rule1MinTrailingSilence;
         public float Rule2MinTrailingSilence;
         public float Rule3MinUtteranceLength;
+
+        [MarshalAs(UnmanagedType.LPStr)]
+        public string HotwordsFile;
+
+        public float HotwordsScore;
     }
 
     // please see

+ 8 - 0
scripts/go/sherpa_ncnn.go

@@ -84,6 +84,9 @@ type RecognizerConfig struct {
 	Rule1MinTrailingSilence float32
 	Rule2MinTrailingSilence float32
 	Rule3MinUtteranceLength float32
+
+	HotwordsFile  string
+	HotwordsScore float32
 }
 
 // It contains the recognition result for a online stream.
@@ -148,6 +151,11 @@ func NewRecognizer(config *RecognizerConfig) *Recognizer {
 	c.rule2_min_trailing_silence = C.float(config.Rule2MinTrailingSilence)
 	c.rule3_min_utterance_length = C.float(config.Rule3MinUtteranceLength)
 
+	c.hotwords_file = C.CString(config.HotwordsFile)
+	defer C.free(unsafe.Pointer(c.hotwords_file))
+
+	c.hotwords_score = C.float(config.HotwordsScore)
+
 	recognizer := &Recognizer{}
 	recognizer.impl = C.CreateRecognizer(&c)
 

+ 14 - 5
sherpa-ncnn/c-api/c-api.cc

@@ -39,6 +39,8 @@ struct SherpaNcnnDisplay {
   std::unique_ptr<sherpa_ncnn::Display> impl;
 };
 
+#define SHERPA_NCNN_OR(x, y) (x ? x : y)
+
 SherpaNcnnRecognizer *CreateRecognizer(
     const SherpaNcnnRecognizerConfig *in_config) {
   // model_config
@@ -56,7 +58,7 @@ SherpaNcnnRecognizer *CreateRecognizer(
   config.model_config.use_vulkan_compute =
       in_config->model_config.use_vulkan_compute;
 
-  int32_t num_threads = in_config->model_config.num_threads;
+  int32_t num_threads = SHERPA_NCNN_OR(in_config->model_config.num_threads, 1);
 
   config.model_config.encoder_opt.num_threads = num_threads;
   config.model_config.decoder_opt.num_threads = num_threads;
@@ -66,8 +68,9 @@ SherpaNcnnRecognizer *CreateRecognizer(
   config.decoder_config.method = in_config->decoder_config.decoding_method;
   config.decoder_config.num_active_paths =
       in_config->decoder_config.num_active_paths;
-  config.hotwords_file = in_config->hotwords_file;
-  config.hotwords_score = in_config->hotwords_score;
+
+  config.hotwords_file = SHERPA_NCNN_OR(in_config->hotwords_file, "");
+  config.hotwords_score = SHERPA_NCNN_OR(in_config->hotwords_score, 1.5);
 
   config.enable_endpoint = in_config->enable_endpoint;
 
@@ -80,11 +83,17 @@ SherpaNcnnRecognizer *CreateRecognizer(
   config.endpoint_config.rule3.min_utterance_length =
       in_config->rule3_min_utterance_length;
 
-  config.feat_config.sampling_rate = in_config->feat_config.sampling_rate;
-  config.feat_config.feature_dim = in_config->feat_config.feature_dim;
+  config.feat_config.sampling_rate =
+      SHERPA_NCNN_OR(in_config->feat_config.sampling_rate, 16000);
+
+  config.feat_config.feature_dim =
+      SHERPA_NCNN_OR(in_config->feat_config.feature_dim, 80);
 
   auto recognizer = std::make_unique<sherpa_ncnn::Recognizer>(config);
+
   if (!recognizer->GetModel()) {
+    NCNN_LOGE("Failed to create the recognizer! Please check your config: %s",
+              config.ToString().c_str());
     return nullptr;
   }
 

+ 4 - 4
sherpa-ncnn/c-api/c-api.h

@@ -133,10 +133,10 @@ SHERPA_NCNN_API typedef struct SherpaNcnnRecognizerConfig {
   /// this value.
   /// Used only when enable_endpoint is not 0.
   float rule3_min_utterance_length;
-  
-  /// hotwords file, each line is a hotword which is segmented into char by space
-  /// if language is something like CJK, segment manually,
-  /// if language is something like English, segment by bpe model.
+
+  /// hotwords file, each line is a hotword which is segmented into char by
+  /// space if language is something like CJK, segment manually, if language is
+  /// something like English, segment by bpe model.
   const char *hotwords_file;
 
   /// scale of hotwords, used only when hotwords_file is not empty

+ 7 - 1
sherpa-ncnn/csrc/decoder.h

@@ -59,7 +59,9 @@ struct DecoderResult {
   // used only for modified_beam_search
   Hypotheses hyps;
 };
+
 class Stream;
+
 class Decoder {
  public:
   virtual ~Decoder() = default;
@@ -88,7 +90,11 @@ class Decoder {
    * and there are no paddings.
    */
   virtual void Decode(ncnn::Mat encoder_out, DecoderResult *result) = 0;
-  virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result){};
+
+  virtual void Decode(ncnn::Mat encoder_out, Stream *s, DecoderResult *result) {
+    NCNN_LOGE("Please override it!");
+    exit(-1);
+  }
 };
 
 }  // namespace sherpa_ncnn

+ 2 - 1
sherpa-ncnn/csrc/hypothesis.h

@@ -24,6 +24,7 @@
 #include <unordered_map>
 #include <utility>
 #include <vector>
+
 #include "sherpa-ncnn/csrc/context-graph.h"
 
 namespace sherpa_ncnn {
@@ -43,7 +44,7 @@ struct Hypothesis {
 
   Hypothesis() = default;
   Hypothesis(const std::vector<int32_t> &ys, double log_prob,
-            const ContextState *context_state = nullptr)
+             const ContextState *context_state = nullptr)
       : ys(ys), log_prob(log_prob), context_state(context_state) {}
 
   // If two Hypotheses have the same `Key`, then they contain

+ 3 - 6
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -195,7 +195,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
   result->num_trailing_blanks = hyp.num_trailing_blanks;
 }
 
-
 void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
                                        DecoderResult *result) {
   int32_t context_size = model_->ContextSize();
@@ -205,7 +204,6 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
     std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
     cur.Clear();
 
-
     ncnn::Mat decoder_input = BuildDecoderInput(prev);
     ncnn::Mat decoder_out;
     if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
@@ -218,14 +216,13 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
 
     // decoder_out.w == decoder_dim
     // decoder_out.h == num_active_paths
-	ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
+    ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
 
     ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
     // joiner_out.w == vocab_size
     // joiner_out.h == num_active_paths
     LogSoftmax(&joiner_out);
 
-
     float *p_joiner_out = joiner_out;
 
     for (int32_t i = 0; i != joiner_out.h; ++i) {
@@ -255,8 +252,8 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
         new_hyp.num_trailing_blanks = 0;
         new_hyp.timestamps.push_back(t + frame_offset);
         if (s != nullptr && s->GetContextGraph() != nullptr) {
-          auto context_res = s->GetContextGraph()->ForwardOneStep(
-              context_state, new_token);
+          auto context_res =
+              s->GetContextGraph()->ForwardOneStep(context_state, new_token);
           context_score = context_res.first;
           new_hyp.context_state = context_res.second;
         }

+ 79 - 57
sherpa-ncnn/csrc/recognizer.cc

@@ -18,7 +18,7 @@
  */
 
 #include "sherpa-ncnn/csrc/recognizer.h"
-#include <iostream>
+
 #include <fstream>
 #include <memory>
 #include <string>
@@ -29,6 +29,14 @@
 #include "sherpa-ncnn/csrc/greedy-search-decoder.h"
 #include "sherpa-ncnn/csrc/modified-beam-search-decoder.h"
 
+#if __ANDROID_API__ >= 9
+#include <strstream>
+
+#include "android/asset_manager.h"
+#include "android/asset_manager_jni.h"
+#include "android/log.h"
+#endif
+
 namespace sherpa_ncnn {
 
 static RecognitionResult Convert(const DecoderResult &src,
@@ -76,12 +84,10 @@ std::string RecognizerConfig::ToString() const {
   os << "feat_config=" << feat_config.ToString() << ", ";
   os << "model_config=" << model_config.ToString() << ", ";
   os << "decoder_config=" << decoder_config.ToString() << ", ";
-  os << "max_active_paths=" << max_active_paths << ", ";
   os << "endpoint_config=" << endpoint_config.ToString() << ", ";
   os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
   os << "hotwords_file=\"" << hotwords_file << "\", ";
-  os << "hotwrods_score=" << hotwords_score << ", ";
-  os << "decoding_method=\"" << decoding_method << "\")";
+  os << "hotwrods_score=" << hotwords_score << ")";
 
   return os.str();
 }
@@ -98,29 +104,9 @@ class Recognizer::Impl {
     } else if (config.decoder_config.method == "modified_beam_search") {
       decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
           model_.get(), config.decoder_config.num_active_paths);
-      std::vector<int32_t> tmp;
-      /*each line in hotwords file is a string which is segmented by space*/
-      std::ifstream file(config_.hotwords_file);
-      if (file) {
-        std::string line;
-        std::string word;
-        while (std::getline(file, line)) {
-          std::istringstream iss(line);
-          while(iss >> word){
-            if (sym_.contains(word)) {
-              int number = sym_[word];
-              tmp.push_back(number);
-            } else {
-              NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
-              exit(-1);
-            }
-          }
-          hotwords_.push_back(tmp);
-          tmp.clear();
-        }
-      } else {
-        NCNN_LOGE("open file failed: %s, hotwords will not be used", 
-                 config_.hotwords_file.c_str());
+
+      if (!config_.hotwords_file.empty()) {
+        InitHotwords();
       }
     } else {
       NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
@@ -139,29 +125,9 @@ class Recognizer::Impl {
     } else if (config.decoder_config.method == "modified_beam_search") {
       decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
           model_.get(), config.decoder_config.num_active_paths);
-      std::vector<int32_t> tmp;
-      /*each line in hotwords file is a string which is segmented by space*/
-      std::ifstream file(config_.hotwords_file);
-      if (file) {
-        std::string line;
-        std::string word;
-        while (std::getline(file, line)) {
-          std::istringstream iss(line);
-          while(iss >> word){
-            if (sym_.contains(word)) {
-              int number = sym_[word];
-              tmp.push_back(number);
-            } else {
-              NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
-              exit(-1);
-            }
-          }
-          hotwords_.push_back(tmp);
-          tmp.clear();
-        }
-      } else {
-        NCNN_LOGE("open file failed: %s, hotwords will not be used", 
-                 config_.hotwords_file.c_str());
+
+      if (!config_.hotwords_file.empty()) {
+        InitHotwords(mgr);
       }
     } else {
       NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
@@ -171,27 +137,30 @@ class Recognizer::Impl {
 #endif
 
   std::unique_ptr<Stream> CreateStream() const {
-    if(hotwords_.empty()) {
+    if (hotwords_.empty()) {
       auto stream = std::make_unique<Stream>(config_.feat_config);
       stream->SetResult(decoder_->GetEmptyResult());
       stream->SetStates(model_->GetEncoderInitStates());
       return stream;
     } else {
       auto r = decoder_->GetEmptyResult();
+
       auto context_graph =
           std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
+
       auto stream =
           std::make_unique<Stream>(config_.feat_config, context_graph);
-      if (config_.decoder_config.method == "modified_beam_search" &&
-          nullptr != stream->GetContextGraph()) {
-        std::cout<<"create contexts stream"<<std::endl;
+
+      if (stream->GetContextGraph()) {
         // r.hyps has only one element.
         for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
           it->second.context_state = stream->GetContextGraph()->Root();
         }
       }
+
       stream->SetResult(r);
       stream->SetStates(model_->GetEncoderInitStates());
+
       return stream;
     }
   }
@@ -239,8 +208,8 @@ class Recognizer::Impl {
 
   void Reset(Stream *s) const {
     auto r = decoder_->GetEmptyResult();
-    if (config_.decoding_method == "modified_beam_search" &&
-        nullptr != s->GetContextGraph()) {
+
+    if (s->GetContextGraph()) {
       for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
         it->second.context_state = s->GetContextGraph()->Root();
       }
@@ -271,6 +240,60 @@ class Recognizer::Impl {
 
   const Model *GetModel() const { return model_.get(); }
 
+ private:
+#if __ANDROID_API__ >= 9
+  void InitHotwords(AAssetManager *mgr) {
+    AAsset *asset = AAssetManager_open(mgr, config_.hotwords_file.c_str(),
+                                       AASSET_MODE_BUFFER);
+    if (!asset) {
+      __android_log_print(ANDROID_LOG_FATAL, "sherpa-ncnn",
+                          "hotwords_file: Load %s failed",
+                          config_.hotwords_file.c_str());
+      exit(-1);
+    }
+
+    auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
+    size_t asset_length = AAsset_getLength(asset);
+    std::istrstream is(p, asset_length);
+    InitHotwords(is);
+    AAsset_close(asset);
+  }
+#endif
+
+  void InitHotwords() {
+    // each line in hotwords_file contains space-separated words
+
+    std::ifstream is(config_.hotwords_file);
+    if (!is) {
+      NCNN_LOGE("Open hotwords file failed: %s", config_.hotwords_file.c_str());
+      exit(-1);
+    }
+
+    InitHotwords(is);
+  }
+
+  void InitHotwords(std::istream &is) {
+    std::vector<int32_t> tmp;
+    std::string line;
+    std::string word;
+
+    while (std::getline(is, line)) {
+      std::istringstream iss(line);
+      while (iss >> word) {
+        if (sym_.contains(word)) {
+          int32_t number = sym_[word];
+          tmp.push_back(number);
+        } else {
+          NCNN_LOGE("Cannot find ID for hotword %s at line: %s", word.c_str(),
+                    line.c_str());
+          exit(-1);
+        }
+      }
+
+      hotwords_.push_back(std::move(tmp));
+    }
+  }
+
  private:
   RecognizerConfig config_;
   std::unique_ptr<Model> model_;
@@ -294,7 +317,6 @@ std::unique_ptr<Stream> Recognizer::CreateStream() const {
   return impl_->CreateStream();
 }
 
-
 bool Recognizer::IsReady(Stream *s) const { return impl_->IsReady(s); }
 
 void Recognizer::DecodeStream(Stream *s) const { impl_->DecodeStream(s); }

+ 10 - 9
sherpa-ncnn/csrc/recognizer.h

@@ -31,6 +31,11 @@
 #include "sherpa-ncnn/csrc/stream.h"
 #include "sherpa-ncnn/csrc/symbol-table.h"
 
+#if __ANDROID_API__ >= 9
+#include "android/asset_manager.h"
+#include "android/asset_manager_jni.h"
+#endif
+
 namespace sherpa_ncnn {
 
 struct RecognitionResult {
@@ -48,31 +53,27 @@ struct RecognizerConfig {
   FeatureExtractorConfig feat_config;
   ModelConfig model_config;
   DecoderConfig decoder_config;
-  std::string decoding_method;
-  std::string hotwords_file;
   EndpointConfig endpoint_config;
   bool enable_endpoint = false;
-  // used only for modified_beam_search
-  int32_t max_active_paths = 4;
+
+  std::string hotwords_file;
+
   /// used only for modified_beam_search
   float hotwords_score = 1.5;
+
   RecognizerConfig() = default;
 
   RecognizerConfig(const FeatureExtractorConfig &feat_config,
                    const ModelConfig &model_config,
                    const DecoderConfig decoder_config,
                    const EndpointConfig &endpoint_config, bool enable_endpoint,
-                   const std::string &decoding_method,
-                   const std::string &hotwords_file,
-                   int32_t max_active_paths, float hotwords_score)
+                   const std::string &hotwords_file, float hotwords_score)
       : feat_config(feat_config),
         model_config(model_config),
         decoder_config(decoder_config),
         endpoint_config(endpoint_config),
         enable_endpoint(enable_endpoint),
-        decoding_method(decoding_method),
         hotwords_file(hotwords_file),
-        max_active_paths(max_active_paths),
         hotwords_score(hotwords_score) {}
 
   std::string ToString() const;

+ 10 - 11
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -18,9 +18,10 @@
  */
 
 #include <stdio.h>
-#include <fstream>
+
 #include <algorithm>
 #include <chrono>  // NOLINT
+#include <fstream>
 #include <iostream>
 
 #include "net.h"  // NOLINT
@@ -72,26 +73,24 @@ for a list of pre-trained models to download.
       config.decoder_config.method = method;
     }
   }
-  std::cout<<"decode method:"<<config.decoder_config.method<<std::endl;
-  if(argc >= 12) {
-	config.hotwords_file = argv[11];
-  } else {
-    config.hotwords_file = "";
+
+  if (argc >= 12) {
+    config.hotwords_file = argv[11];
   }
-  if(argc == 13) {
+
+  if (argc == 13) {
     config.hotwords_score = atof(argv[12]);
-  } else {
-    config.hotwords_file = 1.5;
   }
+
   config.feat_config.sampling_rate = expected_sampling_rate;
   config.feat_config.feature_dim = 80;
 
+  std::cout << config.ToString() << "\n";
+
   sherpa_ncnn::Recognizer recognizer(config);
 
   std::string wav_filename = argv[8];
 
-  std::cout << config.ToString() << "\n";
-
   bool is_ok = false;
   std::vector<float> samples =
       sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate, &is_ok);

+ 5 - 3
sherpa-ncnn/csrc/stream.cc

@@ -22,7 +22,8 @@ namespace sherpa_ncnn {
 
 class Stream::Impl {
  public:
-  explicit Impl(const FeatureExtractorConfig &config,ContextGraphPtr context_graph)
+  explicit Impl(const FeatureExtractorConfig &config,
+                ContextGraphPtr context_graph)
       : feat_extractor_(config), context_graph_(context_graph) {}
 
   void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
@@ -73,7 +74,8 @@ class Stream::Impl {
   std::vector<ncnn::Mat> states_;
 };
 
-Stream::Stream(const FeatureExtractorConfig &config, ContextGraphPtr context_graph)
+Stream::Stream(const FeatureExtractorConfig &config,
+               ContextGraphPtr context_graph)
     : impl_(std::make_unique<Impl>(config, context_graph)) {}
 
 Stream::~Stream() = default;
@@ -113,5 +115,5 @@ std::vector<ncnn::Mat> &Stream::GetStates() { return impl_->GetStates(); }
 
 const ContextGraphPtr &Stream::GetContextGraph() const {
   return impl_->GetContextGraph();
-  }
+}
 }  // namespace sherpa_ncnn

+ 2 - 2
sherpa-ncnn/csrc/stream.h

@@ -22,15 +22,15 @@
 #include <memory>
 #include <vector>
 
+#include "sherpa-ncnn/csrc/context-graph.h"
 #include "sherpa-ncnn/csrc/decoder.h"
 #include "sherpa-ncnn/csrc/features.h"
-#include "sherpa-ncnn/csrc/context-graph.h"
 
 namespace sherpa_ncnn {
 class Stream {
  public:
   explicit Stream(const FeatureExtractorConfig &config = {},
-				ContextGraphPtr context_graph = nullptr);
+                  ContextGraphPtr context_graph = nullptr);
   ~Stream();
 
   /**

+ 2 - 2
sherpa-ncnn/csrc/symbol-table.cc

@@ -23,12 +23,12 @@
 #include <fstream>
 #include <sstream>
 
-
 #if __ANDROID_API__ >= 9
+#include <strstream>
+
 #include "android/asset_manager.h"
 #include "android/asset_manager_jni.h"
 #include "android/log.h"
-#include <strstream>
 #endif
 
 namespace sherpa_ncnn {

+ 9 - 0
sherpa-ncnn/jni/jni.cc

@@ -259,6 +259,15 @@ static RecognizerConfig ParseConfig(JNIEnv *env, jobject _config) {
   config.endpoint_config.rule3.min_utterance_length =
       env->GetFloatField(_config, fid);
 
+  fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
+  jstring s = (jstring)env->GetObjectField(_config, fid);
+  const char *p = env->GetStringUTFChars(s, nullptr);
+  config.hotwords_file = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "hotwordsScore", "F");
+  config.hotwords_score = env->GetFloatField(_config, fid);
+
   NCNN_LOGE("------config------\n%s\n", config.ToString().c_str());
 
   return config;

+ 7 - 3
sherpa-ncnn/python/csrc/recognizer.cc

@@ -63,16 +63,20 @@ static void PybindRecognizerConfig(py::module *m) {
   using PyClass = RecognizerConfig;
   py::class_<PyClass>(*m, "RecognizerConfig")
       .def(py::init<const FeatureExtractorConfig &, const ModelConfig &,
-                    const DecoderConfig &, const EndpointConfig &, bool>(),
+                    const DecoderConfig &, const EndpointConfig &, bool,
+                    const std::string &, float>(),
            py::arg("feat_config"), py::arg("model_config"),
            py::arg("decoder_config"), py::arg("endpoint_config"),
-           py::arg("enable_endpoint"), kRecognizerConfigInitDoc)
+           py::arg("enable_endpoint"), py::arg("hotwords_file") = "",
+           py::arg("hotwords_score") = 1.5, kRecognizerConfigInitDoc)
       .def("__str__", &PyClass::ToString)
       .def_readwrite("feat_config", &PyClass::feat_config)
       .def_readwrite("model_config", &PyClass::model_config)
       .def_readwrite("decoder_config", &PyClass::decoder_config)
       .def_readwrite("endpoint_config", &PyClass::endpoint_config)
-      .def_readwrite("enable_endpoint", &PyClass::enable_endpoint);
+      .def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
+      .def_readwrite("hotwords_file", &PyClass::hotwords_file)
+      .def_readwrite("hotwords_score", &PyClass::hotwords_score);
 }
 
 void PybindRecognizer(py::module *m) {

+ 12 - 0
sherpa-ncnn/python/sherpa_ncnn/recognizer.py

@@ -91,6 +91,8 @@ class Recognizer(object):
         rule2_min_trailing_silence: int = 1.2,
         rule3_min_utterance_length: int = 20,
         model_sample_rate: int = 16000,
+        hotwords_file: str = "",
+        hotwords_score: float = 1.5,
     ):
         """
         Please refer to
@@ -143,6 +145,14 @@ class Recognizer(object):
             is detected.
           model_sample_rate:
             Sample rate expected by the model
+          hotwords_file:
+            Optional. If not empty, it specifies the hotwords file.
+            Each line in the hotwords file is a hotword. A hotword
+            consists of words seperated by spaces.
+            Used only when decoding_method is modified_beam_search.
+          hotwords_score:
+            The scale applied to hotwords score. Used only
+            when hotwords_file is not empty.
         """
         _assert_file_exists(tokens)
         _assert_file_exists(encoder_param)
@@ -190,6 +200,8 @@ class Recognizer(object):
             decoder_config=decoder_config,
             endpoint_config=endpoint_config,
             enable_endpoint=enable_endpoint_detection,
+            hotwords_file=hotwords_file,
+            hotwords_score=hotwords_score,
         )
 
         self.sample_rate = self.config.feat_config.sampling_rate

+ 6 - 2
swift-api-examples/SherpaNcnn.swift

@@ -118,7 +118,9 @@ func sherpaNcnnRecognizerConfig(
     enableEndpoint: Bool = false,
     rule1MinTrailingSilence: Float = 2.4,
     rule2MinTrailingSilence: Float = 1.2,
-    rule3MinUtteranceLength: Float = 30
+    rule3MinUtteranceLength: Float = 30,
+    hotwordsFile: String = "",
+    hotwordsScore: Float = 1.5
 ) -> SherpaNcnnRecognizerConfig {
     return SherpaNcnnRecognizerConfig(
         feat_config: featConfig,
@@ -127,7 +129,9 @@ func sherpaNcnnRecognizerConfig(
         enable_endpoint: enableEndpoint ? 1 : 0,
         rule1_min_trailing_silence: rule1MinTrailingSilence,
         rule2_min_trailing_silence: rule2MinTrailingSilence,
-        rule3_min_utterance_length: rule3MinUtteranceLength)
+        rule3_min_utterance_length: rule3MinUtteranceLength,
+        hotwords_file: toCPointer(hotwordsFile),
+        hotwords_score: hotwordsScore)
 }
 
 /// Wrapper for recognition result.