|
@@ -1,24 +1,19 @@
|
|
|
from pathlib import Path
|
|
|
|
|
|
import numpy as np
|
|
|
-from _sherpa_ncnn import FeatureExtractor, Model, ModelConfig, greedy_search
|
|
|
+from _sherpa_ncnn import (
|
|
|
+ DecoderConfig,
|
|
|
+ EndpointConfig,
|
|
|
+ EndpointRule,
|
|
|
+ ModelConfig,
|
|
|
+)
|
|
|
+from _sherpa_ncnn import Recognizer as _Recognizer
|
|
|
|
|
|
|
|
|
def _assert_file_exists(f: str):
|
|
|
assert Path(f).is_file(), f"{f} does not exist"
|
|
|
|
|
|
|
|
|
-def _read_tokens(tokens):
|
|
|
- sym_table = {}
|
|
|
- with open(tokens, "r", encoding="utf-8") as f:
|
|
|
- for line in f:
|
|
|
- sym, i = line.split()
|
|
|
- sym = sym.replace("▁", " ")
|
|
|
- sym_table[int(i)] = sym
|
|
|
-
|
|
|
- return sym_table
|
|
|
-
|
|
|
-
|
|
|
class Recognizer(object):
|
|
|
"""A class for streaming speech recognition.
|
|
|
|
|
@@ -88,6 +83,12 @@ class Recognizer(object):
|
|
|
joiner_param: str,
|
|
|
joiner_bin: str,
|
|
|
num_threads: int = 4,
|
|
|
+ decoding_method: str = "greedy_search",
|
|
|
+ num_active_paths: int = 4,
|
|
|
+ enable_endpoint_detection: bool = False,
|
|
|
+ rule1_min_trailing_silence: int = 2.4,
|
|
|
+ rule2_min_trailing_silence: int = 1.2,
|
|
|
+ rule3_min_utterance_length: int = 20,
|
|
|
):
|
|
|
"""
|
|
|
Please refer to
|
|
@@ -101,6 +102,7 @@ class Recognizer(object):
|
|
|
columns::
|
|
|
|
|
|
symbol integer_id
|
|
|
+
|
|
|
encoder_param:
|
|
|
Path to ``encoder.ncnn.param``.
|
|
|
encoder_bin:
|
|
@@ -115,6 +117,28 @@ class Recognizer(object):
|
|
|
Path to ``joiner.ncnn.bin``.
|
|
|
num_threads:
|
|
|
Number of threads for neural network computation.
|
|
|
+ decoding_method:
|
|
|
+ Valid decoding methods are: greedy_search, modified_beam_search.
|
|
|
+ num_active_paths:
|
|
|
+ Used only when decoding_method is modified_beam_search. Its value
|
|
|
+ is ignored when decoding_method is greedy_search. It specifies
|
|
|
+ the maximum number of paths to use in beam search.
|
|
|
+ enable_endpoint_detection:
|
|
|
+ True to enable endpoint detection. False to disable endpoint
|
|
|
+ detection.
|
|
|
+ rule1_min_trailing_silence:
|
|
|
+ Used only when enable_endpoint_detection is True. If the duration
|
|
|
+ of trailing silence in seconds is larger than this value, we assume
|
|
|
+ an endpoint is detected.
|
|
|
+ rule2_min_trailing_silence:
|
|
|
+ Used only when enable_endpoint_detection is True. If we have decoded
|
|
|
+ something that is nonsilence and if the duration of trailing silence
|
|
|
+ in seconds is larger than this value, we assume an endpoint is
|
|
|
+ detected.
|
|
|
+ rule3_min_utterance_length:
|
|
|
+ Used only when enable_endpoint_detection is True. If the utterance
|
|
|
+ length in seconds is larger than this value, we assume an endpoint
|
|
|
+ is detected.
|
|
|
"""
|
|
|
_assert_file_exists(tokens)
|
|
|
_assert_file_exists(encoder_param)
|
|
@@ -125,8 +149,10 @@ class Recognizer(object):
|
|
|
_assert_file_exists(joiner_bin)
|
|
|
|
|
|
assert num_threads > 0, num_threads
|
|
|
-
|
|
|
- self.sym_table = _read_tokens(tokens)
|
|
|
+ assert decoding_method in (
|
|
|
+ "greedy_search",
|
|
|
+ "modified_beam_search",
|
|
|
+ ), decoding_method
|
|
|
|
|
|
model_config = ModelConfig(
|
|
|
encoder_param=encoder_param,
|
|
@@ -136,23 +162,30 @@ class Recognizer(object):
|
|
|
joiner_param=joiner_param,
|
|
|
joiner_bin=joiner_bin,
|
|
|
num_threads=num_threads,
|
|
|
+ tokens=tokens,
|
|
|
)
|
|
|
|
|
|
- self.model = Model.create(model_config)
|
|
|
- self.sample_rate = 16000
|
|
|
-
|
|
|
- self.feature_extractor = FeatureExtractor(
|
|
|
- feature_dim=80,
|
|
|
- sample_rate=self.sample_rate,
|
|
|
+ endpoint_config = EndpointConfig(
|
|
|
+ rule1_min_trailing_silence=rule1_min_trailing_silence,
|
|
|
+ rule2_min_trailing_silence=rule2_min_trailing_silence,
|
|
|
+ rule3_min_utterance_length=rule3_min_utterance_length,
|
|
|
)
|
|
|
|
|
|
- self.num_processed = 0 # number of processed feature frames so far
|
|
|
- self.states = [] # model state
|
|
|
+ decoder_config = DecoderConfig(
|
|
|
+ method=decoding_method,
|
|
|
+ num_active_paths=num_active_paths,
|
|
|
+ enable_endpoint=enable_endpoint_detection,
|
|
|
+ endpoint_config=endpoint_config,
|
|
|
+ )
|
|
|
|
|
|
- self.hyp = [0] * self.model.context_size # initial hypothesis
|
|
|
+ # all of our current models are using 16 kHz audio samples
|
|
|
+ self.sample_rate = 16000
|
|
|
|
|
|
- decoder_input = np.array(self.hyp, dtype=np.int32)
|
|
|
- self.decoder_out = self.model.run_decoder(decoder_input)
|
|
|
+ self.recognizer = _Recognizer(
|
|
|
+ decoder_config=decoder_config,
|
|
|
+ model_config=model_config,
|
|
|
+ sample_rate=self.sample_rate,
|
|
|
+ )
|
|
|
|
|
|
def accept_waveform(self, sample_rate: float, waveform: np.array):
|
|
|
"""Decode audio samples.
|
|
@@ -165,37 +198,18 @@ class Recognizer(object):
|
|
|
range ``[-1, 1]``.
|
|
|
"""
|
|
|
assert sample_rate == self.sample_rate, (sample_rate, self.sample_rate)
|
|
|
- self.feature_extractor.accept_waveform(sample_rate, waveform)
|
|
|
-
|
|
|
- self._decode()
|
|
|
+ self.recognizer.accept_waveform(sample_rate, waveform)
|
|
|
+ self.recognizer.decode()
|
|
|
|
|
|
def input_finished(self):
|
|
|
"""Signal that no more audio samples are available."""
|
|
|
- self.feature_extractor.input_finished()
|
|
|
- self._decode()
|
|
|
+ self.recognizer.input_finished()
|
|
|
+ self.recognizer.decode()
|
|
|
|
|
|
@property
|
|
|
def text(self):
|
|
|
- context_size = self.model.context_size
|
|
|
- text = [self.sym_table[token] for token in self.hyp[context_size:]]
|
|
|
- return "".join(text)
|
|
|
-
|
|
|
- def _decode(self):
|
|
|
- segment = self.model.segment
|
|
|
- offset = self.model.offset
|
|
|
-
|
|
|
- while self.feature_extractor.num_frames_ready - self.num_processed >= segment:
|
|
|
- features = self.feature_extractor.get_frames(self.num_processed, segment)
|
|
|
- self.num_processed += offset
|
|
|
+ return self.recognizer.result.text
|
|
|
|
|
|
- encoder_out, self.states = self.model.run_encoder(
|
|
|
- features=features,
|
|
|
- states=self.states,
|
|
|
- )
|
|
|
-
|
|
|
- self.decoder_out, self.hyp = greedy_search(
|
|
|
- model=self.model,
|
|
|
- encoder_out=encoder_out,
|
|
|
- decoder_out=self.decoder_out,
|
|
|
- hyp=self.hyp,
|
|
|
- )
|
|
|
+ @property
|
|
|
+ def is_endpoint(self):
|
|
|
+ return self.recognizer.is_endpoint()
|