Эх сурвалжийг харах

add endpointing for android (#55)

Fangjun Kuang 2 жил өмнө
parent
commit
4766a6a316

+ 36 - 3
android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/MainActivity.kt

@@ -42,6 +42,8 @@ class MainActivity : AppCompatActivity() {
     @Volatile
     private var isRecording: Boolean = false
 
+    private var results: MutableList<String> = ArrayList()
+
     override fun onRequestPermissionsResult(
         requestCode: Int,
         permissions: Array<String>,
@@ -91,6 +93,9 @@ class MainActivity : AppCompatActivity() {
             recordButton.setText(R.string.stop)
             isRecording = true
             model.reset()
+            results = ArrayList()
+            textView.text = ""
+
             recordingThread = thread(true) {
                 processSamples()
             }
@@ -101,7 +106,7 @@ class MainActivity : AppCompatActivity() {
             audioRecord!!.release()
             audioRecord = null
             recordButton.setText(R.string.start)
-            textView.text = model.text
+            textView.text = joinText()
             Log.i(TAG, "Stopped recording")
         }
     }
@@ -118,7 +123,22 @@ class MainActivity : AppCompatActivity() {
             if (ret != null && ret > 0) {
                 val samples = FloatArray(ret) { buffer[it] / 32768.0f }
                 model.decodeSamples(samples)
-                runOnUiThread { textView.text = model.text }
+                runOnUiThread {
+                    val isEndpoint = model.isEndpoint()
+                    val text = model.text
+
+                    if (text.isNotBlank()) {
+                        if (isEndpoint) {
+                            results[results.size - 1] = text
+                            results.add("")
+                        } else {
+                            if (results.isEmpty()) results.add("")
+                            results[results.size - 1] = text
+                        }
+                    }
+
+                    textView.text = joinText()
+                }
             }
         }
     }
@@ -153,7 +173,20 @@ class MainActivity : AppCompatActivity() {
         model = SherpaNcnn(
             assetManager = application.assets,
             modelConfig = getModelConfig(type = 1, useGPU = useGPU)!!,
+            decoderConfig=getDecoderConfig(useEndpoint = true),
             fbankConfig = getFbankConfig(),
         )
     }
-}
+
+    private fun joinText(): String {
+        var r = ""
+        var sep = ""
+        results.forEachIndexed { i, s ->
+            if (s.isNotBlank()) {
+                r = r.plus("${sep}${i}: ${s}")
+                sep = "\n"
+            }
+        }
+        return r
+    }
+}

+ 42 - 5
android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/SherpaNcnn.kt

@@ -2,6 +2,25 @@ package com.k2fsa.sherpa.ncnn
 
 import android.content.res.AssetManager
 
