|
@@ -61,6 +61,10 @@ class SherpaNcnn {
|
|
|
return result.text;
|
|
|
}
|
|
|
|
|
|
+ bool IsEndpoint() { return recognizer_.IsEndpoint(); }
|
|
|
+
|
|
|
+ void Reset() { return recognizer_.Reset(); }
|
|
|
+
|
|
|
private:
|
|
|
sherpa_ncnn::Recognizer recognizer_;
|
|
|
std::vector<float> tail_padding_;
|
|
@@ -126,6 +130,68 @@ static ModelConfig GetModelConfig(JNIEnv *env, jobject config) {
|
|
|
return model_config;
|
|
|
}
|
|
|
|
|
|
+static DecoderConfig GetDecoderConfig(JNIEnv *env, jobject config) {
|
|
|
+ DecoderConfig decoder_config;
|
|
|
+
|
|
|
+ jclass cls = env->GetObjectClass(config);
|
|
|
+
|
|
|
+ jfieldID fid = env->GetFieldID(cls, "method", "Ljava/lang/String;");
|
|
|
+ jstring s = (jstring)env->GetObjectField(config, fid);
|
|
|
+ const char *p = env->GetStringUTFChars(s, nullptr);
|
|
|
+ decoder_config.method = p;
|
|
|
+ env->ReleaseStringUTFChars(s, p);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(cls, "numActivePaths", "I");
|
|
|
+ decoder_config.num_active_paths = env->GetIntField(config, fid);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(cls, "useEndpoint", "Z");
|
|
|
+ decoder_config.use_endpoint = env->GetBooleanField(config, fid);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(cls, "endpointConfig",
|
|
|
+ "Lcom/k2fsa/sherpa/ncnn/EndpointConfig;");
|
|
|
+ jobject endpoint_config = env->GetObjectField(config, fid);
|
|
|
+ jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(endpoint_config_cls, "rule1",
|
|
|
+ "Lcom/k2fsa/sherpa/ncnn/EndpointRule;");
|
|
|
+ jobject rule1 = env->GetObjectField(endpoint_config, fid);
|
|
|
+ jclass rule_class = env->GetObjectClass(rule1);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(endpoint_config_cls, "rule2",
|
|
|
+ "Lcom/k2fsa/sherpa/ncnn/EndpointRule;");
|
|
|
+ jobject rule2 = env->GetObjectField(endpoint_config, fid);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(endpoint_config_cls, "rule3",
|
|
|
+ "Lcom/k2fsa/sherpa/ncnn/EndpointRule;");
|
|
|
+ jobject rule3 = env->GetObjectField(endpoint_config, fid);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
|
|
|
+ decoder_config.endpoint_config.rule1.must_contain_nonsilence =
|
|
|
+ env->GetBooleanField(rule1, fid);
|
|
|
+ decoder_config.endpoint_config.rule2.must_contain_nonsilence =
|
|
|
+ env->GetBooleanField(rule2, fid);
|
|
|
+ decoder_config.endpoint_config.rule3.must_contain_nonsilence =
|
|
|
+ env->GetBooleanField(rule3, fid);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
|
|
|
+ decoder_config.endpoint_config.rule1.min_trailing_silence =
|
|
|
+ env->GetFloatField(rule1, fid);
|
|
|
+ decoder_config.endpoint_config.rule2.min_trailing_silence =
|
|
|
+ env->GetFloatField(rule2, fid);
|
|
|
+ decoder_config.endpoint_config.rule3.min_trailing_silence =
|
|
|
+ env->GetFloatField(rule3, fid);
|
|
|
+
|
|
|
+ fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
|
|
|
+ decoder_config.endpoint_config.rule1.min_utterance_length =
|
|
|
+ env->GetFloatField(rule1, fid);
|
|
|
+ decoder_config.endpoint_config.rule2.min_utterance_length =
|
|
|
+ env->GetFloatField(rule2, fid);
|
|
|
+ decoder_config.endpoint_config.rule3.min_utterance_length =
|
|
|
+ env->GetFloatField(rule3, fid);
|
|
|
+
|
|
|
+ return decoder_config;
|
|
|
+}
|
|
|
+
|
|
|
static knf::FbankOptions GetFbankOptions(JNIEnv *env, jobject opts) {
|
|
|
jclass cls = env->GetObjectClass(opts);
|
|
|
jfieldID fid;
|
|
@@ -231,16 +297,18 @@ static knf::FbankOptions GetFbankOptions(JNIEnv *env, jobject opts) {
|
|
|
SHERPA_EXTERN_C
|
|
|
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_new(
|
|
|
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _model_config,
|
|
|
- jobject _fbank_config) {
|
|
|
+ jobject _decoder_config, jobject _fbank_config) {
|
|
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
|
|
if (!mgr) {
|
|
|
NCNN_LOGE("Failed to get asset manager: %p", mgr);
|
|
|
}
|
|
|
|
|
|
- sherpa_ncnn::ModelConfig model_config =
|
|
|
- sherpa_ncnn::GetModelConfig(env, _model_config);
|
|
|
+ auto model_config = sherpa_ncnn::GetModelConfig(env, _model_config);
|
|
|
+ auto decoder_config = sherpa_ncnn::GetDecoderConfig(env, _decoder_config);
|
|
|
|
|
|
- sherpa_ncnn::DecoderConfig decoder_config;
|
|
|
+ NCNN_LOGE("------model_config------\n%s\n", model_config.ToString().c_str());
|
|
|
+ NCNN_LOGE("------decoder_config------\n%s\n",
|
|
|
+ decoder_config.ToString().c_str());
|
|
|
|
|
|
knf::FbankOptions fbank_opts =
|
|
|
sherpa_ncnn::GetFbankOptions(env, _fbank_config);
|
|
@@ -259,7 +327,17 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_delete(
|
|
|
|
|
|
SHERPA_EXTERN_C
|
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_reset(
|
|
|
- JNIEnv *env, jobject /*obj*/, jlong ptr) {}
|
|
|
+ JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
|
+ auto model = reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr);
|
|
|
+ model->Reset();
|
|
|
+}
|
|
|
+
|
|
|
+SHERPA_EXTERN_C
|
|
|
+JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_isEndpoint(
|
|
|
+ JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
|
+ auto model = reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr);
|
|
|
+ return model->IsEndpoint();
|
|
|
+}
|
|
|
|
|
|
SHERPA_EXTERN_C
|
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_decodeSamples(
|