|
@@ -21,11 +21,8 @@
|
|
|
#include <stdlib.h>
|
|
|
|
|
|
#include "portaudio.h" // NOLINT
|
|
|
-#include "sherpa-ncnn/csrc/decode.h"
|
|
|
-#include "sherpa-ncnn/csrc/features.h"
|
|
|
#include "sherpa-ncnn/csrc/microphone.h"
|
|
|
-#include "sherpa-ncnn/csrc/model.h"
|
|
|
-#include "sherpa-ncnn/csrc/symbol-table.h"
|
|
|
+#include "sherpa-ncnn/csrc/recognizer.h"
|
|
|
|
|
|
bool stop = false;
|
|
|
|
|
@@ -34,14 +31,14 @@ static int RecordCallback(const void *input_buffer, void * /*output_buffer*/,
|
|
|
const PaStreamCallbackTimeInfo * /*time_info*/,
|
|
|
PaStreamCallbackFlags /*status_flags*/,
|
|
|
void *user_data) {
|
|
|
- auto feature_extractor =
|
|
|
- reinterpret_cast<sherpa_ncnn::FeatureExtractor *>(user_data);
|
|
|
+ auto recognizer = reinterpret_cast<sherpa_ncnn::Recognizer *>(user_data);
|
|
|
|
|
|
- feature_extractor->AcceptWaveform(
|
|
|
+ recognizer->AcceptWaveform(
|
|
|
16000, reinterpret_cast<const float *>(input_buffer), frames_per_buffer);
|
|
|
|
|
|
return stop ? paComplete : paContinue;
|
|
|
}
|
|
|
+
|
|
|
static void Handler(int sig) {
|
|
|
stop = true;
|
|
|
fprintf(stderr, "\nexiting...\n");
|
|
@@ -71,44 +68,42 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
|
|
|
}
|
|
|
signal(SIGINT, Handler);
|
|
|
|
|
|
- sherpa_ncnn::ModelConfig config;
|
|
|
-
|
|
|
- config.tokens = argv[1];
|
|
|
- config.encoder_param = argv[2];
|
|
|
- config.encoder_bin = argv[3];
|
|
|
- config.decoder_param = argv[4];
|
|
|
- config.decoder_bin = argv[5];
|
|
|
- config.joiner_param = argv[6];
|
|
|
- config.joiner_bin = argv[7];
|
|
|
-
|
|
|
- int32_t num_threads = 4;
|
|
|
- if (argc == 9) {
|
|
|
+ sherpa_ncnn::ModelConfig model_conf;
|
|
|
+ model_conf.tokens = argv[1];
|
|
|
+ model_conf.encoder_param = argv[2];
|
|
|
+ model_conf.encoder_bin = argv[3];
|
|
|
+ model_conf.decoder_param = argv[4];
|
|
|
+ model_conf.decoder_bin = argv[5];
|
|
|
+ model_conf.joiner_param = argv[6];
|
|
|
+ model_conf.joiner_bin = argv[7];
|
|
|
+ int num_threads = 4;
|
|
|
+ if (argc >= 9 && atoi(argv[8]) > 0) {
|
|
|
num_threads = atoi(argv[8]);
|
|
|
}
|
|
|
-
|
|
|
- config.encoder_opt.num_threads = num_threads;
|
|
|
- config.decoder_opt.num_threads = num_threads;
|
|
|
- config.joiner_opt.num_threads = num_threads;
|
|
|
-
|
|
|
- sherpa_ncnn::SymbolTable sym(config.tokens);
|
|
|
- fprintf(stderr, "%s\n", config.ToString().c_str());
|
|
|
-
|
|
|
- auto model = sherpa_ncnn::Model::Create(config);
|
|
|
- if (!model) {
|
|
|
- fprintf(stderr, "Failed to create a model\n");
|
|
|
- exit(EXIT_FAILURE);
|
|
|
+ model_conf.encoder_opt.num_threads = num_threads;
|
|
|
+ model_conf.decoder_opt.num_threads = num_threads;
|
|
|
+ model_conf.joiner_opt.num_threads = num_threads;
|
|
|
+
|
|
|
+ fprintf(stderr, "%s\n", model_conf.ToString().c_str());
|
|
|
+
|
|
|
+ const float expected_sampling_rate = 16000;
|
|
|
+ sherpa_ncnn::DecoderConfig decoder_conf;
|
|
|
+ if (argc == 10) {
|
|
|
+ std::string method = argv[9];
|
|
|
+ if (method.compare("greed_search") ||
|
|
|
+ method.compare("modified_beam_search")) {
|
|
|
+ decoder_conf.method = method;
|
|
|
+ }
|
|
|
}
|
|
|
-
|
|
|
- float sample_rate = 16000;
|
|
|
- sherpa_ncnn::Microphone mic;
|
|
|
-
|
|
|
knf::FbankOptions fbank_opts;
|
|
|
fbank_opts.frame_opts.dither = 0;
|
|
|
fbank_opts.frame_opts.snip_edges = false;
|
|
|
- fbank_opts.frame_opts.samp_freq = sample_rate;
|
|
|
+ fbank_opts.frame_opts.samp_freq = expected_sampling_rate;
|
|
|
fbank_opts.mel_opts.num_bins = 80;
|
|
|
|
|
|
- sherpa_ncnn::FeatureExtractor feature_extractor(fbank_opts);
|
|
|
+ sherpa_ncnn::Recognizer recognizer(decoder_conf, model_conf, fbank_opts);
|
|
|
+
|
|
|
+ sherpa_ncnn::Microphone mic;
|
|
|
|
|
|
PaDeviceIndex num_devices = Pa_GetDeviceCount();
|
|
|
fprintf(stderr, "Num devices: %d\n", num_devices);
|
|
@@ -131,6 +126,7 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
|
|
|
|
|
|
param.suggestedLatency = info->defaultLowInputLatency;
|
|
|
param.hostApiSpecificStreamInfo = nullptr;
|
|
|
+ const float sample_rate = 16000;
|
|
|
|
|
|
PaStream *stream;
|
|
|
PaError err =
|
|
@@ -139,7 +135,7 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
|
|
|
0, // frames per buffer
|
|
|
paClipOff, // we won't output out of range samples
|
|
|
// so don't bother clipping them
|
|
|
- RecordCallback, &feature_extractor);
|
|
|
+ RecordCallback, &recognizer);
|
|
|
if (err != paNoError) {
|
|
|
fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
|
|
|
exit(EXIT_FAILURE);
|
|
@@ -153,47 +149,12 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
|
|
|
exit(EXIT_FAILURE);
|
|
|
}
|
|
|
|
|
|
- int32_t segment = model->Segment();
|
|
|
- int32_t offset = model->Offset();
|
|
|
-
|
|
|
- int32_t context_size = model->ContextSize();
|
|
|
- int32_t blank_id = model->BlankId();
|
|
|
-
|
|
|
- std::vector<int32_t> hyp(context_size, blank_id);
|
|
|
-
|
|
|
- ncnn::Mat decoder_input(context_size);
|
|
|
- for (int32_t i = 0; i != context_size; ++i) {
|
|
|
- static_cast<int32_t *>(decoder_input)[i] = blank_id;
|
|
|
- }
|
|
|
-
|
|
|
- ncnn::Mat decoder_out = model->RunDecoder(decoder_input);
|
|
|
-
|
|
|
- ncnn::Mat hx;
|
|
|
- ncnn::Mat cx;
|
|
|
-
|
|
|
- int32_t num_tokens = hyp.size();
|
|
|
- int32_t num_processed = 0;
|
|
|
-
|
|
|
- std::vector<ncnn::Mat> states;
|
|
|
- ncnn::Mat encoder_out;
|
|
|
-
|
|
|
+ int num_tokens = 0;
|
|
|
while (!stop) {
|
|
|
- while (feature_extractor.NumFramesReady() - num_processed >= segment) {
|
|
|
- ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
|
|
|
- num_processed += offset;
|
|
|
-
|
|
|
- std::tie(encoder_out, states) = model->RunEncoder(features, states);
|
|
|
-
|
|
|
- GreedySearch(model.get(), encoder_out, &decoder_out, &hyp);
|
|
|
- }
|
|
|
-
|
|
|
- if (hyp.size() != num_tokens) {
|
|
|
- num_tokens = hyp.size();
|
|
|
- std::string text;
|
|
|
- for (int32_t i = context_size; i != hyp.size(); ++i) {
|
|
|
- text += sym[hyp[i]];
|
|
|
- }
|
|
|
- fprintf(stderr, "%s\n", text.c_str());
|
|
|
+ recognizer.Decode();
|
|
|
+ auto result = recognizer.GetResult();
|
|
|
+ if (result.text.size() != num_tokens) {
|
|
|
+ fprintf(stderr, "%s\n", result.text.c_str());
|
|
|
}
|
|
|
|
|
|
Pa_Sleep(20); // sleep for 20ms
|