|
@@ -23,7 +23,7 @@
|
|
#include "net.h" // NOLINT
|
|
#include "net.h" // NOLINT
|
|
#include "sherpa-ncnn/csrc/decode.h"
|
|
#include "sherpa-ncnn/csrc/decode.h"
|
|
#include "sherpa-ncnn/csrc/features.h"
|
|
#include "sherpa-ncnn/csrc/features.h"
|
|
-#include "sherpa-ncnn/csrc/lstm-model.h"
|
|
|
|
|
|
+#include "sherpa-ncnn/csrc/model.h"
|
|
#include "sherpa-ncnn/csrc/symbol-table.h"
|
|
#include "sherpa-ncnn/csrc/symbol-table.h"
|
|
#include "sherpa-ncnn/csrc/wave-reader.h"
|
|
#include "sherpa-ncnn/csrc/wave-reader.h"
|
|
|
|
|
|
@@ -89,29 +89,35 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
|
|
|
|
|
|
return 0;
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
+ sherpa_ncnn::ModelConfig config;
|
|
|
|
+
|
|
std::string tokens = argv[1];
|
|
std::string tokens = argv[1];
|
|
- std::string encoder_param = argv[2];
|
|
|
|
- std::string encoder_bin = argv[3];
|
|
|
|
- std::string decoder_param = argv[4];
|
|
|
|
- std::string decoder_bin = argv[5];
|
|
|
|
- std::string joiner_param = argv[6];
|
|
|
|
- std::string joiner_bin = argv[7];
|
|
|
|
|
|
+
|
|
|
|
+ config.encoder_param = argv[2];
|
|
|
|
+ config.encoder_bin = argv[3];
|
|
|
|
+ config.decoder_param = argv[4];
|
|
|
|
+ config.decoder_bin = argv[5];
|
|
|
|
+ config.joiner_param = argv[6];
|
|
|
|
+ config.joiner_bin = argv[7];
|
|
|
|
+
|
|
std::string wav_filename = argv[8];
|
|
std::string wav_filename = argv[8];
|
|
|
|
|
|
- int32_t num_threads = 4;
|
|
|
|
|
|
+ config.num_threads = 4;
|
|
if (argc == 10) {
|
|
if (argc == 10) {
|
|
- num_threads = atoi(argv[9]);
|
|
|
|
|
|
+ config.num_threads = atoi(argv[9]);
|
|
}
|
|
}
|
|
|
|
|
|
float expected_sampling_rate = 16000;
|
|
float expected_sampling_rate = 16000;
|
|
|
|
|
|
sherpa_ncnn::SymbolTable sym(tokens);
|
|
sherpa_ncnn::SymbolTable sym(tokens);
|
|
|
|
|
|
- std::cout << "number of threads: " << num_threads << "\n";
|
|
|
|
|
|
+ std::cout << config.ToString() << "\n";
|
|
|
|
|
|
- sherpa_ncnn::LstmModel model(encoder_param, encoder_bin, decoder_param,
|
|
|
|
- decoder_bin, joiner_param, joiner_bin,
|
|
|
|
- num_threads);
|
|
|
|
|
|
+ auto model = sherpa_ncnn::Model::Create(config);
|
|
|
|
+ if (!model) {
|
|
|
|
+ std::cout << "Failed to create a model\n";
|
|
|
|
+ exit(EXIT_FAILURE);
|
|
|
|
+ }
|
|
|
|
|
|
std::vector<float> samples =
|
|
std::vector<float> samples =
|
|
sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate);
|
|
sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate);
|
|
@@ -132,11 +138,11 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
|
|
|
|
|
|
feature_extractor.InputFinished();
|
|
feature_extractor.InputFinished();
|
|
|
|
|
|
- int32_t segment = 9;
|
|
|
|
- int32_t offset = 4;
|
|
|
|
|
|
+ int32_t segment = model->Segment();
|
|
|
|
+ int32_t offset = model->Offset();
|
|
|
|
|
|
- int32_t context_size = model.ContextSize();
|
|
|
|
- int32_t blank_id = model.BlankId();
|
|
|
|
|
|
+ int32_t context_size = model->ContextSize();
|
|
|
|
+ int32_t blank_id = model->BlankId();
|
|
|
|
|
|
std::vector<int32_t> hyp(context_size, blank_id);
|
|
std::vector<int32_t> hyp(context_size, blank_id);
|
|
|
|
|
|
@@ -145,19 +151,19 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
|
|
static_cast<int32_t *>(decoder_input)[i] = blank_id;
|
|
static_cast<int32_t *>(decoder_input)[i] = blank_id;
|
|
}
|
|
}
|
|
|
|
|
|
- ncnn::Mat decoder_out = model.RunDecoder(decoder_input);
|
|
|
|
|
|
+ ncnn::Mat decoder_out = model->RunDecoder(decoder_input);
|
|
|
|
|
|
- ncnn::Mat hx;
|
|
|
|
- ncnn::Mat cx;
|
|
|
|
|
|
+ std::vector<ncnn::Mat> states;
|
|
|
|
+ ncnn::Mat encoder_out;
|
|
|
|
|
|
int32_t num_processed = 0;
|
|
int32_t num_processed = 0;
|
|
while (feature_extractor.NumFramesReady() - num_processed >= segment) {
|
|
while (feature_extractor.NumFramesReady() - num_processed >= segment) {
|
|
ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
|
|
ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
|
|
num_processed += offset;
|
|
num_processed += offset;
|
|
|
|
|
|
- ncnn::Mat encoder_out = model.RunEncoder(features, &hx, &cx);
|
|
|
|
|
|
+ std::tie(encoder_out, states) = model->RunEncoder(features, states);
|
|
|
|
|
|
- GreedySearch(model, encoder_out, &decoder_out, &hyp);
|
|
|
|
|
|
+ GreedySearch(model.get(), encoder_out, &decoder_out, &hyp);
|
|
}
|
|
}
|
|
|
|
|
|
std::string text;
|
|
std::string text;
|