|
@@ -172,11 +172,7 @@ class Recognizer::Impl {
|
|
void DecodeStream(Stream *s) const {
|
|
void DecodeStream(Stream *s) const {
|
|
int32_t segment = model_->Segment();
|
|
int32_t segment = model_->Segment();
|
|
int32_t offset = model_->Offset();
|
|
int32_t offset = model_->Offset();
|
|
- bool has_context_graph = false;
|
|
|
|
|
|
|
|
- if (!has_context_graph && s->GetContextGraph()) {
|
|
|
|
- has_context_graph = true;
|
|
|
|
- }
|
|
|
|
ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment);
|
|
ncnn::Mat features = s->GetFrames(s->GetNumProcessedFrames(), segment);
|
|
s->GetNumProcessedFrames() += offset;
|
|
s->GetNumProcessedFrames() += offset;
|
|
std::vector<ncnn::Mat> states = s->GetStates();
|
|
std::vector<ncnn::Mat> states = s->GetStates();
|
|
@@ -184,7 +180,7 @@ class Recognizer::Impl {
|
|
ncnn::Mat encoder_out;
|
|
ncnn::Mat encoder_out;
|
|
std::tie(encoder_out, states) = model_->RunEncoder(features, states);
|
|
std::tie(encoder_out, states) = model_->RunEncoder(features, states);
|
|
|
|
|
|
- if (has_context_graph) {
|
|
|
|
|
|
+ if (s->GetContextGraph()) {
|
|
decoder_->Decode(encoder_out, s, &s->GetResult());
|
|
decoder_->Decode(encoder_out, s, &s->GetResult());
|
|
} else {
|
|
} else {
|
|
decoder_->Decode(encoder_out, &s->GetResult());
|
|
decoder_->Decode(encoder_out, &s->GetResult());
|
|
@@ -216,7 +212,7 @@ class Recognizer::Impl {
|
|
}
|
|
}
|
|
// Caution: We need to keep the decoder output state
|
|
// Caution: We need to keep the decoder output state
|
|
ncnn::Mat decoder_out = s->GetResult().decoder_out;
|
|
ncnn::Mat decoder_out = s->GetResult().decoder_out;
|
|
- s->SetResult(decoder_->GetEmptyResult());
|
|
|
|
|
|
+ s->SetResult(r);
|
|
s->GetResult().decoder_out = decoder_out;
|
|
s->GetResult().decoder_out = decoder_out;
|
|
|
|
|
|
// don't reset encoder state
|
|
// don't reset encoder state
|
|
@@ -284,8 +280,10 @@ class Recognizer::Impl {
|
|
int32_t number = sym_[word];
|
|
int32_t number = sym_[word];
|
|
tmp.push_back(number);
|
|
tmp.push_back(number);
|
|
} else {
|
|
} else {
|
|
- NCNN_LOGE("Cannot find ID for hotword %s at line: %s", word.c_str(),
|
|
|
|
- line.c_str());
|
|
|
|
|
|
+ NCNN_LOGE(
|
|
|
|
+ "Cannot find ID for hotword %s at line: %s. (Hint: words on the "
|
|
|
|
+ "same line are separated by spaces)",
|
|
|
|
+ word.c_str(), line.c_str());
|
|
exit(-1);
|
|
exit(-1);
|
|
}
|
|
}
|
|
}
|
|
}
|