Browse Source

Support endpoint detection in Python (#56)

* Fix Python API
Fangjun Kuang 2 years ago
parent
commit
2e78bbf587

+ 11 - 11
README.md

@@ -15,6 +15,9 @@ We support all platforms that [ncnn](https://github.com/tencent/ncnn) supports.
 Everything can be compiled from source with static link. The generated
 Everything can be compiled from source with static link. The generated
 executable depends only on system libraries.
 executable depends only on system libraries.
 
 
+**HINT**: It does not depend on PyTorch or any other inference frameworks
+other than [ncnn](https://github.com/tencent/ncnn).
+
 Please see the documentation <https://k2-fsa.github.io/sherpa/ncnn/index.html>
 Please see the documentation <https://k2-fsa.github.io/sherpa/ncnn/index.html>
 for installation and usages, e.g.,
 for installation and usages, e.g.,
 
 
@@ -24,21 +27,18 @@ for installation and usages, e.g.,
 We provide a few YouTube videos for demonstration about real-time speech recognition
 We provide a few YouTube videos for demonstration about real-time speech recognition
 with `sherpa-ncnn` using a microphone:
 with `sherpa-ncnn` using a microphone:
 
 
-  - `English`: <https://www.youtube.com/watch?v=m6ynSxycpX0>
-  - `Chinese`: <https://www.youtube.com/watch?v=bbQfoRT75oM>
-  - `Chinese + English` Android demo: <https://www.youtube.com/shorts/S5Owcrb8vzU>
-  - `Chinese (with background noise)` Android demo : <https://www.youtube.com/shorts/KI1-d-W9uZw>
-  - `Chinese` Android demo : <https://www.youtube.com/shorts/lpDAG36T1R4>
-  - `Chinese poem with background music` Android demo : <https://www.youtube.com/shorts/5CJ-r8VNuwo>
-
-**Note**: If you don't have access to YouTube, we provide the links
-in bilibili below:
-
   - `English`: <https://www.bilibili.com/video/BV1TP411p7dh/>
   - `English`: <https://www.bilibili.com/video/BV1TP411p7dh/>
   - `Chinese`: <https://www.bilibili.com/video/BV1214y177vu>
   - `Chinese`: <https://www.bilibili.com/video/BV1214y177vu>
-  - `Chinese + English` Android demo: <https://www.bilibili.com/video/BV1Ge411A7XS>
+
+  - Multilingual (Chinese + English) with endpointing Python demo : <https://www.bilibili.com/video/BV1eK411y788/>
+
+  - **Android demos**
+
+  - Multilingual (Chinese + English) Android demo 1: <https://www.bilibili.com/video/BV1Ge411A7XS>
+  - Multilingual (Chinese + English) Android demo 2: <https://www.bilibili.com/video/BV1eK411y788/>
   - `Chinese (with background noise)` Android demo : <https://www.bilibili.com/video/BV1GR4y167fx>
   - `Chinese (with background noise)` Android demo : <https://www.bilibili.com/video/BV1GR4y167fx>
   - `Chinese` Android demo : <https://www.bilibili.com/video/BV1744y1Z76H>
   - `Chinese` Android demo : <https://www.bilibili.com/video/BV1744y1Z76H>
   - `Chinese poem with background music` Android demo : <https://www.bilibili.com/video/BV1vR4y1k7eo>
   - `Chinese poem with background music` Android demo : <https://www.bilibili.com/video/BV1vR4y1k7eo>
 
 
+
 See also <https://github.com/k2-fsa/sherpa>
 See also <https://github.com/k2-fsa/sherpa>

+ 1 - 1
android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/MainActivity.kt

@@ -173,7 +173,7 @@ class MainActivity : AppCompatActivity() {
         model = SherpaNcnn(
         model = SherpaNcnn(
             assetManager = application.assets,
             assetManager = application.assets,
             modelConfig = getModelConfig(type = 1, useGPU = useGPU)!!,
             modelConfig = getModelConfig(type = 1, useGPU = useGPU)!!,
-            decoderConfig=getDecoderConfig(useEndpoint = true),
+            decoderConfig = getDecoderConfig(enableEndpoint = true),
             fbankConfig = getFbankConfig(),
             fbankConfig = getFbankConfig(),
         )
         )
     }
     }

+ 3 - 3
android/SherpaNcnn/app/src/main/java/com/k2fsa/sherpa/ncnn/SherpaNcnn.kt

@@ -17,7 +17,7 @@ data class EndpointConfig(
 data class DecoderConfig(
 data class DecoderConfig(
     var method: String = "modified_beam_search", // valid values: greedy_search, modified_beam_search
     var method: String = "modified_beam_search", // valid values: greedy_search, modified_beam_search
     var numActivePaths: Int = 4, // used only by modified_beam_search
     var numActivePaths: Int = 4, // used only by modified_beam_search
-    var useEndpoint: Boolean = true,
+    var enableEndpoint: Boolean = true,
     var endpointConfig: EndpointConfig = EndpointConfig(),
     var endpointConfig: EndpointConfig = EndpointConfig(),
 )
 )
 
 
@@ -169,11 +169,11 @@ fun getModelConfig(type: Int, useGPU: Boolean): ModelConfig? {
     return null
     return null
 }
 }
 
 
-fun getDecoderConfig(useEndpoint: Boolean): DecoderConfig {
+fun getDecoderConfig(enableEndpoint: Boolean): DecoderConfig {
     return DecoderConfig(
     return DecoderConfig(
         method = "modified_beam_search",
         method = "modified_beam_search",
         numActivePaths = 4,
         numActivePaths = 4,
-        useEndpoint = useEndpoint,
+        enableEndpoint = enableEndpoint,
         endpointConfig = EndpointConfig(
         endpointConfig = EndpointConfig(
             rule1 = EndpointRule(false, 2.4f, 0.0f),
             rule1 = EndpointRule(false, 2.4f, 0.0f),
             rule2 = EndpointRule(true, 1.4f, 0.0f),
             rule2 = EndpointRule(true, 1.4f, 0.0f),

+ 12 - 0
python-api-examples/README.md

@@ -7,3 +7,15 @@ This file shows how to recognize a file.
 ## speech-recognition-from-microphone.py
 ## speech-recognition-from-microphone.py
 
 
 This file demonstrates how to do real-time speech recognition with a microphone.
 This file demonstrates how to do real-time speech recognition with a microphone.
+
+You can find video demos about this file at the following addresses:
+
+  - https://www.bilibili.com/video/BV1K44y197Fg/
+  - https://www.youtube.com/watch?v=74SxVueROok
+
+## speech-recognition-from-microphone-with-endpoint-detection.py
+
+Similar to `speech-recognition-from-microphone.py` but it also enables
+endpoint detection.
+
+You can find a video demo about this file at <https://www.bilibili.com/video/BV1eK411y788>

+ 80 - 0
python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py

@@ -0,0 +1,80 @@
+#!/usr/bin/env python3
+
+# Real-time speech recognition from a microphone with sherpa-ncnn Python API
+# with endpoint detection.
+#
+# Please refer to
+# https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
+# to download pre-trained models
+
+import sys
+
+try:
+    import sounddevice as sd
+except ImportError as e:
+    print("Please install sounddevice first. You can use")
+    print()
+    print("  pip install sounddevice")
+    print()
+    print("to install it")
+    sys.exit(-1)
+
+import sherpa_ncnn
+
+
+def create_recognizer():
+    # Please replace the model files if needed.
+    # See https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
+    # for download links.
+    recognizer = sherpa_ncnn.Recognizer(
+        tokens="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/tokens.txt",
+        encoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.param",
+        encoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.bin",
+        decoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.param",
+        decoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.bin",
+        joiner_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.param",
+        joiner_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.bin",
+        num_threads=4,
+        decoding_method="modified_beam_search",
+        enable_endpoint_detection=True,
+        rule1_min_trailing_silence=2.4,
+        rule2_min_trailing_silence=1.2,
+        rule3_min_utterance_length=300,
+    )
+    return recognizer
+
+
+def main():
+    print("Started! Please speak")
+    recognizer = create_recognizer()
+    sample_rate = recognizer.sample_rate
+    samples_per_read = int(0.1 * sample_rate)  # 0.1 second = 100 ms
+    last_result = ""
+    segment_id = 0
+    with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
+        while True:
+            samples, _ = s.read(samples_per_read)  # a blocking read
+            samples = samples.reshape(-1)
+            recognizer.accept_waveform(sample_rate, samples)
+
+            is_endpoint = recognizer.is_endpoint
+
+            result = recognizer.text
+            if result and (last_result != result):
+                last_result = result
+                print(f"{segment_id}: {result}")
+
+            if result and is_endpoint:
+                segment_id += 1
+
+
+if __name__ == "__main__":
+    devices = sd.query_devices()
+    print(devices)
+    default_input_device_idx = sd.default.device[0]
+    print(f'Use default device: {devices[default_input_device_idx]["name"]}')
+
+    try:
+        main()
+    except KeyboardInterrupt:
+        print("\nCaught Ctrl + C. Exiting")

+ 1 - 1
python-api-examples/speech-recognition-from-microphone.py

@@ -42,7 +42,7 @@ def main():
     print("Started! Please speak")
     print("Started! Please speak")
     recognizer = create_recognizer()
     recognizer = create_recognizer()
     sample_rate = recognizer.sample_rate
     sample_rate = recognizer.sample_rate
-    samples_per_read = int(0.02 * sample_rate)  # 20ms
+    samples_per_read = int(0.1 * sample_rate)  # 0.1 second = 100 ms
     last_result = ""
     last_result = ""
     with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
     with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
         while True:
         while True:

+ 1 - 0
sherpa-ncnn/csrc/CPPLINT.cfg

@@ -0,0 +1 @@
+exclude_files=generate-int8-scale-table.cc

+ 5 - 0
sherpa-ncnn/csrc/endpoint.h

@@ -18,6 +18,7 @@
 #ifndef SHERPA_NCNN_CSRC_ENDPOINT_H_
 #ifndef SHERPA_NCNN_CSRC_ENDPOINT_H_
 #define SHERPA_NCNN_CSRC_ENDPOINT_H_
 #define SHERPA_NCNN_CSRC_ENDPOINT_H_
 
 
+#include <string>
 #include <vector>
 #include <vector>
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
@@ -54,6 +55,10 @@ struct EndpointConfig {
   EndpointRule rule2;
   EndpointRule rule2;
   EndpointRule rule3;
   EndpointRule rule3;
 
 
+  EndpointConfig(const EndpointRule &rule1, const EndpointRule &rule2,
+                 const EndpointRule &rule3)
+      : rule1(rule1), rule2(rule2), rule3(rule3) {}
+
   EndpointConfig()
   EndpointConfig()
       : rule1(false, 2.4, 0), rule2(true, 1.4, 0), rule3(false, 0, 20) {}
       : rule1(false, 2.4, 0), rule2(true, 1.4, 0), rule3(false, 0, 20) {}
 
 

+ 2 - 2
sherpa-ncnn/csrc/greedy-search-decoder.cc

@@ -21,7 +21,7 @@
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
-void GreedySearchDecoder::AcceptWaveform(const int32_t sample_rate,
+void GreedySearchDecoder::AcceptWaveform(const float sample_rate,
                                          const float *input_buffer,
                                          const float *input_buffer,
                                          int32_t frames_per_buffer) {
                                          int32_t frames_per_buffer) {
   feature_extractor_.AcceptWaveform(sample_rate, input_buffer,
   feature_extractor_.AcceptWaveform(sample_rate, input_buffer,
@@ -77,7 +77,7 @@ void GreedySearchDecoder::Decode() {
 
 
 RecognitionResult GreedySearchDecoder::GetResult() {
 RecognitionResult GreedySearchDecoder::GetResult() {
   auto ans = result_;
   auto ans = result_;
-  if (config_.use_endpoint && IsEndpoint()) {
+  if (config_.enable_endpoint && IsEndpoint()) {
     ResetResult();
     ResetResult();
     endpoint_start_frame_ = num_processed_;
     endpoint_start_frame_ = num_processed_;
   }
   }

+ 1 - 1
sherpa-ncnn/csrc/greedy-search-decoder.h

@@ -51,7 +51,7 @@ class GreedySearchDecoder : public Decoder {
     decoder_out_ = model_->RunDecoder(decoder_input_);
     decoder_out_ = model_->RunDecoder(decoder_input_);
   }
   }
 
 
-  void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+  void AcceptWaveform(float sample_rate, const float *input_buffer,
                       int32_t frames_per_buffer) override;
                       int32_t frames_per_buffer) override;
 
 
   void Decode() override;
   void Decode() override;

+ 2 - 2
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -25,7 +25,7 @@
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
-void ModifiedBeamSearchDecoder::AcceptWaveform(const int32_t sample_rate,
+void ModifiedBeamSearchDecoder::AcceptWaveform(const float sample_rate,
                                                const float *input_buffer,
                                                const float *input_buffer,
                                                int32_t frames_per_buffer) {
                                                int32_t frames_per_buffer) {
   feature_extractor_.AcceptWaveform(sample_rate, input_buffer,
   feature_extractor_.AcceptWaveform(sample_rate, input_buffer,
@@ -113,7 +113,7 @@ RecognitionResult ModifiedBeamSearchDecoder::GetResult() {
   result_.num_trailing_blanks = best_hyp.num_trailing_blanks;
   result_.num_trailing_blanks = best_hyp.num_trailing_blanks;
   auto ans = result_;
   auto ans = result_;
 
 
-  if (config_.use_endpoint && IsEndpoint()) {
+  if (config_.enable_endpoint && IsEndpoint()) {
     ResetResult();
     ResetResult();
     endpoint_start_frame_ = num_processed_;
     endpoint_start_frame_ = num_processed_;
   }
   }

+ 1 - 1
sherpa-ncnn/csrc/modified-beam-search-decoder.h

@@ -49,7 +49,7 @@ class ModifiedBeamSearchDecoder : public Decoder {
     ResetResult();
     ResetResult();
   }
   }
 
 
-  void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+  void AcceptWaveform(float sample_rate, const float *input_buffer,
                       int32_t frames_per_buffer) override;
                       int32_t frames_per_buffer) override;
 
 
   void Decode() override;
   void Decode() override;

+ 6 - 4
sherpa-ncnn/csrc/recognizer.cc

@@ -17,6 +17,8 @@
  * limitations under the License.
  * limitations under the License.
  */
  */
 
 
+#include "sherpa-ncnn/csrc/recognizer.h"
+
 #include <memory>
 #include <memory>
 #include <string>
 #include <string>
 #include <vector>
 #include <vector>
@@ -32,7 +34,7 @@ std::string DecoderConfig::ToString() const {
   os << "DecoderConfig(";
   os << "DecoderConfig(";
   os << "method=\"" << method << "\", ";
   os << "method=\"" << method << "\", ";
   os << "num_active_paths=" << num_active_paths << ", ";
   os << "num_active_paths=" << num_active_paths << ", ";
-  os << "use_endpoint=" << (use_endpoint ? "True" : "False") << ", ";
+  os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
   os << "endpoint_config=" << endpoint_config.ToString() << ")";
   os << "endpoint_config=" << endpoint_config.ToString() << ")";
 
 
   return os.str();
   return os.str();
@@ -42,8 +44,8 @@ Recognizer::Recognizer(
 #if __ANDROID_API__ >= 9
 #if __ANDROID_API__ >= 9
     AAssetManager *mgr,
     AAssetManager *mgr,
 #endif
 #endif
-    const DecoderConfig decoder_conf, const ModelConfig model_conf,
-    const knf::FbankOptions fbank_opts)
+    const DecoderConfig &decoder_conf, const ModelConfig &model_conf,
+    const knf::FbankOptions &fbank_opts)
     :
     :
 #if __ANDROID_API__ >= 9
 #if __ANDROID_API__ >= 9
       model_(Model::Create(mgr, model_conf)),
       model_(Model::Create(mgr, model_conf)),
@@ -65,7 +67,7 @@ Recognizer::Recognizer(
   }
   }
 }
 }
 
 
-void Recognizer::AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+void Recognizer::AcceptWaveform(float sample_rate, const float *input_buffer,
                                 int32_t frames_per_buffer) {
                                 int32_t frames_per_buffer) {
   decoder_->AcceptWaveform(sample_rate, input_buffer, frames_per_buffer);
   decoder_->AcceptWaveform(sample_rate, input_buffer, frames_per_buffer);
 }
 }

+ 16 - 5
sherpa-ncnn/csrc/recognizer.h

@@ -32,6 +32,7 @@
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
+// TODO(fangjun): Add timestamps
 struct RecognitionResult {
 struct RecognitionResult {
   std::vector<int32_t> tokens;
   std::vector<int32_t> tokens;
   std::string text;
   std::string text;
@@ -47,9 +48,19 @@ struct DecoderConfig {
 
 
   int32_t num_active_paths = 4;  // for modified beam search
   int32_t num_active_paths = 4;  // for modified beam search
 
 
-  bool use_endpoint = true;
+  bool enable_endpoint = false;
 
 
   EndpointConfig endpoint_config;
   EndpointConfig endpoint_config;
+
+  DecoderConfig() = default;
+
+  DecoderConfig(const std::string &method, int32_t num_active_paths,
+                bool enable_endpoint, const EndpointConfig &endpoint_config)
+      : method(method),
+        num_active_paths(num_active_paths),
+        enable_endpoint(enable_endpoint),
+        endpoint_config(endpoint_config) {}
+
   std::string ToString() const;
   std::string ToString() const;
 };
 };
 
 
@@ -57,7 +68,7 @@ class Decoder {
  public:
  public:
   virtual ~Decoder() = default;
   virtual ~Decoder() = default;
 
 
-  virtual void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+  virtual void AcceptWaveform(float sample_rate, const float *input_buffer,
                               int32_t frames_per_buffer) = 0;
                               int32_t frames_per_buffer) = 0;
 
 
   virtual void Decode() = 0;
   virtual void Decode() = 0;
@@ -81,12 +92,12 @@ class Recognizer {
 #if __ANDROID_API__ >= 9
 #if __ANDROID_API__ >= 9
       AAssetManager *mgr,
       AAssetManager *mgr,
 #endif
 #endif
-      const DecoderConfig decoder_conf, const ModelConfig model_conf,
-      const knf::FbankOptions fbank_opts);
+      const DecoderConfig &decoder_conf, const ModelConfig &model_conf,
+      const knf::FbankOptions &fbank_opts);
 
 
   ~Recognizer() = default;
   ~Recognizer() = default;
 
 
-  void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+  void AcceptWaveform(float sample_rate, const float *input_buffer,
                       int32_t frames_per_buffer);
                       int32_t frames_per_buffer);
 
 
   void Decode();
   void Decode();

+ 2 - 2
sherpa-ncnn/jni/jni.cc

@@ -144,8 +144,8 @@ static DecoderConfig GetDecoderConfig(JNIEnv *env, jobject config) {
   fid = env->GetFieldID(cls, "numActivePaths", "I");
   fid = env->GetFieldID(cls, "numActivePaths", "I");
   decoder_config.num_active_paths = env->GetIntField(config, fid);
   decoder_config.num_active_paths = env->GetIntField(config, fid);
 
 
-  fid = env->GetFieldID(cls, "useEndpoint", "Z");
-  decoder_config.use_endpoint = env->GetBooleanField(config, fid);
+  fid = env->GetFieldID(cls, "enableEndpoint", "Z");
+  decoder_config.enable_endpoint = env->GetBooleanField(config, fid);
 
 
   fid = env->GetFieldID(cls, "endpointConfig",
   fid = env->GetFieldID(cls, "endpointConfig",
                         "Lcom/k2fsa/sherpa/ncnn/EndpointConfig;");
                         "Lcom/k2fsa/sherpa/ncnn/EndpointConfig;");

+ 2 - 3
sherpa-ncnn/python/csrc/CMakeLists.txt

@@ -1,10 +1,9 @@
 
 
 include_directories(${PROJECT_SOURCE_DIR})
 include_directories(${PROJECT_SOURCE_DIR})
 set(srcs
 set(srcs
-  decode.cc
-  features.cc
-  mat-util.cc
+  endpoint.cc
   model.cc
   model.cc
+  recognizer.cc
   sherpa-ncnn.cc
   sherpa-ncnn.cc
 )
 )
 
 

+ 0 - 46
sherpa-ncnn/python/csrc/decode.cc

@@ -1,46 +0,0 @@
-/**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
- *
- * See LICENSE for clarification regarding multiple authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "sherpa-ncnn/python/csrc/decode.h"
-
-#include "sherpa-ncnn/csrc/decode.h"
-#include "sherpa-ncnn/csrc/model.h"
-#include "sherpa-ncnn/python/csrc/mat-util.h"
-
-namespace sherpa_ncnn {
-
-static void PybindGreedySearch(py::module *m) {
-  m->def(
-      "greedy_search",
-      [](Model *model, py::array _encoder_out, py::array _decoder_out,
-         std::vector<int32_t> hyp)
-          -> std::pair<py::array, std::vector<int32_t>> {
-        ncnn::Mat encoder_out = ArrayToMat(_encoder_out);
-        ncnn::Mat decoder_out = ArrayToMat(_decoder_out);
-
-        GreedySearch(model, encoder_out, &decoder_out, &hyp);
-
-        return {MatToArray(decoder_out), hyp};
-      },
-      py::arg("model"), py::arg("encoder_out"), py::arg("decoder_out"),
-      py::arg("hyp"));
-}
-
-void PybindDecode(py::module *m) { PybindGreedySearch(m); }
-
-}  // namespace sherpa_ncnn

+ 114 - 0
sherpa-ncnn/python/csrc/endpoint.cc

@@ -0,0 +1,114 @@
+/**
+ * Copyright (c)  2023  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/python/csrc/endpoint.h"
+
+#include <memory>
+#include <string>
+
+#include "sherpa-ncnn/csrc/endpoint.h"
+
+namespace sherpa_ncnn {
+
+static constexpr const char *kEndpointRuleInitDoc = R"doc(
+Constructor for EndpointRule.
+
+Args:
+  must_contain_nonsilence:
+    If True, for this endpointing rule to apply there must be nonsilence in the
+    best-path traceback. For decoding, a non-blank token is considered as
+    non-silence.
+  min_trailing_silence:
+    This endpointing rule requires duration of trailing silence (in seconds)
+    to be ``>=`` this value.
+  min_utterance_length:
+    This endpointing rule requires utterance-length (in seconds) to
+    be ``>=`` this value.
+)doc";
+
+static constexpr const char *kEndpointConfigInitDoc = R"doc(
+If any rule in EndpointConfig is activated, it is said that an endpointing
+is detected.
+
+Args:
+  rule1:
+    By default, it times out after 2.4 seconds of silence, even if
+    we decoded nothing.
+  rule2:
+    By default, it times out after 1.2 seconds of silence after decoding
+    something.
+  rule3:
+    By default, it times out after the utterance is 20 seconds long, regardless of
+    anything else.
+)doc";
+
+static void PybindEndpointRule(py::module *m) {
+  using PyClass = EndpointRule;
+  py::class_<PyClass>(*m, "EndpointRule")
+      .def(py::init<bool, float, float>(), py::arg("must_contain_nonsilence"),
+           py::arg("min_trailing_silence"), py::arg("min_utterance_length"),
+           kEndpointRuleInitDoc)
+      .def("__str__", &PyClass::ToString)
+      .def_readwrite("must_contain_nonsilence",
+                     &PyClass::must_contain_nonsilence)
+      .def_readwrite("min_trailing_silence", &PyClass::min_trailing_silence)
+      .def_readwrite("min_utterance_length", &PyClass::min_utterance_length);
+}
+
+static void PybindEndpointConfig(py::module *m) {
+  using PyClass = EndpointConfig;
+  py::class_<PyClass>(*m, "EndpointConfig")
+      .def(
+          py::init(
+              [](float rule1_min_trailing_silence,
+                 float rule2_min_trailing_silence,
+                 float rule3_min_utterance_length) -> std::unique_ptr<PyClass> {
+                EndpointRule rule1(false, rule1_min_trailing_silence, 0);
+                EndpointRule rule2(true, rule2_min_trailing_silence, 0);
+                EndpointRule rule3(false, 0, rule3_min_utterance_length);
+
+                return std::make_unique<EndpointConfig>(rule1, rule2, rule3);
+              }),
+          py::arg("rule1_min_trailing_silence"),
+          py::arg("rule2_min_trailing_silence"),
+          py::arg("rule3_min_utterance_length"))
+      .def(py::init([](const EndpointRule &rule1, const EndpointRule &rule2,
+                       const EndpointRule &rule3) -> std::unique_ptr<PyClass> {
+             auto ans = std::make_unique<PyClass>();
+             ans->rule1 = rule1;
+             ans->rule2 = rule2;
+             ans->rule3 = rule3;
+             return ans;
+           }),
+           py::arg("rule1") = EndpointRule(false, 2.4, 0),
+           py::arg("rule2") = EndpointRule(true, 1.2, 0),
+           py::arg("rule3") = EndpointRule(false, 0, 20),
+           kEndpointConfigInitDoc)
+      .def("__str__",
+           [](const PyClass &self) -> std::string { return self.ToString(); })
+      .def_readwrite("rule1", &PyClass::rule1)
+      .def_readwrite("rule2", &PyClass::rule2)
+      .def_readwrite("rule3", &PyClass::rule3);
+}
+
+void PybindEndpoint(py::module *m) {
+  PybindEndpointRule(m);
+  PybindEndpointConfig(m);
+}
+
+}  // namespace sherpa_ncnn

+ 5 - 5
sherpa-ncnn/python/csrc/decode.h → sherpa-ncnn/python/csrc/endpoint.h

@@ -1,5 +1,5 @@
 /**
 /**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2023  Xiaomi Corporation (authors: Fangjun Kuang)
  *
  *
  * See LICENSE for clarification regarding multiple authors
  * See LICENSE for clarification regarding multiple authors
  *
  *
@@ -16,15 +16,15 @@
  * limitations under the License.
  * limitations under the License.
  */
  */
 
 
-#ifndef SHERPA_NCNN_PYTHON_CSRC_DECODE_H_
-#define SHERPA_NCNN_PYTHON_CSRC_DECODE_H_
+#ifndef SHERPA_NCNN_PYTHON_CSRC_ENDPOINT_H_
+#define SHERPA_NCNN_PYTHON_CSRC_ENDPOINT_H_
 
 
 #include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
 #include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
-void PybindDecode(py::module *m);
+void PybindEndpoint(py::module *m);
 
 
 }  // namespace sherpa_ncnn
 }  // namespace sherpa_ncnn
 
 
-#endif  // SHERPA_NCNN_PYTHON_CSRC_DECODE_H_
+#endif  // SHERPA_NCNN_PYTHON_CSRC_ENDPOINT_H_

+ 0 - 56
sherpa-ncnn/python/csrc/features.cc

@@ -1,56 +0,0 @@
-/**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
- *
- * See LICENSE for clarification regarding multiple authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "sherpa-ncnn/csrc/features.h"
-
-#include "sherpa-ncnn/python/csrc/mat-util.h"
-#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
-
-namespace sherpa_ncnn {
-
-void PybindFeatures(py::module *m) {
-  using PyClass = FeatureExtractor;
-
-  py::class_<PyClass>(*m, "FeatureExtractor")
-      .def(py::init([](int32_t feature_dim,
-                       float sample_rate) -> std::unique_ptr<PyClass> {
-             knf::FbankOptions fbank_opts;
-             fbank_opts.frame_opts.dither = 0;
-             fbank_opts.frame_opts.snip_edges = false;
-             fbank_opts.frame_opts.samp_freq = sample_rate;
-             fbank_opts.mel_opts.num_bins = feature_dim;
-
-             return std::make_unique<PyClass>(fbank_opts);
-           }),
-           py::arg("feature_dim"), py::arg("sample_rate"))
-      .def("accept_waveform",
-           [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
-             self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
-           })
-      .def("input_finished", &PyClass::InputFinished)
-      .def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
-      .def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
-      .def("get_frames",
-           [](PyClass &self, int32_t frame_index, int32_t n) -> py::array {
-             ncnn::Mat frames = self.GetFrames(frame_index, n);
-             return MatToArray(frames);
-           })
-      .def("reset", &PyClass::Reset);
-}
-
-}  // namespace sherpa_ncnn

+ 0 - 97
sherpa-ncnn/python/csrc/mat-util.cc

@@ -1,97 +0,0 @@
-/**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
- *
- * See LICENSE for clarification regarding multiple authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "sherpa-ncnn/python/csrc/mat-util.h"
-
-namespace sherpa_ncnn {
-
-struct KeepMatAlive {
-  explicit KeepMatAlive(ncnn::Mat m) : m(m) {}
-
-  ncnn::Mat m;
-};
-
-py::array_t<float> MatToArray(ncnn::Mat m) {
-  std::vector<py::ssize_t> shape;
-  std::vector<py::ssize_t> strides;
-  if (m.dims == 1) {
-    shape.push_back(m.w);
-    strides.push_back(m.elemsize);
-  } else if (m.dims == 2) {
-    shape.push_back(m.h);
-    shape.push_back(m.w);
-    strides.push_back(m.w * m.elemsize);
-    strides.push_back(m.elemsize);
-  } else if (m.dims == 3) {
-    shape.push_back(m.c);
-    shape.push_back(m.h);
-    shape.push_back(m.w);
-    strides.push_back(m.cstep * m.elemsize);
-    strides.push_back(m.w * m.elemsize);
-    strides.push_back(m.elemsize);
-  } else if (m.dims == 4) {
-    shape.push_back(m.c);
-    shape.push_back(m.d);
-    shape.push_back(m.h);
-    shape.push_back(m.w);
-    strides.push_back(m.cstep * m.elemsize);
-    strides.push_back(m.w * m.h * m.elemsize);
-    strides.push_back(m.w * m.elemsize);
-    strides.push_back(m.elemsize);
-  }
-
-  auto keep_mat_alive = new KeepMatAlive(m);
-  py::capsule handle(keep_mat_alive, [](void *p) {
-    delete reinterpret_cast<KeepMatAlive *>(p);
-  });
-
-  return py::array_t<float>(shape, strides, (float *)m.data, handle);
-}
-
-ncnn::Mat ArrayToMat(py::array array) {
-  py::buffer_info info = array.request();
-  size_t elemsize = info.itemsize;
-
-  ncnn::Mat ans;
-
-  if (info.ndim == 1) {
-    ans = ncnn::Mat((int)info.shape[0], info.ptr, elemsize);
-  } else if (info.ndim == 2) {
-    ans = ncnn::Mat((int)info.shape[1], (int)info.shape[0], info.ptr, elemsize);
-  } else if (info.ndim == 3) {
-    ans = ncnn::Mat((int)info.shape[2], (int)info.shape[1], (int)info.shape[0],
-                    info.ptr, elemsize);
-
-    // in ncnn, buffer to construct ncnn::Mat need align to ncnn::alignSize
-    // with (w * h * elemsize, 16) / elemsize, but the buffer from numpy not
-    // so we set the cstep as numpy's cstep
-    ans.cstep = (int)info.shape[2] * (int)info.shape[1];
-  } else if (info.ndim == 4) {
-    ans = ncnn::Mat((int)info.shape[3], (int)info.shape[2], (int)info.shape[1],
-                    (int)info.shape[0], info.ptr, elemsize);
-
-    // in ncnn, buffer to construct ncnn::Mat need align to ncnn::alignSize
-    // with (w * h * d elemsize, 16) / elemsize, but the buffer from numpy not
-    // so we set the cstep as numpy's cstep
-    ans.cstep = (int)info.shape[3] * (int)info.shape[2] * (int)info.shape[1];
-  }
-
-  return ans;
-}
-
-}  // namespace sherpa_ncnn

+ 0 - 37
sherpa-ncnn/python/csrc/mat-util.h

@@ -1,37 +0,0 @@
-/**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
- *
- * See LICENSE for clarification regarding multiple authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef SHERPA_NCNN_PYTHON_CSRC_MAT_UTIL_H_
-#define SHERPA_NCNN_PYTHON_CSRC_MAT_UTIL_H_
-
-#include "mat.h"
-#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
-
-namespace sherpa_ncnn {
-
-// Convert a ncnn::Mat to a numpy array. Data is shared.
-//
-// @param m It should be a float unpacked matrix
-py::array_t<float> MatToArray(ncnn::Mat m);
-
-// convert an array to a ncnn::Mat
-ncnn::Mat ArrayToMat(py::array array);
-
-}  // namespace sherpa_ncnn
-
-#endif  // SHERPA_NCNN_PYTHON_CSRC_MODEL_UTIL_H_

+ 7 - 62
sherpa-ncnn/python/csrc/model.cc

@@ -22,7 +22,6 @@
 #include <string>
 #include <string>
 
 
 #include "sherpa-ncnn/csrc/model.h"
 #include "sherpa-ncnn/csrc/model.h"
-#include "sherpa-ncnn/python/csrc/mat-util.h"
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
@@ -48,6 +47,8 @@ Args:
     Path to joiner.ncnn.bin.
     Path to joiner.ncnn.bin.
   num_threads:
   num_threads:
     Number of threads to use for neural network computation.
     Number of threads to use for neural network computation.
+  tokens:
+    Path to tokens.txt
 )doc";
 )doc";
 
 
 static void PybindModelConfig(py::module *m) {
 static void PybindModelConfig(py::module *m) {
@@ -58,8 +59,8 @@ static void PybindModelConfig(py::module *m) {
                        const std::string &decoder_param,
                        const std::string &decoder_param,
                        const std::string &decoder_bin,
                        const std::string &decoder_bin,
                        const std::string &joiner_param,
                        const std::string &joiner_param,
-                       const std::string &joiner_bin,
-                       int32_t num_threads) -> std::unique_ptr<PyClass> {
+                       const std::string &joiner_bin, int32_t num_threads,
+                       const std::string &tokens) -> std::unique_ptr<PyClass> {
              auto ans = std::make_unique<PyClass>();
              auto ans = std::make_unique<PyClass>();
              ans->encoder_param = encoder_param;
              ans->encoder_param = encoder_param;
              ans->encoder_bin = encoder_bin;
              ans->encoder_bin = encoder_bin;
@@ -67,6 +68,7 @@ static void PybindModelConfig(py::module *m) {
              ans->decoder_bin = decoder_bin;
              ans->decoder_bin = decoder_bin;
              ans->joiner_param = joiner_param;
              ans->joiner_param = joiner_param;
              ans->joiner_bin = joiner_bin;
              ans->joiner_bin = joiner_bin;
+             ans->tokens = tokens;
 
 
              ans->use_vulkan_compute = false;
              ans->use_vulkan_compute = false;
 
 
@@ -79,66 +81,9 @@ static void PybindModelConfig(py::module *m) {
            py::arg("encoder_param"), py::arg("encoder_bin"),
            py::arg("encoder_param"), py::arg("encoder_bin"),
            py::arg("decoder_param"), py::arg("decoder_bin"),
            py::arg("decoder_param"), py::arg("decoder_bin"),
            py::arg("joiner_param"), py::arg("joiner_bin"),
            py::arg("joiner_param"), py::arg("joiner_bin"),
-           py::arg("num_threads"), kModelConfigInitDoc);
+           py::arg("num_threads"), py::arg("tokens"), kModelConfigInitDoc);
 }
 }
 
 
-void PybindModel(py::module *m) {
-  PybindModelConfig(m);
-
-  using PyClass = Model;
-  py::class_<PyClass>(*m, "Model")
-      .def_static("create", &PyClass::Create, py::arg("config"))
-      .def(
-          "run_encoder",
-          [](PyClass &self, py::array _features,
-             const std::vector<py::array> &_states)
-              -> std::pair<py::array, std::vector<py::array>> {
-            ncnn::Mat features = ArrayToMat(_features);
-
-            std::vector<ncnn::Mat> states;
-            states.reserve(_states.size());
-            for (const auto &s : _states) {
-              states.push_back(ArrayToMat(s));
-            }
-
-            ncnn::Mat encoder_out;
-            std::vector<ncnn::Mat> _next_states;
-
-            std::tie(encoder_out, _next_states) =
-                self.RunEncoder(features, states);
-
-            std::vector<py::array> next_states;
-            next_states.reserve(_next_states.size());
-            for (const auto &m : _next_states) {
-              next_states.push_back(MatToArray(m));
-            }
-
-            return std::make_pair(MatToArray(encoder_out), next_states);
-          },
-          py::arg("features"), py::arg("states"))
-      .def(
-          "run_decoder",
-          [](PyClass &self, py::array _decoder_input) -> py::array {
-            ncnn::Mat decoder_input = ArrayToMat(_decoder_input);
-            ncnn::Mat decoder_out = self.RunDecoder(decoder_input);
-            return MatToArray(decoder_out);
-          },
-          py::arg("decoder_input"))
-      .def(
-          "run_joiner",
-          [](PyClass &self, py::array _encoder_out,
-             py::array _decoder_out) -> py::array {
-            ncnn::Mat encoder_out = ArrayToMat(_encoder_out);
-            ncnn::Mat decoder_out = ArrayToMat(_decoder_out);
-            ncnn::Mat joiner_out = self.RunJoiner(encoder_out, decoder_out);
-
-            return MatToArray(joiner_out);
-          },
-          py::arg("encoder_out"), py::arg("decoder_out"))
-      .def_property_readonly("context_size", &PyClass::ContextSize)
-      .def_property_readonly("blank_id", &PyClass::BlankId)
-      .def_property_readonly("segment", &PyClass::Segment)
-      .def_property_readonly("offset", &PyClass::Offset);
-}
+void PybindModel(py::module *m) { PybindModelConfig(m); }
 
 
 }  // namespace sherpa_ncnn
 }  // namespace sherpa_ncnn

+ 104 - 0
sherpa-ncnn/python/csrc/recognizer.cc

@@ -0,0 +1,104 @@
+/**
+ * Copyright (c)  2023  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/python/csrc/recognizer.h"
+
+#include <memory>
+#include <string>
+
+#include "sherpa-ncnn/csrc/recognizer.h"
+
+namespace sherpa_ncnn {
+
+static constexpr const char *kDecoderConfigInitDoc = R"doc(
+Constructor for DecoderConfig.
+
+Args:
+  method:
+    Decoding method. Supported values are: greedy_search, modified_beam_search.
+  num_active_paths:
+    Used only when method is modified_beam_search. It specifies the number of
+    actives paths during beam search.
+  enable_endpoint:
+    True to enable endpoint detection. False to disable endpoint detection.
+  endpoint_config:
+    Used only when ``enable_endpoint`` is True.
+)doc";
+
+static void PybindRecognitionResult(py::module *m) {
+  using PyClass = RecognitionResult;
+  py::class_<PyClass>(*m, "RecognitionResult")
+      .def_property_readonly(
+          "text", [](PyClass &self) -> std::string { return self.text; });
+}
+
+static void PybindDecoderConfig(py::module *m) {
+  using PyClass = DecoderConfig;
+  py::class_<PyClass>(*m, "DecoderConfig")
+      .def(py::init<const std::string &, int32_t, bool,
+                    const EndpointConfig &>(),
+           py::arg("method"), py::arg("num_active_paths"),
+           py::arg("enable_endpoint"), py::arg("endpoint_config"),
+           kDecoderConfigInitDoc)
+      .def("__str__", &PyClass::ToString)
+      .def_property_readonly("method",
+                             [](const PyClass &self) { return self.method; })
+      .def_property_readonly(
+          "num_active_paths",
+          [](const PyClass &self) { return self.num_active_paths; })
+      .def_property_readonly(
+          "enable_endpoint",
+          [](const PyClass &self) { return self.enable_endpoint; })
+      .def_property_readonly("endpoint_config", [](const PyClass &self) {
+        return self.endpoint_config;
+      });
+}
+
+void PybindRecognizer(py::module *m) {
+  PybindRecognitionResult(m);
+  PybindDecoderConfig(m);
+
+  using PyClass = Recognizer;
+  py::class_<PyClass>(*m, "Recognizer")
+      .def(py::init([](const DecoderConfig &decoder_config,
+                       const ModelConfig &model_config,
+                       float sample_rate = 16000) -> std::unique_ptr<PyClass> {
+             knf::FbankOptions fbank_opts;
+             fbank_opts.frame_opts.dither = 0;
+             fbank_opts.frame_opts.snip_edges = false;
+             fbank_opts.frame_opts.samp_freq = sample_rate;
+             fbank_opts.mel_opts.num_bins = 80;
+
+             return std::make_unique<PyClass>(decoder_config, model_config,
+                                              fbank_opts);
+           }),
+           py::arg("decoder_config"), py::arg("model_config"),
+           py::arg("sample_rate") = 16000)
+      .def("accept_waveform",
+           [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
+             self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
+           })
+      .def("input_finished", &PyClass::InputFinished)
+      .def("decode", &PyClass::Decode)
+      .def_property_readonly("result",
+                             [](PyClass &self) { return self.GetResult(); })
+      .def("is_endpoint", &PyClass::IsEndpoint)
+      .def("reset", &PyClass::Reset);
+}
+
+}  // namespace sherpa_ncnn

+ 5 - 5
sherpa-ncnn/python/csrc/features.h → sherpa-ncnn/python/csrc/recognizer.h

@@ -1,5 +1,5 @@
 /**
 /**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2023  Xiaomi Corporation (authors: Fangjun Kuang)
  *
  *
  * See LICENSE for clarification regarding multiple authors
  * See LICENSE for clarification regarding multiple authors
  *
  *
@@ -16,15 +16,15 @@
  * limitations under the License.
  * limitations under the License.
  */
  */
 
 
-#ifndef SHERPA_NCNN_PYTHON_CSRC_FEATURES_H_
-#define SHERPA_NCNN_PYTHON_CSRC_FEATURES_H_
+#ifndef SHERPA_NCNN_PYTHON_CSRC_RECOGNIZER_H_
+#define SHERPA_NCNN_PYTHON_CSRC_RECOGNIZER_H_
 
 
 #include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
 #include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
-void PybindFeatures(py::module *m);
+void PybindRecognizer(py::module *m);
 
 
 }  // namespace sherpa_ncnn
 }  // namespace sherpa_ncnn
 
 
-#endif  // SHERPA_NCNN_PYTHON_CSRC_FEATURES_H_
+#endif  // SHERPA_NCNN_PYTHON_CSRC_RECOGNIZER_H_

+ 4 - 6
sherpa-ncnn/python/csrc/sherpa-ncnn.cc

@@ -18,20 +18,18 @@
 
 
 #include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
 #include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
 
 
-#include "sherpa-ncnn/python/csrc/decode.h"
-#include "sherpa-ncnn/python/csrc/features.h"
+#include "sherpa-ncnn/python/csrc/endpoint.h"
 #include "sherpa-ncnn/python/csrc/model.h"
 #include "sherpa-ncnn/python/csrc/model.h"
+#include "sherpa-ncnn/python/csrc/recognizer.h"
 
 
 namespace sherpa_ncnn {
 namespace sherpa_ncnn {
 
 
 PYBIND11_MODULE(_sherpa_ncnn, m) {
 PYBIND11_MODULE(_sherpa_ncnn, m) {
   m.doc() = "pybind11 binding of sherpa-ncnn";
   m.doc() = "pybind11 binding of sherpa-ncnn";
 
 
+  PybindEndpoint(&m);
   PybindModel(&m);
   PybindModel(&m);
-
-  PybindFeatures(&m);
-
-  PybindDecode(&m);
+  PybindRecognizer(&m);
 }
 }
 
 
 }  // namespace sherpa_ncnn
 }  // namespace sherpa_ncnn

+ 0 - 2
sherpa-ncnn/python/sherpa_ncnn/__init__.py

@@ -1,3 +1 @@
-from _sherpa_ncnn import FeatureExtractor, Model, ModelConfig, greedy_search
-
 from .recognizer import Recognizer
 from .recognizer import Recognizer

+ 66 - 52
sherpa-ncnn/python/sherpa_ncnn/recognizer.py

@@ -1,24 +1,19 @@
 from pathlib import Path
 from pathlib import Path
 
 
 import numpy as np
 import numpy as np
-from _sherpa_ncnn import FeatureExtractor, Model, ModelConfig, greedy_search
+from _sherpa_ncnn import (
+    DecoderConfig,
+    EndpointConfig,
+    EndpointRule,
+    ModelConfig,
+)
+from _sherpa_ncnn import Recognizer as _Recognizer
 
 
 
 
 def _assert_file_exists(f: str):
 def _assert_file_exists(f: str):
     assert Path(f).is_file(), f"{f} does not exist"
     assert Path(f).is_file(), f"{f} does not exist"
 
 
 
 
-def _read_tokens(tokens):
-    sym_table = {}
-    with open(tokens, "r", encoding="utf-8") as f:
-        for line in f:
-            sym, i = line.split()
-            sym = sym.replace("▁", " ")
-            sym_table[int(i)] = sym
-
-    return sym_table
-
-
 class Recognizer(object):
 class Recognizer(object):
     """A class for streaming speech recognition.
     """A class for streaming speech recognition.
 
 
@@ -88,6 +83,12 @@ class Recognizer(object):
         joiner_param: str,
         joiner_param: str,
         joiner_bin: str,
         joiner_bin: str,
         num_threads: int = 4,
         num_threads: int = 4,
+        decoding_method: str = "greedy_search",
+        num_active_paths: int = 4,
+        enable_endpoint_detection: bool = False,
+        rule1_min_trailing_silence: int = 2.4,
+        rule2_min_trailing_silence: int = 1.2,
+        rule3_min_utterance_length: int = 20,
     ):
     ):
         """
         """
         Please refer to
         Please refer to
@@ -101,6 +102,7 @@ class Recognizer(object):
             columns::
             columns::
 
 
                 symbol integer_id
                 symbol integer_id
+
           encoder_param:
           encoder_param:
             Path to ``encoder.ncnn.param``.
             Path to ``encoder.ncnn.param``.
           encoder_bin:
           encoder_bin:
@@ -115,6 +117,28 @@ class Recognizer(object):
             Path to ``joiner.ncnn.bin``.
             Path to ``joiner.ncnn.bin``.
           num_threads:
           num_threads:
             Number of threads for neural network computation.
             Number of threads for neural network computation.
+          decoding_method:
+            Valid decoding methods are: greedy_search, modified_beam_search.
+          num_active_paths:
+            Used only when decoding_method is modified_beam_search. Its value
+            is ignored when decoding_method is greedy_search. It specifies
+            the maximum number of paths to use in beam search.
+          enable_endpoint_detection:
+            True to enable endpoint detection. False to disable endpoint
+            detection.
+          rule1_min_trailing_silence:
+            Used only when enable_endpoint_detection is True. If the duration
+            of trailing silence in seconds is larger than this value, we assume
+            an endpoint is detected.
+          rule2_min_trailing_silence:
+            Used only when enable_endpoint_detection is True. If we have decoded
+            something that is nonsilence and if the duration of trailing silence
+            in seconds is larger than this value, we assume an endpoint is
+            detected.
+          rule3_min_utterance_length:
+            Used only when enable_endpoint_detection is True. If the utterance
+            length in seconds is larger than this value, we assume an endpoint
+            is detected.
         """
         """
         _assert_file_exists(tokens)
         _assert_file_exists(tokens)
         _assert_file_exists(encoder_param)
         _assert_file_exists(encoder_param)
@@ -125,8 +149,10 @@ class Recognizer(object):
         _assert_file_exists(joiner_bin)
         _assert_file_exists(joiner_bin)
 
 
         assert num_threads > 0, num_threads
         assert num_threads > 0, num_threads
-
-        self.sym_table = _read_tokens(tokens)
+        assert decoding_method in (
+            "greedy_search",
+            "modified_beam_search",
+        ), decoding_method
 
 
         model_config = ModelConfig(
         model_config = ModelConfig(
             encoder_param=encoder_param,
             encoder_param=encoder_param,
@@ -136,23 +162,30 @@ class Recognizer(object):
             joiner_param=joiner_param,
             joiner_param=joiner_param,
             joiner_bin=joiner_bin,
             joiner_bin=joiner_bin,
             num_threads=num_threads,
             num_threads=num_threads,
+            tokens=tokens,
         )
         )
 
 
-        self.model = Model.create(model_config)
-        self.sample_rate = 16000
-
-        self.feature_extractor = FeatureExtractor(
-            feature_dim=80,
-            sample_rate=self.sample_rate,
+        endpoint_config = EndpointConfig(
+            rule1_min_trailing_silence=rule1_min_trailing_silence,
+            rule2_min_trailing_silence=rule2_min_trailing_silence,
+            rule3_min_utterance_length=rule3_min_utterance_length,
         )
         )
 
 
-        self.num_processed = 0  # number of processed feature frames so far
-        self.states = []  # model state
+        decoder_config = DecoderConfig(
+            method=decoding_method,
+            num_active_paths=num_active_paths,
+            enable_endpoint=enable_endpoint_detection,
+            endpoint_config=endpoint_config,
+        )
 
 
-        self.hyp = [0] * self.model.context_size  # initial hypothesis
+        # all of our current models are using 16 kHz audio samples
+        self.sample_rate = 16000
 
 
-        decoder_input = np.array(self.hyp, dtype=np.int32)
-        self.decoder_out = self.model.run_decoder(decoder_input)
+        self.recognizer = _Recognizer(
+            decoder_config=decoder_config,
+            model_config=model_config,
+            sample_rate=self.sample_rate,
+        )
 
 
     def accept_waveform(self, sample_rate: float, waveform: np.array):
     def accept_waveform(self, sample_rate: float, waveform: np.array):
         """Decode audio samples.
         """Decode audio samples.
@@ -165,37 +198,18 @@ class Recognizer(object):
             range ``[-1, 1]``.
             range ``[-1, 1]``.
         """
         """
         assert sample_rate == self.sample_rate, (sample_rate, self.sample_rate)
         assert sample_rate == self.sample_rate, (sample_rate, self.sample_rate)
-        self.feature_extractor.accept_waveform(sample_rate, waveform)
-
-        self._decode()
+        self.recognizer.accept_waveform(sample_rate, waveform)
+        self.recognizer.decode()
 
 
     def input_finished(self):
     def input_finished(self):
         """Signal that no more audio samples are available."""
         """Signal that no more audio samples are available."""
-        self.feature_extractor.input_finished()
-        self._decode()
+        self.recognizer.input_finished()
+        self.recognizer.decode()
 
 
     @property
     @property
     def text(self):
     def text(self):
-        context_size = self.model.context_size
-        text = [self.sym_table[token] for token in self.hyp[context_size:]]
-        return "".join(text)
-
-    def _decode(self):
-        segment = self.model.segment
-        offset = self.model.offset
-
-        while self.feature_extractor.num_frames_ready - self.num_processed >= segment:
-            features = self.feature_extractor.get_frames(self.num_processed, segment)
-            self.num_processed += offset
+        return self.recognizer.result.text
 
 
-            encoder_out, self.states = self.model.run_encoder(
-                features=features,
-                states=self.states,
-            )
-
-            self.decoder_out, self.hyp = greedy_search(
-                model=self.model,
-                encoder_out=encoder_out,
-                decoder_out=self.decoder_out,
-                hyp=self.hyp,
-            )
+    @property
+    def is_endpoint(self):
+        return self.recognizer.is_endpoint()