|
@@ -18,7 +18,7 @@
|
|
|
*/
|
|
|
|
|
|
#include "sherpa-ncnn/csrc/recognizer.h"
|
|
|
-#include <iostream>
|
|
|
+
|
|
|
#include <fstream>
|
|
|
#include <memory>
|
|
|
#include <string>
|
|
@@ -29,6 +29,14 @@
|
|
|
#include "sherpa-ncnn/csrc/greedy-search-decoder.h"
|
|
|
#include "sherpa-ncnn/csrc/modified-beam-search-decoder.h"
|
|
|
|
|
|
+#if __ANDROID_API__ >= 9
|
|
|
+#include <strstream>
|
|
|
+
|
|
|
+#include "android/asset_manager.h"
|
|
|
+#include "android/asset_manager_jni.h"
|
|
|
+#include "android/log.h"
|
|
|
+#endif
|
|
|
+
|
|
|
namespace sherpa_ncnn {
|
|
|
|
|
|
static RecognitionResult Convert(const DecoderResult &src,
|
|
@@ -76,12 +84,10 @@ std::string RecognizerConfig::ToString() const {
|
|
|
os << "feat_config=" << feat_config.ToString() << ", ";
|
|
|
os << "model_config=" << model_config.ToString() << ", ";
|
|
|
os << "decoder_config=" << decoder_config.ToString() << ", ";
|
|
|
- os << "max_active_paths=" << max_active_paths << ", ";
|
|
|
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
|
|
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
|
|
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
|
|
- os << "hotwrods_score=" << hotwords_score << ", ";
|
|
|
- os << "decoding_method=\"" << decoding_method << "\")";
|
|
|
+ os << "hotwrods_score=" << hotwords_score << ")";
|
|
|
|
|
|
return os.str();
|
|
|
}
|
|
@@ -98,29 +104,9 @@ class Recognizer::Impl {
|
|
|
} else if (config.decoder_config.method == "modified_beam_search") {
|
|
|
decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
|
|
|
model_.get(), config.decoder_config.num_active_paths);
|
|
|
- std::vector<int32_t> tmp;
|
|
|
- /*each line in hotwords file is a string which is segmented by space*/
|
|
|
- std::ifstream file(config_.hotwords_file);
|
|
|
- if (file) {
|
|
|
- std::string line;
|
|
|
- std::string word;
|
|
|
- while (std::getline(file, line)) {
|
|
|
- std::istringstream iss(line);
|
|
|
- while(iss >> word){
|
|
|
- if (sym_.contains(word)) {
|
|
|
- int number = sym_[word];
|
|
|
- tmp.push_back(number);
|
|
|
- } else {
|
|
|
- NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
|
|
|
- exit(-1);
|
|
|
- }
|
|
|
- }
|
|
|
- hotwords_.push_back(tmp);
|
|
|
- tmp.clear();
|
|
|
- }
|
|
|
- } else {
|
|
|
- NCNN_LOGE("open file failed: %s, hotwords will not be used",
|
|
|
- config_.hotwords_file.c_str());
|
|
|
+
|
|
|
+ if (!config_.hotwords_file.empty()) {
|
|
|
+ InitHotwords();
|
|
|
}
|
|
|
} else {
|
|
|
NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
|
|
@@ -139,29 +125,9 @@ class Recognizer::Impl {
|
|
|
} else if (config.decoder_config.method == "modified_beam_search") {
|
|
|
decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
|
|
|
model_.get(), config.decoder_config.num_active_paths);
|
|
|
- std::vector<int32_t> tmp;
|
|
|
- /*each line in hotwords file is a string which is segmented by space*/
|
|
|
- std::ifstream file(config_.hotwords_file);
|
|
|
- if (file) {
|
|
|
- std::string line;
|
|
|
- std::string word;
|
|
|
- while (std::getline(file, line)) {
|
|
|
- std::istringstream iss(line);
|
|
|
- while(iss >> word){
|
|
|
- if (sym_.contains(word)) {
|
|
|
- int number = sym_[word];
|
|
|
- tmp.push_back(number);
|
|
|
- } else {
|
|
|
- NCNN_LOGE("hotword %s can't find id. line: %s", word.c_str(), line.c_str());
|
|
|
- exit(-1);
|
|
|
- }
|
|
|
- }
|
|
|
- hotwords_.push_back(tmp);
|
|
|
- tmp.clear();
|
|
|
- }
|
|
|
- } else {
|
|
|
- NCNN_LOGE("open file failed: %s, hotwords will not be used",
|
|
|
- config_.hotwords_file.c_str());
|
|
|
+
|
|
|
+ if (!config_.hotwords_file.empty()) {
|
|
|
+ InitHotwords(mgr);
|
|
|
}
|
|
|
} else {
|
|
|
NCNN_LOGE("Unsupported method: %s", config.decoder_config.method.c_str());
|
|
@@ -171,27 +137,30 @@ class Recognizer::Impl {
|
|
|
#endif
|
|
|
|
|
|
std::unique_ptr<Stream> CreateStream() const {
|
|
|
- if(hotwords_.empty()) {
|
|
|
+ if (hotwords_.empty()) {
|
|
|
auto stream = std::make_unique<Stream>(config_.feat_config);
|
|
|
stream->SetResult(decoder_->GetEmptyResult());
|
|
|
stream->SetStates(model_->GetEncoderInitStates());
|
|
|
return stream;
|
|
|
} else {
|
|
|
auto r = decoder_->GetEmptyResult();
|
|
|
+
|
|
|
auto context_graph =
|
|
|
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
|
|
|
+
|
|
|
auto stream =
|
|
|
std::make_unique<Stream>(config_.feat_config, context_graph);
|
|
|
- if (config_.decoder_config.method == "modified_beam_search" &&
|
|
|
- nullptr != stream->GetContextGraph()) {
|
|
|
- std::cout<<"create contexts stream"<<std::endl;
|
|
|
+
|
|
|
+ if (stream->GetContextGraph()) {
|
|
|
// r.hyps has only one element.
|
|
|
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
|
|
it->second.context_state = stream->GetContextGraph()->Root();
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
stream->SetResult(r);
|
|
|
stream->SetStates(model_->GetEncoderInitStates());
|
|
|
+
|
|
|
return stream;
|
|
|
}
|
|
|
}
|
|
@@ -239,8 +208,8 @@ class Recognizer::Impl {
|
|
|
|
|
|
void Reset(Stream *s) const {
|
|
|
auto r = decoder_->GetEmptyResult();
|
|
|
- if (config_.decoding_method == "modified_beam_search" &&
|
|
|
- nullptr != s->GetContextGraph()) {
|
|
|
+
|
|
|
+ if (s->GetContextGraph()) {
|
|
|
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
|
|
|
it->second.context_state = s->GetContextGraph()->Root();
|
|
|
}
|
|
@@ -271,6 +240,60 @@ class Recognizer::Impl {
|
|
|
|
|
|
const Model *GetModel() const { return model_.get(); }
|
|
|
|
|
|
+ private:
|
|
|
+#if __ANDROID_API__ >= 9
|
|
|
+ void InitHotwords(AAssetManager *mgr) {
|
|
|
+ AAsset *asset = AAssetManager_open(mgr, config_.hotwords_file.c_str(),
|
|
|
+ AASSET_MODE_BUFFER);
|
|
|
+ if (!asset) {
|
|
|
+ __android_log_print(ANDROID_LOG_FATAL, "sherpa-ncnn",
|
|
|
+ "hotwords_file: Load %s failed",
|
|
|
+ config_.hotwords_file.c_str());
|
|
|
+ exit(-1);
|
|
|
+ }
|
|
|
+
|
|
|
+ auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
|
|
|
+ size_t asset_length = AAsset_getLength(asset);
|
|
|
+ std::istrstream is(p, asset_length);
|
|
|
+ InitHotwords(is);
|
|
|
+ AAsset_close(asset);
|
|
|
+ }
|
|
|
+#endif
|
|
|
+
|
|
|
+ void InitHotwords() {
|
|
|
+ // each line in hotwords_file contains space-separated words
|
|
|
+
|
|
|
+ std::ifstream is(config_.hotwords_file);
|
|
|
+ if (!is) {
|
|
|
+ NCNN_LOGE("Open hotwords file failed: %s", config_.hotwords_file.c_str());
|
|
|
+ exit(-1);
|
|
|
+ }
|
|
|
+
|
|
|
+ InitHotwords(is);
|
|
|
+ }
|
|
|
+
|
|
|
+ void InitHotwords(std::istream &is) {
|
|
|
+ std::vector<int32_t> tmp;
|
|
|
+ std::string line;
|
|
|
+ std::string word;
|
|
|
+
|
|
|
+ while (std::getline(is, line)) {
|
|
|
+ std::istringstream iss(line);
|
|
|
+ while (iss >> word) {
|
|
|
+ if (sym_.contains(word)) {
|
|
|
+ int32_t number = sym_[word];
|
|
|
+ tmp.push_back(number);
|
|
|
+ } else {
|
|
|
+ NCNN_LOGE("Cannot find ID for hotword %s at line: %s", word.c_str(),
|
|
|
+ line.c_str());
|
|
|
+ exit(-1);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ hotwords_.push_back(std::move(tmp));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private:
|
|
|
RecognizerConfig config_;
|
|
|
std::unique_ptr<Model> model_;
|
|
@@ -294,7 +317,6 @@ std::unique_ptr<Stream> Recognizer::CreateStream() const {
|
|
|
return impl_->CreateStream();
|
|
|
}
|
|
|
|
|
|
-
|
|
|
bool Recognizer::IsReady(Stream *s) const { return impl_->IsReady(s); }
|
|
|
|
|
|
void Recognizer::DecodeStream(Stream *s) const { impl_->DecodeStream(s); }
|