|
@@ -40,31 +40,20 @@ namespace sherpa_ncnn {
|
|
|
class SherpaNcnn {
|
|
|
public:
|
|
|
SherpaNcnn(AAssetManager *mgr, const ModelConfig &model_config,
|
|
|
- const knf::FbankOptions &fbank_config, const std::string &tokens)
|
|
|
+ const knf::FbankOptions &fbank_config)
|
|
|
: model_(Model::Create(mgr, model_config)),
|
|
|
- feature_extractor_(fbank_config),
|
|
|
- sym_(mgr, tokens) {
|
|
|
- // Initialize decoder_output
|
|
|
- int32_t context_size = model_->ContextSize();
|
|
|
- int32_t blank_id = 0;
|
|
|
-
|
|
|
- ncnn::Mat decoder_input(context_size);
|
|
|
- for (int32_t i = 0; i != context_size; ++i) {
|
|
|
- static_cast<int32_t *>(decoder_input)[i] = blank_id;
|
|
|
- }
|
|
|
-
|
|
|
- decoder_out_ = model_->RunDecoder(decoder_input);
|
|
|
-
|
|
|
- hyp_.resize(context_size, 0);
|
|
|
+ feature_extractor_(std::make_unique<FeatureExtractor>(fbank_config)),
|
|
|
+ sym_(mgr, model_config.tokens) {
|
|
|
+ Reset();
|
|
|
}
|
|
|
|
|
|
void DecodeSamples(float sample_rate, const float *samples, int32_t n) {
|
|
|
- feature_extractor_.AcceptWaveform(sample_rate, samples, n);
|
|
|
+ feature_extractor_->AcceptWaveform(sample_rate, samples, n);
|
|
|
Decode();
|
|
|
}
|
|
|
|
|
|
void InputFinished() {
|
|
|
- feature_extractor_.InputFinished();
|
|
|
+ feature_extractor_->InputFinished();
|
|
|
Decode();
|
|
|
}
|
|
|
|
|
@@ -79,15 +68,33 @@ class SherpaNcnn {
|
|
|
return text;
|
|
|
}
|
|
|
|
|
|
+ void Reset() {
|
|
|
+ feature_extractor_->Reset();
|
|
|
+ num_processed_ = 0;
|
|
|
+ states_.clear();
|
|
|
+
|
|
|
+ int32_t context_size = model_->ContextSize();
|
|
|
+ int32_t blank_id = 0;
|
|
|
+
|
|
|
+ ncnn::Mat decoder_input(context_size);
|
|
|
+ for (int32_t i = 0; i != context_size; ++i) {
|
|
|
+ static_cast<int32_t *>(decoder_input)[i] = blank_id;
|
|
|
+ }
|
|
|
+
|
|
|
+ decoder_out_ = model_->RunDecoder(decoder_input);
|
|
|
+
|
|
|
+ hyp_.resize(context_size, 0);
|
|
|
+ }
|
|
|
+
|
|
|
private:
|
|
|
void Decode() {
|
|
|
int32_t segment = model_->Segment();
|
|
|
int32_t offset = model_->Offset();
|
|
|
|
|
|
ncnn::Mat encoder_out;
|
|
|
- while (feature_extractor_.NumFramesReady() - num_processed_ >= segment) {
|
|
|
+ while (feature_extractor_->NumFramesReady() - num_processed_ >= segment) {
|
|
|
ncnn::Mat features =
|
|
|
- feature_extractor_.GetFrames(num_processed_, segment);
|
|
|
+ feature_extractor_->GetFrames(num_processed_, segment);
|
|
|
num_processed_ += offset;
|
|
|
|
|
|
std::tie(encoder_out, states_) = model_->RunEncoder(features, states_);
|
|
@@ -98,7 +105,7 @@ class SherpaNcnn {
|
|
|
|
|
|
private:
|
|
|
std::unique_ptr<Model> model_;
|
|
|
- FeatureExtractor feature_extractor_;
|
|
|
+ std::unique_ptr<FeatureExtractor> feature_extractor_;
|
|
|
sherpa_ncnn::SymbolTable sym_;
|
|
|
|
|
|
std::vector<int32_t> hyp_;
|
|
@@ -150,9 +157,18 @@ static ModelConfig GetModelConfig(JNIEnv *env, jobject config) {
|
|
|
model_config.joiner_bin = p;
|
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
|
|
+ fid = env->GetFieldID(cls, "tokens", "Ljava/lang/String;");
|
|
|
+ s = (jstring)env->GetObjectField(config, fid);
|
|
|
+ p = env->GetStringUTFChars(s, nullptr);
|
|
|
+ model_config.tokens = p;
|
|
|
+ env->ReleaseStringUTFChars(s, p);
|
|
|
+
|
|
|
fid = env->GetFieldID(cls, "numThreads", "I");
|
|
|
model_config.num_threads = env->GetIntField(config, fid);
|
|
|
|
|
|
+ fid = env->GetFieldID(cls, "useGPU", "Z");
|
|
|
+ model_config.use_vulkan_compute = env->GetBooleanField(config, fid);
|
|
|
+
|
|
|
return model_config;
|
|
|
}
|
|
|
|
|
@@ -165,92 +181,92 @@ static knf::FbankOptions GetFbankOptions(JNIEnv *env, jobject opts) {
|
|
|
|
|
|
knf::FbankOptions fbank_opts;
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "use_energy", "Z");
|
|
|
+ fid = env->GetFieldID(cls, "useEnergy", "Z");
|
|
|
fbank_opts.use_energy = env->GetBooleanField(opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "energy_floor", "F");
|
|
|
+ fid = env->GetFieldID(cls, "energyFloor", "F");
|
|
|
fbank_opts.energy_floor = env->GetFloatField(opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "raw_energy", "Z");
|
|
|
+ fid = env->GetFieldID(cls, "rawEnergy", "Z");
|
|
|
fbank_opts.raw_energy = env->GetBooleanField(opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "htk_compat", "Z");
|
|
|
+ fid = env->GetFieldID(cls, "htkCompat", "Z");
|
|
|
fbank_opts.htk_compat = env->GetBooleanField(opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "use_log_fbank", "Z");
|
|
|
+ fid = env->GetFieldID(cls, "useLogFbank", "Z");
|
|
|
fbank_opts.use_log_fbank = env->GetBooleanField(opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "use_power", "Z");
|
|
|
+ fid = env->GetFieldID(cls, "usePower", "Z");
|
|
|
fbank_opts.use_power = env->GetBooleanField(opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "frame_opts",
|
|
|
+ fid = env->GetFieldID(cls, "frameOpts",
|
|
|
"Lcom/k2fsa/sherpa/ncnn/FrameExtractionOptions;");
|
|
|
|
|
|
jobject frame_opts = env->GetObjectField(opts, fid);
|
|
|
jclass frame_opts_cls = env->GetObjectClass(frame_opts);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "samp_freq", "F");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "sampFreq", "F");
|
|
|
fbank_opts.frame_opts.samp_freq = env->GetFloatField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "frame_shift_ms", "F");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "frameShiftMs", "F");
|
|
|
fbank_opts.frame_opts.frame_shift_ms = env->GetFloatField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "frame_length_ms", "F");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "frameLengthMs", "F");
|
|
|
fbank_opts.frame_opts.frame_length_ms = env->GetFloatField(frame_opts, fid);
|
|
|
|
|
|
fid = env->GetFieldID(frame_opts_cls, "dither", "F");
|
|
|
fbank_opts.frame_opts.dither = env->GetFloatField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "preemph_coeff", "F");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "preemphCoeff", "F");
|
|
|
fbank_opts.frame_opts.preemph_coeff = env->GetFloatField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "remove_dc_offset", "Z");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "removeDcOffset", "Z");
|
|
|
fbank_opts.frame_opts.remove_dc_offset =
|
|
|
env->GetBooleanField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "window_type", "Ljava/lang/String;");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "windowType", "Ljava/lang/String;");
|
|
|
jstring window_type = (jstring)env->GetObjectField(frame_opts, fid);
|
|
|
const char *p_window_type = env->GetStringUTFChars(window_type, nullptr);
|
|
|
fbank_opts.frame_opts.window_type = p_window_type;
|
|
|
env->ReleaseStringUTFChars(window_type, p_window_type);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "round_to_power_of_two", "Z");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "roundToPowerOfTwo", "Z");
|
|
|
fbank_opts.frame_opts.round_to_power_of_two =
|
|
|
env->GetBooleanField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "blackman_coeff", "F");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "blackmanCoeff", "F");
|
|
|
fbank_opts.frame_opts.blackman_coeff = env->GetFloatField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "snip_edges", "Z");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "snipEdges", "Z");
|
|
|
fbank_opts.frame_opts.snip_edges = env->GetBooleanField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(frame_opts_cls, "max_feature_vectors", "I");
|
|
|
+ fid = env->GetFieldID(frame_opts_cls, "maxFeatureVectors", "I");
|
|
|
fbank_opts.frame_opts.max_feature_vectors = env->GetIntField(frame_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(cls, "mel_opts",
|
|
|
+ fid = env->GetFieldID(cls, "melOpts",
|
|
|
"Lcom/k2fsa/sherpa/ncnn/MelBanksOptions;");
|
|
|
jobject mel_opts = env->GetObjectField(opts, fid);
|
|
|
jclass mel_opts_cls = env->GetObjectClass(mel_opts);
|
|
|
|
|
|
- fid = env->GetFieldID(mel_opts_cls, "num_bins", "I");
|
|
|
+ fid = env->GetFieldID(mel_opts_cls, "numBins", "I");
|
|
|
fbank_opts.mel_opts.num_bins = env->GetIntField(mel_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(mel_opts_cls, "low_freq", "F");
|
|
|
+ fid = env->GetFieldID(mel_opts_cls, "lowFreq", "F");
|
|
|
fbank_opts.mel_opts.low_freq = env->GetFloatField(mel_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(mel_opts_cls, "high_freq", "F");
|
|
|
+ fid = env->GetFieldID(mel_opts_cls, "highFreq", "F");
|
|
|
fbank_opts.mel_opts.high_freq = env->GetFloatField(mel_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(mel_opts_cls, "vtln_low", "F");
|
|
|
+ fid = env->GetFieldID(mel_opts_cls, "vtlnLow", "F");
|
|
|
fbank_opts.mel_opts.vtln_low = env->GetFloatField(mel_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(mel_opts_cls, "vtln_high", "F");
|
|
|
+ fid = env->GetFieldID(mel_opts_cls, "vtlnHigh", "F");
|
|
|
fbank_opts.mel_opts.vtln_high = env->GetFloatField(mel_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(mel_opts_cls, "debug_mel", "Z");
|
|
|
+ fid = env->GetFieldID(mel_opts_cls, "debugMel", "Z");
|
|
|
fbank_opts.mel_opts.debug_mel = env->GetBooleanField(mel_opts, fid);
|
|
|
|
|
|
- fid = env->GetFieldID(mel_opts_cls, "htk_mode", "Z");
|
|
|
+ fid = env->GetFieldID(mel_opts_cls, "htkMode", "Z");
|
|
|
fbank_opts.mel_opts.htk_mode = env->GetBooleanField(mel_opts, fid);
|
|
|
|
|
|
return fbank_opts;
|
|
@@ -261,7 +277,7 @@ static knf::FbankOptions GetFbankOptions(JNIEnv *env, jobject opts) {
|
|
|
SHERPA_EXTERN_C
|
|
|
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_new(
|
|
|
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _model_config,
|
|
|
- jobject _fbank_config, jstring tokens) {
|
|
|
+ jobject _fbank_config) {
|
|
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
|
|
if (!mgr) {
|
|
|
NCNN_LOGE("Failed to get asset manager: %p", mgr);
|
|
@@ -273,10 +289,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_new(
|
|
|
knf::FbankOptions fbank_opts =
|
|
|
sherpa_ncnn::GetFbankOptions(env, _fbank_config);
|
|
|
|
|
|
- const char *p_tokens = env->GetStringUTFChars(tokens, nullptr);
|
|
|
- auto model =
|
|
|
- new sherpa_ncnn::SherpaNcnn(mgr, model_config, fbank_opts, p_tokens);
|
|
|
- env->ReleaseStringUTFChars(tokens, p_tokens);
|
|
|
+ auto model = new sherpa_ncnn::SherpaNcnn(mgr, model_config, fbank_opts);
|
|
|
|
|
|
return (jlong)model;
|
|
|
}
|
|
@@ -287,6 +300,12 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_delete(
|
|
|
delete reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr);
|
|
|
}
|
|
|
|
|
|
+SHERPA_EXTERN_C
|
|
|
+JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_reset(
|
|
|
+ JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
|
+ reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr)->Reset();
|
|
|
+}
|
|
|
+
|
|
|
SHERPA_EXTERN_C
|
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_decodeSamples(
|
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|