|
@@ -18,7 +18,8 @@
|
|
|
*/
|
|
|
|
|
|
#include "sherpa-ncnn/csrc/recognizer.h"
|
|
|
-
|
|
|
+#include <iostream>
|
|
|
+#include <fstream>
|
|
|
#include <memory>
|
|
|
#include <string>
|
|
|
#include <utility>
|
|
@@ -75,8 +76,12 @@ std::string RecognizerConfig::ToString() const {
|
|
|
os << "feat_config=" << feat_config.ToString() << ", ";
|
|
|
os << "model_config=" << model_config.ToString() << ", ";
|
|
|
os << "decoder_config=" << decoder_config.ToString() << ", ";
|
|
|
+ os << "max_active_paths=" << max_active_paths << ", ";
|
|
|
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
|
|
- os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ")";
|
|
|
+ os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
|
|
+ os << "hotwords_file=\"" << hotwords_file << "\", ";
|
|
|
+ os << "hotwrods_score=" << hotwords_score << ", ";
|
|
|
+ os << "decoding_method=\"" << decoding_method << "\")";
|
|
|
|
|
|
return os.str();
|
|
|
}
|
|
@@ -93,6 +98,30 @@ class Recognizer::Impl {
|
|
|
} else if (config.decoder_config.method == "modified_beam_search") {
|
|
|
decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
|
|
|
model_.get(), config.decoder_config.num_active_paths);
|
|
|
+ std::vector<int32_t> tmp;
|
|
|
+ /*each line in hotwords file is a string which is segmented by space*/
|
|
|
+ std::ifstream file(config_.hotwords_file);
|
|
|
+ if (file) {
|
|
|
+ std::string line;
|
|
|
+ std::string word;
|
|
|
+ while (std::getline(file, line)) {
|
|
|
+ std::istringstream iss(line);
|
|
|
+ while(iss >> word){
|
|
|
+ if (sym_.contains(word)) {
|
|
|
+ int number = sym_[word];
|
|
|
+ tmp.push_back(number);
|
|
|
+ } else {
|
|
|
+ NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
|
|
|
+ exit(-1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ hotwords_.push_back(tmp);
|
|
|
+ tmp.clear();
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ NCNN_LOGE("open file failed: %s, hotwords will not be used",
|
|
|
+ config_.hotwords_file.c_str());
|
|
|
+ }
|
|
|
} else {
|
|
|
NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
|
|
|
exit(-1);
|
|
@@ -110,6 +139,30 @@ class Recognizer::Impl {
|
|
|
} else if (config.decoder_config.method == "modified_beam_search") {
|
|
|
decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
|
|
|
model_.get(), config.decoder_config.num_active_paths);
|
|
|
+ std::vector<int32_t> tmp;
|
|
|
+ /*each line in hotwords file is a string which is segmented by space*/
|
|
|
+ std::ifstream file(config_.hotwords_file);
|
|
|
+ if (file) {
|
|
|
+ std::string line;
|
|
|
+ std::string word;
|
|
|
+ while (std::getline(file, line)) {
|
|
|
+ std::istringstream iss(line);
|
|
|
+ while(iss >> word){
|
|
|
+ if (sym_.contains(word)) {
|
|
|
+ int number = sym_[word];
|
|
|
+ tmp.push_back(number);
|
|
|
+ } else {
|
|
|
+ NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
|
|
|
+ exit(-1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ hotwords_.push_back(tmp);
|
|
|
+ tmp.clear();
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ NCNN_LOGE("open file failed: %s, hotwords will not be used",
|
|
|
+ config_.hotwords_file.c_str());
|
|
|
+ }
|
|
|
} else {
|
|
|
NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
|
|
|
exit(-1);
|
|
@@ -118,10 +171,29 @@ class Recognizer::Impl {
|
|
|
#endif
|
|
|
|
|
|
std::unique_ptr<Stream> CreateStream() const {
|
|
|
- auto stream = std::make_unique<Stream>(config_.feat_config);
|
|
|
- stream->SetResult(decoder_->GetEmptyResult());
|
|
|
- stream->SetStates(model_->GetEncoderInitStates());
|
|
|
- return stream;
|
|
|
+ if(hotwords_.empty()) {
|
|
|
+ auto stream = std::make_unique<Stream>(config_.feat_config);
|
|
|
+ stream->SetResult(decoder_->GetEmptyResult());
|
|
|
+ stream->SetStates(model_->GetEncoderInitStates());
|
|
|
+ return stream;
|
|
|
+ } else {
|
|
|
+ auto r = decoder_->GetEmptyResult();
|
|
|
+ auto context_graph =
|
|
|
+ std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
|
|
+ auto stream =
|
|
|
+ std::make_unique<Stream>(config_.feat_config, context_graph);
|
|
|
+ if (config_.decoder_config.method == "modified_beam_search" &&
|
|
|
+ nullptr != stream->GetContextGraph()) {
|
|
|
+ std::cout<<"create contexts stream"<<std::endl;
|
|
|
+ // r.hyps has only one element.
|
|
|
+ for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
|
|
+ it->second.context_state = stream->GetContextGraph()->Root();
|
|
|
+ }
|
|
|
+ }
|
|
|
+ stream->SetResult(r);
|
|
|
+ stream->SetStates(model_->GetEncoderInitStates());
|
|
|
+ return stream;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
bool IsReady(Stream *s) const {
|
|
@@ -131,16 +203,24 @@ class Recognizer::Impl {
|
|
|
void DecodeStream(Stream *s) const {
|
|
|
int32_t segment = model_->Segment();
|
|
|
int32_t offset = model_->Offset();
|
|
|
+ bool has_context_graph = false;
|
|
|
|
|
|
+ if (!has_context_graph && s->GetContextGraph()) {
|
|
|
+ has_context_graph = true;
|
|
|
+ }
|
|
|
ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment);
|
|
|
s->GetNumProcessedFrames() += offset;
|
|
|
std::vector<ncnn::Mat> states = s->GetStates();
|
|
|
|
|
|
ncnn::Mat encoder_out;
|
|
|
std::tie(encoder_out, states) = model_->RunEncoder(features, states);
|
|
|
- s->SetStates(states);
|
|
|
|
|
|
- decoder_->Decode(encoder_out, &s->GetResult());
|
|
|
+ if (has_context_graph) {
|
|
|
+ decoder_->Decode(encoder_out, s, &s->GetResult());
|
|
|
+ } else {
|
|
|
+ decoder_->Decode(encoder_out, &s->GetResult());
|
|
|
+ }
|
|
|
+ s->SetStates(states);
|
|
|
}
|
|
|
|
|
|
bool IsEndpoint(Stream *s) const {
|
|
@@ -158,6 +238,13 @@ class Recognizer::Impl {
|
|
|
}
|
|
|
|
|
|
void Reset(Stream *s) const {
|
|
|
+ auto r = decoder_->GetEmptyResult();
|
|
|
+ if (config_.decoding_method == "modified_beam_search" &&
|
|
|
+ nullptr != s->GetContextGraph()) {
|
|
|
+ for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
|
|
+ it->second.context_state = s->GetContextGraph()->Root();
|
|
|
+ }
|
|
|
+ }
|
|
|
// Caution: We need to keep the decoder output state
|
|
|
ncnn::Mat decoder_out = s->GetResult().decoder_out;
|
|
|
s->SetResult(decoder_->GetEmptyResult());
|
|
@@ -190,6 +277,7 @@ class Recognizer::Impl {
|
|
|
std::unique_ptr<Decoder> decoder_;
|
|
|
Endpoint endpoint_;
|
|
|
SymbolTable sym_;
|
|
|
+ std::vector<std::vector<int32_t>> hotwords_;
|
|
|
};
|
|
|
|
|
|
Recognizer::Recognizer(const RecognizerConfig &config)
|
|
@@ -206,6 +294,7 @@ std::unique_ptr<Stream> Recognizer::CreateStream() const {
|
|
|
return impl_->CreateStream();
|
|
|
}
|
|
|
|
|
|
+
|
|
|
bool Recognizer::IsReady(Stream *s) const { return impl_->IsReady(s); }
|
|
|
|
|
|
void Recognizer::DecodeStream(Stream *s) const { impl_->DecodeStream(s); }
|