Prechádzať zdrojové kódy

Update python example for streaming recognition. (#136)

Also, we don't require wave_file.sample_rate == recognizer.sample_rate.
If they are different, we will do resampling inside sherpa-ncnn.
Fangjun Kuang 2 rokov pred
rodič
commit
48f7834eb3

+ 24 - 9
python-api-examples/decode-file.py

@@ -10,6 +10,7 @@ to install sherpa-ncnn and to download the pre-trained models
 used in this file.
 """
 
+import time
 import wave
 
 import numpy as np
@@ -17,6 +18,8 @@ import sherpa_ncnn
 
 
 def main():
+    # Please refer to https://k2-fsa.github.io/sherpa/ncnn/index.html
+    # to download the model files
     recognizer = sherpa_ncnn.Recognizer(
         tokens="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/tokens.txt",
         encoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.param",
@@ -30,10 +33,9 @@ def main():
 
     filename = "./sherpa-ncnn-conv-emformer-transducer-2022-12-06/test_wavs/1.wav"
     with wave.open(filename) as f:
-        assert f.getframerate() == recognizer.sample_rate, (
-            f.getframerate(),
-            recognizer.sample_rate,
-        )
+        # Note: If wave_file_sample_rate is different from
+        # recognizer.sample_rate, we will do resampling inside sherpa-ncnn
+        wave_file_sample_rate = f.getframerate()
         assert f.getnchannels() == 1, f.getnchannels()
         assert f.getsampwidth() == 2, f.getsampwidth()  # it is in bytes
         num_samples = f.getnframes()
@@ -43,14 +45,27 @@ def main():
 
         samples_float32 = samples_float32 / 32768
 
-    recognizer.accept_waveform(recognizer.sample_rate, samples_float32)
+    # simulate streaming
+    chunk_size = int(0.1 * wave_file_sample_rate)  # 0.1 seconds
+    start = 0
+    while start < samples_float32.shape[0]:
+        end = start + chunk_size
+        end = min(end, samples_float32.shape[0])
+        recognizer.accept_waveform(wave_file_sample_rate, samples_float32[start:end])
+        start = end
+        text = recognizer.text
+        if text:
+            print(text)
 
-    tail_paddings = np.zeros(int(recognizer.sample_rate * 0.5), dtype=np.float32)
-    recognizer.accept_waveform(recognizer.sample_rate, tail_paddings)
+        # simulate streaming by sleeping
+        time.sleep(0.1)
 
+    tail_paddings = np.zeros(int(wave_file_sample_rate * 0.5), dtype=np.float32)
+    recognizer.accept_waveform(wave_file_sample_rate, tail_paddings)
     recognizer.input_finished()
-
-    print(recognizer.text)
+    text = recognizer.text
+    if text:
+        print(text)
 
 
 if __name__ == "__main__":

+ 7 - 3
sherpa-ncnn/python/sherpa_ncnn/recognizer.py

@@ -91,6 +91,7 @@ class Recognizer(object):
         rule2_min_trailing_silence: int = 1.2,
         rule3_min_utterance_length: int = 20,
         max_feature_vectors: int = -1,
+        model_sample_rate: int = 16000,
     ):
         """
         Please refer to
@@ -144,6 +145,8 @@ class Recognizer(object):
           max_feature_vectors:
             It specifies the number of feature frames to cache. Use -1
             to cache all processed frames
+          model_sample_rate:
+            Sample rate expected by the model
         """
         _assert_file_exists(tokens)
         _assert_file_exists(encoder_param)
@@ -159,7 +162,7 @@ class Recognizer(object):
             "modified_beam_search",
         ), decoding_method
         feat_config = FeatureExtractorConfig(
-            sampling_rate=16000,
+            sampling_rate=model_sample_rate,
             feature_dim=80,
             max_feature_vectors=-1,
         )
@@ -204,12 +207,13 @@ class Recognizer(object):
 
         Args:
           sample_rate:
-            Sample rate of the input audio samples. It should be 16000.
+            Sample rate of the input audio samples. You must use the same
+            value across different calls to `accept_waveform`! If it
+            is different from self.sample_rate, we will do resampling inside.
           waveform:
             A 1-D float32 array containing audio samples in the
             range ``[-1, 1]``.
         """
-        assert sample_rate == self.sample_rate, (sample_rate, self.sample_rate)
         self.stream.accept_waveform(sample_rate, waveform)
         self._decode()