+data class EndpointRule(
+    var mustContainNonSilence: Boolean,
+    var minTrailingSilence: Float,
+    var minUtteranceLength: Float,
+)
+
+data class EndpointConfig(
+    var rule1: EndpointRule = EndpointRule(false, 2.4f, 0.0f),
+    var rule2: EndpointRule = EndpointRule(true, 1.4f, 0.0f),
+    var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)
+)
+
+data class DecoderConfig(
+    var method: String = "modified_beam_search", // valid values: greedy_search, modified_beam_search
+    var numActivePaths: Int = 4, // used only by modified_beam_search
+    var useEndpoint: Boolean = true,
+    var endpointConfig: EndpointConfig = EndpointConfig(),
+)
+
 data class FrameExtractionOptions(
     var sampFreq: Float = 16000.0f,
     var frameShiftMs: Float = 10.0f,
@@ -52,12 +71,13 @@ data class ModelConfig(
 class SherpaNcnn(
     assetManager: AssetManager,
     modelConfig: ModelConfig,
+    decoderConfig: DecoderConfig,
     var fbankConfig: FbankOptions,
 ) {
     private val ptr: Long
 
     init {
-        ptr = new(assetManager, modelConfig, fbankConfig)
+        ptr = new(assetManager, modelConfig, decoderConfig, fbankConfig)
     }
 
     protected fun finalize() {
@@ -69,6 +89,7 @@ class SherpaNcnn(
 
     fun inputFinished() = inputFinished(ptr)
     fun reset() = reset(ptr)
+    fun isEndpoint(): Boolean = isEndpoint(ptr)
 
     val text: String
         get() = getText(ptr)
@@ -76,6 +97,7 @@ class SherpaNcnn(
     private external fun new(
         assetManager: AssetManager,
         modelConfig: ModelConfig,
+        decoderConfig: DecoderConfig,
         fbankConfig: FbankOptions
     ): Long
 
@@ -84,6 +106,7 @@ class SherpaNcnn(
     private external fun inputFinished(ptr: Long)
     private external fun getText(ptr: Long): String
     private external fun reset(ptr: Long)
+    private external fun isEndpoint(ptr: Long): Boolean
 
     companion object {
         init {
@@ -116,12 +139,12 @@ fun getModelConfig(type: Int, useGPU: Boolean): ModelConfig? {
         1 -> {
             val modelDir = "sherpa-ncnn-conv-emformer-transducer-2022-12-06"
             return ModelConfig(
-                encoderParam = "$modelDir/encoder_jit_trace-pnnx.ncnn.param",
-                encoderBin = "$modelDir/encoder_jit_trace-pnnx.ncnn.bin",
+                encoderParam = "$modelDir/encoder_jit_trace-pnnx.ncnn.int8.param",
+                encoderBin = "$modelDir/encoder_jit_trace-pnnx.ncnn.int8.bin",
                 decoderParam = "$modelDir/decoder_jit_trace-pnnx.ncnn.param",
                 decoderBin = "$modelDir/decoder_jit_trace-pnnx.ncnn.bin",
-                joinerParam = "$modelDir/joiner_jit_trace-pnnx.ncnn.param",
-                joinerBin = "$modelDir/joiner_jit_trace-pnnx.ncnn.bin",
+                joinerParam = "$modelDir/joiner_jit_trace-pnnx.ncnn.int8.param",
+                joinerBin = "$modelDir/joiner_jit_trace-pnnx.ncnn.int8.bin",
                 tokens = "$modelDir/tokens.txt",
                 numThreads = 4,
                 useGPU = useGPU,
@@ -144,4 +167,18 @@ fun getModelConfig(type: Int, useGPU: Boolean): ModelConfig? {
         }
     }
     return null
+}
+
+fun getDecoderConfig(useEndpoint: Boolean): DecoderConfig {
+    return DecoderConfig(
+        method = "modified_beam_search",
+        numActivePaths = 4,
+        useEndpoint = useEndpoint,
+        endpointConfig = EndpointConfig(
+            rule1 = EndpointRule(false, 2.4f, 0.0f),
+            rule2 = EndpointRule(true, 1.4f, 0.0f),
+            rule3 = EndpointRule(false, 0.0f, 20.0f)
+        )
+    )
+
 }

+ 24 - 0
sherpa-ncnn/csrc/endpoint.cc

@@ -17,6 +17,7 @@
  */
 #include "sherpa-ncnn/csrc/endpoint.h"
 
+#include <sstream>
 #include <string>
 
 namespace sherpa_ncnn {
@@ -32,6 +33,29 @@ static bool RuleActivated(const EndpointRule &rule,
   return ans;
 }
 
+std::string EndpointRule::ToString() const {
+  std::ostringstream os;
+
+  os << "EndpointRule(";
+  os << "must_contain_nonsilence="
+     << (must_contain_nonsilence ? "True" : "False") << ", ";
+  os << "min_trailing_silence=" << min_trailing_silence << ", ";
+  os << "min_utterance_length=" << min_utterance_length << ")";
+
+  return os.str();
+}
+
+std::string EndpointConfig::ToString() const {
+  std::ostringstream os;
+
+  os << "EndpointConfig(";
+  os << "rule1=" << rule1.ToString() << ", ";
+  os << "rule2=" << rule2.ToString() << ", ";
+  os << "rule3=" << rule3.ToString() << ")";
+
+  return os.str();
+}
+
 bool Endpoint::IsEndpoint(const int num_frames_decoded,
                           const int trailing_silence_frames,
                           const float frame_shift_in_seconds) const {

+ 4 - 0
sherpa-ncnn/csrc/endpoint.h

@@ -40,6 +40,8 @@ struct EndpointRule {
       : must_contain_nonsilence(must_contain_nonsilence),
         min_trailing_silence(min_trailing_silence),
         min_utterance_length(min_utterance_length) {}
+
+  std::string ToString() const;
 };
 
 struct EndpointConfig {
@@ -54,6 +56,8 @@ struct EndpointConfig {
 
   EndpointConfig()
       : rule1(false, 2.4, 0), rule2(true, 1.4, 0), rule3(false, 0, 20) {}
+
+  std::string ToString() const;
 };
 
 class Endpoint {

+ 10 - 1
sherpa-ncnn/csrc/greedy-search-decoder.cc

@@ -88,9 +88,18 @@ void GreedySearchDecoder::InputFinished() {
   feature_extractor_.InputFinished();
 }
 
-bool GreedySearchDecoder::IsEndpoint() const {
+bool GreedySearchDecoder::IsEndpoint() {
   return endpoint_->IsEndpoint(num_processed_ - endpoint_start_frame_,
                                result_.num_trailing_blanks * 4, 10 / 1000.0);
 }
 
+void GreedySearchDecoder::Reset() {
+  ResetResult();
+  BuildDecoderInput();
+  decoder_out_ = model_->RunDecoder(decoder_input_);
+  feature_extractor_.Reset();
+  num_processed_ = 0;
+  endpoint_start_frame_ = 0;
+}
+
 }  // namespace sherpa_ncnn

+ 3 - 1
sherpa-ncnn/csrc/greedy-search-decoder.h

@@ -60,7 +60,9 @@ class GreedySearchDecoder : public Decoder {
 
   void ResetResult() override;
 
-  bool IsEndpoint() const override;
+  bool IsEndpoint() override;
+
+  void Reset() override;
 
   void InputFinished() override;
 

+ 11 - 14
sherpa-ncnn/csrc/model.cc

@@ -27,20 +27,17 @@ namespace sherpa_ncnn {
 
 std::string ModelConfig::ToString() const {
   std::ostringstream os;
-  os << "encoder_param: " << encoder_param << "\n";
-  os << "encoder_bin: " << encoder_bin << "\n";
-
-  os << "decoder_param: " << decoder_param << "\n";
-  os << "decoder_bin: " << decoder_bin << "\n";
-
-  os << "joiner_param: " << joiner_param << "\n";
-  os << "joiner_bin: " << joiner_bin << "\n";
-
-  os << "tokens: " << tokens << "\n";
-
-  os << "encoder num_threads: " << encoder_opt.num_threads << "\n";
-  os << "decoder num_threads: " << decoder_opt.num_threads << "\n";
-  os << "joiner num_threads: " << joiner_opt.num_threads << "\n";
+  os << "ModelConfig(";
+  os << "encoder_param=\"" << encoder_param << "\", ";
+  os << "encoder_bin=\"" << encoder_bin << "\", ";
+  os << "decoder_param=\"" << decoder_param << "\", ";
+  os << "decoder_bin=\"" << decoder_bin << "\", ";
+  os << "joiner_param=\"" << joiner_param << "\", ";
+  os << "joiner_bin=\"" << joiner_bin << "\", ";
+  os << "tokens=\"" << tokens << "\", ";
+  os << "encoder num_threads=" << encoder_opt.num_threads << ", ";
+  os << "decoder num_threads=" << decoder_opt.num_threads << ", ";
+  os << "joiner num_threads=" << joiner_opt.num_threads << ")";
 
   return os.str();
 }

+ 13 - 2
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -45,6 +45,7 @@ void ModifiedBeamSearchDecoder::ResetResult() {
   std::vector<int32_t> blanks(context_size_, blank_id_);
   Hypotheses blank_hyp({{blanks, 0}});
   result_.hyps = std::move(blank_hyp);
+  result_.num_trailing_blanks = 0;
 }
 
 void ModifiedBeamSearchDecoder::Decode() {
@@ -109,8 +110,9 @@ RecognitionResult ModifiedBeamSearchDecoder::GetResult() {
     }
   }
   result_.text = std::move(best_hyp_text);
-  auto ans = result_;
   result_.num_trailing_blanks = best_hyp.num_trailing_blanks;
+  auto ans = result_;
+
   if (config_.use_endpoint && IsEndpoint()) {
     ResetResult();
     endpoint_start_frame_ = num_processed_;
@@ -122,9 +124,18 @@ void ModifiedBeamSearchDecoder::InputFinished() {
   feature_extractor_.InputFinished();
 }
 
-bool ModifiedBeamSearchDecoder::IsEndpoint() const {
+bool ModifiedBeamSearchDecoder::IsEndpoint() {
+  auto best_hyp = result_.hyps.GetMostProbable(true);
+  result_.num_trailing_blanks = best_hyp.num_trailing_blanks;
   return endpoint_->IsEndpoint(num_processed_ - endpoint_start_frame_,
                                result_.num_trailing_blanks * 4, 10 / 1000.0);
 }
 
+void ModifiedBeamSearchDecoder::Reset() {
+  ResetResult();
+  feature_extractor_.Reset();
+  num_processed_ = 0;
+  endpoint_start_frame_ = 0;
+}
+
 }  // namespace sherpa_ncnn

+ 3 - 1
sherpa-ncnn/csrc/modified-beam-search-decoder.h

@@ -58,7 +58,9 @@ class ModifiedBeamSearchDecoder : public Decoder {
 
   void ResetResult() override;
 
-  bool IsEndpoint() const override;
+  bool IsEndpoint() override;
+
+  void Reset() override;
 
   void InputFinished() override;
 

+ 15 - 1
sherpa-ncnn/csrc/recognizer.cc

@@ -26,6 +26,18 @@
 
 namespace sherpa_ncnn {
 
+std::string DecoderConfig::ToString() const {
+  std::ostringstream os;
+
+  os << "DecoderConfig(";
+  os << "method=\"" << method << "\", ";
+  os << "num_active_paths=" << num_active_paths << ", ";
+  os << "use_endpoint=" << (use_endpoint ? "True" : "False") << ", ";
+  os << "endpoint_config=" << endpoint_config.ToString() << ")";
+
+  return os.str();
+}
+
 Recognizer::Recognizer(
 #if __ANDROID_API__ >= 9
     AAssetManager *mgr,
@@ -62,7 +74,9 @@ void Recognizer::Decode() { decoder_->Decode(); }
 
 RecognitionResult Recognizer::GetResult() { return decoder_->GetResult(); }
 
-bool Recognizer::IsEndpoint() const { return decoder_->IsEndpoint(); }
+bool Recognizer::IsEndpoint() { return decoder_->IsEndpoint(); }
+
+void Recognizer::Reset() { return decoder_->Reset(); }
 
 void Recognizer::InputFinished() { return decoder_->InputFinished(); }
 

+ 7 - 2
sherpa-ncnn/csrc/recognizer.h

@@ -50,6 +50,7 @@ struct DecoderConfig {
   bool use_endpoint = true;
 
   EndpointConfig endpoint_config;
+  std::string ToString() const;
 };
 
 class Decoder {
@@ -67,7 +68,9 @@ class Decoder {
 
   virtual void InputFinished() = 0;
 
-  virtual bool IsEndpoint() const = 0;
+  virtual bool IsEndpoint() = 0;
+
+  virtual void Reset() = 0;
 };
 
 class Recognizer {
@@ -92,7 +95,9 @@ class Recognizer {
 
   void InputFinished();
 
-  bool IsEndpoint() const;
+  bool IsEndpoint();
+
+  void Reset();
 
  private:
   std::unique_ptr<Model> model_;

+ 5 - 2
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -58,8 +58,9 @@ Usage:
     /path/to/joiner.ncnn.bin \
     [num_threads]
 
-You can download pre-trained models from the following repository:
-https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
+Please refer to
+https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
+for a list of pre-trained models to download.
 )usage";
     fprintf(stderr, "%s\n", usage);
     fprintf(stderr, "argc, %d\n", argc);
@@ -101,6 +102,8 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
   fbank_opts.frame_opts.samp_freq = expected_sampling_rate;
   fbank_opts.mel_opts.num_bins = 80;
 
+  fprintf(stderr, "%s\n", decoder_conf.ToString().c_str());
+
   sherpa_ncnn::Recognizer recognizer(decoder_conf, model_conf, fbank_opts);
 
   sherpa_ncnn::Microphone mic;

+ 3 - 2
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -38,8 +38,9 @@ Usage:
     /path/to/joiner.ncnn.bin \
     /path/to/foo.wav [num_threads] [decode_method, can be greedy_search/modified_beam_search]
 
-You can download pre-trained models from the following repository:
-https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
+Please refer to
+https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
+for a list of pre-trained models to download.
 )usage";
     std::cerr << usage << "\n";
 

+ 83 - 5
sherpa-ncnn/jni/jni.cc

@@ -61,6 +61,10 @@ class SherpaNcnn {
     return result.text;
   }
 
+  bool IsEndpoint() { return recognizer_.IsEndpoint(); }
+
+  void Reset() { return recognizer_.Reset(); }
+
  private:
   sherpa_ncnn::Recognizer recognizer_;
   std::vector<float> tail_padding_;
@@ -126,6 +130,68 @@ static ModelConfig GetModelConfig(JNIEnv *env, jobject config) {
   return model_config;
 }
 
+static DecoderConfig GetDecoderConfig(JNIEnv *env, jobject config) {
+  DecoderConfig decoder_config;
+
+  jclass cls = env->GetObjectClass(config);
+
+  jfieldID fid = env->GetFieldID(cls, "method", "Ljava/lang/String;");
+  jstring s = (jstring)env->GetObjectField(config, fid);
+  const char *p = env->GetStringUTFChars(s, nullptr);
+  decoder_config.method = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "numActivePaths", "I");
+  decoder_config.num_active_paths = env->GetIntField(config, fid);
+
+  fid = env->GetFieldID(cls, "useEndpoint", "Z");
+  decoder_config.use_endpoint = env->GetBooleanField(config, fid);
+
+  fid = env->GetFieldID(cls, "endpointConfig",
+                        "Lcom/k2fsa/sherpa/ncnn/EndpointConfig;");
+  jobject endpoint_config = env->GetObjectField(config, fid);
+  jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
+
+  fid = env->GetFieldID(endpoint_config_cls, "rule1",
+                        "Lcom/k2fsa/sherpa/ncnn/EndpointRule;");
+  jobject rule1 = env->GetObjectField(endpoint_config, fid);
+  jclass rule_class = env->GetObjectClass(rule1);
+
+  fid = env->GetFieldID(endpoint_config_cls, "rule2",
+                        "Lcom/k2fsa/sherpa/ncnn/EndpointRule;");
+  jobject rule2 = env->GetObjectField(endpoint_config, fid);
+
+  fid = env->GetFieldID(endpoint_config_cls, "rule3",
+                        "Lcom/k2fsa/sherpa/ncnn/EndpointRule;");
+  jobject rule3 = env->GetObjectField(endpoint_config, fid);
+
+  fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
+  decoder_config.endpoint_config.rule1.must_contain_nonsilence =
+      env->GetBooleanField(rule1, fid);
+  decoder_config.endpoint_config.rule2.must_contain_nonsilence =
+      env->GetBooleanField(rule2, fid);
+  decoder_config.endpoint_config.rule3.must_contain_nonsilence =
+      env->GetBooleanField(rule3, fid);
+
+  fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
+  decoder_config.endpoint_config.rule1.min_trailing_silence =
+      env->GetFloatField(rule1, fid);
+  decoder_config.endpoint_config.rule2.min_trailing_silence =
+      env->GetFloatField(rule2, fid);
+  decoder_config.endpoint_config.rule3.min_trailing_silence =
+      env->GetFloatField(rule3, fid);
+
+  fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
+  decoder_config.endpoint_config.rule1.min_utterance_length =
+      env->GetFloatField(rule1, fid);
+  decoder_config.endpoint_config.rule2.min_utterance_length =
+      env->GetFloatField(rule2, fid);
+  decoder_config.endpoint_config.rule3.min_utterance_length =
+      env->GetFloatField(rule3, fid);
+
+  return decoder_config;
+}
+
 static knf::FbankOptions GetFbankOptions(JNIEnv *env, jobject opts) {
   jclass cls = env->GetObjectClass(opts);
   jfieldID fid;
@@ -231,16 +297,18 @@ static knf::FbankOptions GetFbankOptions(JNIEnv *env, jobject opts) {
 SHERPA_EXTERN_C
 JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_new(
     JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _model_config,
-    jobject _fbank_config) {
+    jobject _decoder_config, jobject _fbank_config) {
   AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
   if (!mgr) {
     NCNN_LOGE("Failed to get asset manager: %p", mgr);
   }
 
-  sherpa_ncnn::ModelConfig model_config =
-      sherpa_ncnn::GetModelConfig(env, _model_config);
+  auto model_config = sherpa_ncnn::GetModelConfig(env, _model_config);
+  auto decoder_config = sherpa_ncnn::GetDecoderConfig(env, _decoder_config);
 
-  sherpa_ncnn::DecoderConfig decoder_config;
+  NCNN_LOGE("------model_config------\n%s\n", model_config.ToString().c_str());
+  NCNN_LOGE("------decoder_config------\n%s\n",
+            decoder_config.ToString().c_str());
 
   knf::FbankOptions fbank_opts =
       sherpa_ncnn::GetFbankOptions(env, _fbank_config);
@@ -259,7 +327,17 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_delete(
 
 SHERPA_EXTERN_C
 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_reset(
-    JNIEnv *env, jobject /*obj*/, jlong ptr) {}
+    JNIEnv *env, jobject /*obj*/, jlong ptr) {
+  auto model = reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr);
+  model->Reset();
+}
+
+SHERPA_EXTERN_C
+JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_isEndpoint(
+    JNIEnv *env, jobject /*obj*/, jlong ptr) {
+  auto model = reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr);
+  return model->IsEndpoint();
+}
 
 SHERPA_EXTERN_C
 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_decodeSamples(