Przeglądaj źródła

Support building for android (#35)

* Support building for android
Fangjun Kuang 2 lat temu
rodzic
commit
491c7b99b4

+ 5 - 3
.github/workflows/aarch64-linux-gnu.yaml

@@ -7,7 +7,7 @@ on:
       - master
     paths:
       - '.github/workflows/aarch64-linux-gnu.yaml'
-      - '.github/scripts/test.sh'
+      - '.github/scripts/run-test.sh'
       - 'CMakeLists.txt'
       - 'cmake/**'
       - 'sherpa-ncnn/csrc/*'
@@ -17,7 +17,7 @@ on:
       - master
     paths:
       - '.github/workflows/aarch64-linux-gnu.yaml'
-      - '.github/scripts/test.sh'
+      - '.github/scripts/run-test.sh'
       - 'CMakeLists.txt'
       - 'cmake/**'
       - 'sherpa-ncnn/csrc/*'
@@ -109,9 +109,11 @@ jobs:
           ./build-aarch64-linux-gnu.sh
 
           ls -lh build-aarch64-linux-gnu/bin
+          ls -lh build-aarch64-linux-gnu/lib
+
           file build-aarch64-linux-gnu/bin/sherpa-ncnn
 
-      - name: Run tests (English)
+      - name: Run tests
         shell: bash
         run: |
           export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH

+ 5 - 3
.github/workflows/arm-linux-gnueabihf.yaml

@@ -7,7 +7,7 @@ on:
       - master
     paths:
       - '.github/workflows/arm-linux-gnueabihf.yaml'
-      - '.github/scripts/test.sh'
+      - '.github/scripts/run-test.sh'
       - 'CMakeLists.txt'
       - 'cmake/**'
       - 'sherpa-ncnn/csrc/*'
@@ -17,7 +17,7 @@ on:
       - master
     paths:
       - '.github/workflows/arm-linux-gnueabihf.yaml'
-      - '.github/scripts/test.sh'
+      - '.github/scripts/run-test.sh'
       - 'CMakeLists.txt'
       - 'cmake/**'
       - 'sherpa-ncnn/csrc/*'
@@ -111,9 +111,11 @@ jobs:
           ./build-arm-linux-gnueabihf.sh
 
           ls -lh build-arm-linux-gnueabihf/bin
+          ls -lh build-arm-linux-gnueabihf/lib
+
           file build-arm-linux-gnueabihf/bin/sherpa-ncnn
 
-      - name: Run tests (English)
+      - name: Run tests
         shell: bash
         run: |
           export PATH=$GITHUB_WORKSPACE/toolchain/bin:$PATH

+ 7 - 4
.github/workflows/linux-macos-windows.yaml

@@ -6,7 +6,7 @@ on:
       - master
     paths:
       - '.github/workflows/linux-macos-windows.yaml'
-      - '.github/scripts/test.sh'
+      - '.github/scripts/run-test.sh'
       - 'CMakeLists.txt'
       - 'cmake/**'
       - 'sherpa-ncnn/csrc/*'
@@ -15,7 +15,7 @@ on:
       - master
     paths:
       - '.github/workflows/linux-macos-windows.yaml'
-      - '.github/scripts/test.sh'
+      - '.github/scripts/run-test.sh'
       - 'CMakeLists.txt'
       - 'cmake/**'
       - 'sherpa-ncnn/csrc/*'
@@ -58,6 +58,9 @@ jobs:
           cd build
           make -j2
 
+          ls -lh lib
+          ls -lh bin
+
           ls -lh bin/sherpa-ncnn
           file bin/sherpa-ncnn
 
@@ -71,7 +74,7 @@ jobs:
           name: sherpa-ncnn-pre-built-binaries-os-${{ matrix.os }}
           path: ./build/bin
 
-      - name: Run tests for ubuntu/macos (English)
+      - name: Run tests for ubuntu/macos
         if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
         run: |
           export PATH=$PWD/build/bin:$PATH
@@ -88,7 +91,7 @@ jobs:
 
           ls -lh ./bin/Release/sherpa-ncnn.exe
 
-      - name: Run tests for windows (English)
+      - name: Run tests for windows
         if: startsWith(matrix.os, 'windows')
         shell: bash
         run: |

+ 11 - 7
CMakeLists.txt

@@ -1,7 +1,9 @@
 cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
 project(sherpa-ncnn)
 
-set(SHERPA_NCNN_VERSION "0.1")
+set(SHERPA_NCNN_VERSION_MAJOR "1")
+set(SHERPA_NCNN_VERSION_MINOR "0")
+set(SHERPA_NCNN_VERSION "${SHERPA_NCNN_VERSION_MAJOR}.${SHERPA_NCNN_VERSION_MINOR}")
 
 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -20,11 +22,8 @@ endif()
 set(CMAKE_INSTALL_RPATH ${SHERPA_NCNN_RPATH_ORIGIN})
 set(CMAKE_BUILD_RPATH ${SHERPA_NCNN_RPATH_ORIGIN})
 
-set(BUILD_SHARED_LIBS OFF)
-if(WIN32)
-  message(STATUS "Set BUILD_SHARED_LIBS to OFF for Windows")
-  set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE)
-endif()
+option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
+message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
 
 if(NOT CMAKE_BUILD_TYPE)
   message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
@@ -41,8 +40,13 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
 set(CMAKE_CXX_STANDARD 11 CACHE STRING "The C++ version to be used.")
 set(CMAKE_CXX_EXTENSIONS OFF)
 
+option(SHERPA_NCNN_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
+
 include(kaldi-native-fbank)
 include(ncnn)
-include(portaudio)
+
+if(SHERPA_NCNN_ENABLE_PORTAUDIO)
+  include(portaudio)
+endif()
 
 add_subdirectory(sherpa-ncnn)

+ 7 - 2
build-aarch64-linux-gnu.sh

@@ -5,6 +5,11 @@ set -x
 dir=build-aarch64-linux-gnu
 mkdir -p $dir
 cd $dir
-cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake ..
+cmake \
+  -DCMAKE_INSTALL_PREFIX=./install \
+  -DCMAKE_BUILD_TYPE=Release \
+  -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \
+  ..
+
 make VERBOSE=1 -j4
-make install
+make install/strip

+ 36 - 0
build-android-arm64-v8a.sh

@@ -0,0 +1,36 @@
+#!/usr/bin/env bash
+set -e
+
+# see https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
+
+dir=build-android-arm64-v8a
+
+mkdir -p $dir
+cd $dir
+
+if [ -z $ANDROID_NDK ]; then
+  ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
+  # or use
+  # ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
+  #
+  # Inside the $ANDROID_NDK directory, you can find a binary ndk-build
+  # and some other files like the file "build/cmake/android.toolchain.cmake"
+fi
+
+if [ ! -d $ANDROID_NDK ]; then
+  echo Please set the environment variable ANDROID_NDK before you run this script
+  exit 1
+fi
+
+echo "ANDROID_NDK: $ANDROID_NDK"
+sleep 1
+
+cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
+    -DCMAKE_BUILD_TYPE=Release \
+    -DBUILD_SHARED_LIBS=ON \
+    -DSHERPA_NCNN_ENABLE_PORTAUDIO=OFF \
+    -DCMAKE_INSTALL_PREFIX=./install \
+    -DANDROID_ABI="arm64-v8a" \
+    -DANDROID_PLATFORM=android-21 ..
+make VERBOSE=1 -j4
+make install/strip

+ 31 - 0
build-android-armv7-eabi.sh

@@ -0,0 +1,31 @@
+#!/usr/bin/env bash
+set -x
+
+# see https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
+
+dir=build-android-armv7-eabi
+
+mkdir -p $dir
+cd $dir
+
+if [ -z $ANDROID_NDK ]; then
+  ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
+  # or use
+  # ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
+  #
+  # Inside the $ANDROID_NDK directory, you can find a binary ndk-build
+  # and some other files like the file "build/cmake/android.toolchain.cmake"
+fi
+
+echo "ANDROID_NDK: $ANDROID_NDK"
+sleep 1
+
+cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
+    -DCMAKE_BUILD_TYPE=Release \
+    -DBUILD_SHARED_LIBS=ON \
+    -DSHERPA_NCNN_ENABLE_PORTAUDIO=OFF \
+    -DCMAKE_INSTALL_PREFIX=./install \
+    -DANDROID_ABI="armeabi-v7a" -DANDROID_ARM_NEON=ON \
+    -DANDROID_PLATFORM=android-21 ..
+make VERBOSE=1 -j4
+make install/strip

+ 31 - 0
build-android-x86-64.sh

@@ -0,0 +1,31 @@
+#!/usr/bin/env bash
+set -x
+
+# see https://github.com/Tencent/ncnn/wiki/how-to-build#build-for-android
+
+dir=build-android-x86-64
+
+mkdir -p $dir
+cd $dir
+
+if [ -z $ANDROID_NDK ]; then
+  ANDROID_NDK=/ceph-fj/fangjun/software/android-sdk/ndk/21.0.6113669
+  # or use
+  # ANDROID_NDK=/ceph-fj/fangjun/software/android-ndk
+  #
+  # Inside the $ANDROID_NDK directory, you can find a binary ndk-build
+  # and some other files like the file "build/cmake/android.toolchain.cmake"
+fi
+
+echo "ANDROID_NDK: $ANDROID_NDK"
+sleep 1
+
+cmake -DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
+    -DCMAKE_BUILD_TYPE=Release \
+    -DBUILD_SHARED_LIBS=ON \
+    -DSHERPA_NCNN_ENABLE_PORTAUDIO=OFF \
+    -DCMAKE_INSTALL_PREFIX=./install \
+    -DANDROID_ABI="x86_64" \
+    -DANDROID_PLATFORM=android-21 ..
+make VERBOSE=1 -j4
+make install/strip

+ 6 - 2
build-arm-linux-gnueabihf.sh

@@ -5,6 +5,10 @@ set -x
 dir=build-arm-linux-gnueabihf
 mkdir -p $dir
 cd $dir
-cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../toolchains/arm-linux-gnueabihf.toolchain.cmake ..
+cmake \
+  -DCMAKE_INSTALL_PREFIX=./install \
+  -DCMAKE_BUILD_TYPE=Release \
+  -DCMAKE_TOOLCHAIN_FILE=../toolchains/arm-linux-gnueabihf.toolchain.cmake \
+  ..
 make VERBOSE=1 -j4
-make install
+make install/strip

+ 2 - 1
cmake/kaldi-native-fbank.cmake

@@ -23,12 +23,13 @@ function(download_kaldi_native_fbank)
   message(STATUS "kaldi-native-fbank is downloaded to ${kaldi_native_fbank_SOURCE_DIR}")
   message(STATUS "kaldi-native-fbank's binary dir is ${kaldi_native_fbank_BINARY_DIR}")
 
-  add_subdirectory(${kaldi_native_fbank_SOURCE_DIR} ${kaldi_native_fbank_BINARY_DIR} EXCLUDE_FROM_ALL)
+  add_subdirectory(${kaldi_native_fbank_SOURCE_DIR} ${kaldi_native_fbank_BINARY_DIR})
 
   target_include_directories(kaldi-native-fbank-core
     INTERFACE
       ${kaldi_native_fbank_SOURCE_DIR}/
   )
+  install(TARGETS kaldi-native-fbank-core DESTINATION lib)
 endfunction()
 
 download_kaldi_native_fbank()

+ 5 - 3
cmake/ncnn.cmake

@@ -7,7 +7,7 @@ function(download_ncnn)
 
   # If you don't have access to the internet, please download it to your
   # local drive and modify the following line according to your needs.
-  # set(ncnn_URL  "file:///ceph-fj/fangjun/sherpa-0.7.tar.gz")
+  # set(ncnn_URL  "file:///ceph-fj/fangjun/open-source/sherpa-ncnn/sherpa-0.7.tar.gz")
   set(ncnn_URL "https://github.com/csukuangfj/ncnn/archive/refs/tags/sherpa-0.7.tar.gz")
 
   set(ncnn_HASH "SHA256=fdf3cc29a43bfb3e2d7cdbbc98a7e69d0a3cc8922b67c47c4c2c8ac28125ae9c")
@@ -17,13 +17,14 @@ function(download_ncnn)
     URL_HASH          ${ncnn_HASH}
   )
 
-  set(NCNN_INSTALL_SDK OFF CACHE BOOL "" FORCE)
   set(NCNN_PIXEL OFF CACHE BOOL "" FORCE)
   set(NCNN_PIXEL_ROTATE OFF CACHE BOOL "" FORCE)
   set(NCNN_PIXEL_AFFINE OFF CACHE BOOL "" FORCE)
   set(NCNN_PIXEL_DRAWING OFF CACHE BOOL "" FORCE)
   set(NCNN_BUILD_BENCHMARK OFF CACHE BOOL "" FORCE)
 
+  set(NCNN_SHARED_LIB ${BUILD_SHARED_LIBS} CACHE BOOL "" FORCE)
+
   set(NCNN_BUILD_TOOLS OFF CACHE BOOL "" FORCE)
   set(NCNN_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
   set(NCNN_BUILD_TESTS OFF CACHE BOOL "" FORCE)
@@ -151,7 +152,8 @@ function(download_ncnn)
   message(STATUS "ncnn is downloaded to ${ncnn_SOURCE_DIR}")
   message(STATUS "ncnn's binary dir is ${ncnn_BINARY_DIR}")
 
-  add_subdirectory(${ncnn_SOURCE_DIR} ${ncnn_BINARY_DIR} EXCLUDE_FROM_ALL)
+  add_subdirectory(${ncnn_SOURCE_DIR} ${ncnn_BINARY_DIR})
+  install(TARGETS ncnn DESTINATION lib)
 endfunction()
 
 download_ncnn()

+ 4 - 0
sherpa-ncnn/CMakeLists.txt

@@ -1 +1,5 @@
 add_subdirectory(csrc)
+
+if(DEFINED ANDROID_ABI)
+  add_subdirectory(jni)
+endif()

+ 33 - 11
sherpa-ncnn/csrc/CMakeLists.txt

@@ -12,18 +12,40 @@ set(sherpa_ncnn_core_srcs
 )
 add_library(sherpa-ncnn-core ${sherpa_ncnn_core_srcs})
 target_link_libraries(sherpa-ncnn-core kaldi-native-fbank-core ncnn)
+set_target_properties(sherpa-ncnn-core PROPERTIES VERSION ${SHERPA_NCNN_VERSION} SOVERSION ${SHERPA_NCNN_VERSION_MAJOR})
+install(TARGETS sherpa-ncnn-core DESTINATION lib)
 
-add_executable(sherpa-ncnn sherpa-ncnn.cc)
-target_link_libraries(sherpa-ncnn sherpa-ncnn-core)
+if(NOT DEFINED ANDROID_ABI)
+  add_executable(sherpa-ncnn sherpa-ncnn.cc)
+  target_link_libraries(sherpa-ncnn sherpa-ncnn-core)
+  install(TARGETS sherpa-ncnn DESTINATION bin)
 
-add_executable(sherpa-ncnn-microphone
-  sherpa-ncnn-microphone.cc
-  microphone.cc
-)
-if(BUILD_SHARED_LIBS)
-  set(PA_LIB portaudio)
-else()
-  set(PA_LIB portaudio_static)
+  if(SHERPA_NCNN_ENABLE_PORTAUDIO)
+    add_executable(sherpa-ncnn-microphone
+      sherpa-ncnn-microphone.cc
+      microphone.cc
+    )
+
+    if(BUILD_SHARED_LIBS)
+      set(PA_LIB portaudio)
+    else()
+      set(PA_LIB portaudio_static)
+    endif()
+
+    target_link_libraries(sherpa-ncnn-microphone ${PA_LIB} sherpa-ncnn-core)
+
+    install(TARGETS sherpa-ncnn-microphone DESTINATION bin)
+  endif()
 endif()
 
-target_link_libraries(sherpa-ncnn-microphone ${PA_LIB} sherpa-ncnn-core)
+set(hdrs
+  decode.h
+  features.h
+  model.h
+  symbol-table.h
+  wave-reader.h
+)
+
+install(FILES ${hdrs}
+  DESTINATION include/sherpa-ncnn/csrc
+)

+ 44 - 5
sherpa-ncnn/csrc/conv-emformer-model.cc

@@ -25,6 +25,20 @@ ConvEmformerModel::ConvEmformerModel(const ModelConfig &config)
   InitJoinerInputOutputIndexes();
 }
 
+#if __ANDROID_API__ >= 9
+ConvEmformerModel::ConvEmformerModel(AAssetManager *mgr,
+                                     const ModelConfig &config)
+    : num_threads_(config.num_threads) {
+  InitEncoder(mgr, config.encoder_param, config.encoder_bin);
+  InitDecoder(mgr, config.decoder_param, config.decoder_bin);
+  InitJoiner(mgr, config.joiner_param, config.joiner_bin);
+
+  InitEncoderInputOutputIndexes();
+  InitDecoderInputOutputIndexes();
+  InitJoinerInputOutputIndexes();
+}
+#endif
+
 std::pair<ncnn::Mat, std::vector<ncnn::Mat>> ConvEmformerModel::RunEncoder(
     ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
   std::vector<ncnn::Mat> _states;
@@ -81,11 +95,7 @@ ncnn::Mat ConvEmformerModel::RunJoiner(ncnn::Mat &encoder_out,
   return joiner_out;
 }
 
-void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
-                                    const std::string &encoder_bin) {
-  RegisterMetaDataLayer(encoder_);
-  InitNet(encoder_, encoder_param, encoder_bin);
-
+void ConvEmformerModel::InitEncoderPostProcessing() {
   // Now load parameters for member variables
   for (const auto *layer : encoder_.layers()) {
     if (layer->type == "SherpaMetaData" && layer->name == "sherpa_meta_data1") {
@@ -107,6 +117,13 @@ void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
   }
 }
 
+void ConvEmformerModel::InitEncoder(const std::string &encoder_param,
+                                    const std::string &encoder_bin) {
+  RegisterMetaDataLayer(encoder_);
+  InitNet(encoder_, encoder_param, encoder_bin);
+  InitEncoderPostProcessing();
+}
+
 void ConvEmformerModel::InitDecoder(const std::string &decoder_param,
                                     const std::string &decoder_bin) {
   InitNet(decoder_, decoder_param, decoder_bin);
@@ -117,6 +134,28 @@ void ConvEmformerModel::InitJoiner(const std::string &joiner_param,
   InitNet(joiner_, joiner_param, joiner_bin);
 }
 
+#if __ANDROID_API__ >= 9
+void ConvEmformerModel::InitEncoder(AAssetManager *mgr,
+                                    const std::string &encoder_param,
+                                    const std::string &encoder_bin) {
+  RegisterMetaDataLayer(encoder_);
+  InitNet(mgr, encoder_, encoder_param, encoder_bin);
+  InitEncoderPostProcessing();
+}
+
+void ConvEmformerModel::InitDecoder(AAssetManager *mgr,
+                                    const std::string &decoder_param,
+                                    const std::string &decoder_bin) {
+  InitNet(mgr, decoder_, decoder_param, decoder_bin);
+}
+
+void ConvEmformerModel::InitJoiner(AAssetManager *mgr,
+                                   const std::string &joiner_param,
+                                   const std::string &joiner_bin) {
+  InitNet(mgr, joiner_, joiner_param, joiner_bin);
+}
+#endif
+
 std::vector<ncnn::Mat> ConvEmformerModel::GetEncoderInitStates() const {
   std::vector<ncnn::Mat> states;
   states.reserve(num_layers_ * 4);

+ 14 - 0
sherpa-ncnn/csrc/conv-emformer-model.h

@@ -17,6 +17,9 @@ namespace sherpa_ncnn {
 class ConvEmformerModel : public Model {
  public:
   explicit ConvEmformerModel(const ModelConfig &config);
+#if __ANDROID_API__ >= 9
+  ConvEmformerModel(AAssetManager *mgr, const ModelConfig &config);
+#endif
 
   std::pair<ncnn::Mat, std::vector<ncnn::Mat>> RunEncoder(
       ncnn::Mat &features, const std::vector<ncnn::Mat> &states) override;
@@ -46,6 +49,17 @@ class ConvEmformerModel : public Model {
   void InitJoiner(const std::string &joiner_param,
                   const std::string &joiner_bin);
 
+  void InitEncoderPostProcessing();
+
+#if __ANDROID_API__ >= 9
+  void InitEncoder(AAssetManager *mgr, const std::string &encoder_param,
+                   const std::string &encoder_bin);
+  void InitDecoder(AAssetManager *mgr, const std::string &decoder_param,
+                   const std::string &decoder_bin);
+  void InitJoiner(AAssetManager *mgr, const std::string &joiner_param,
+                  const std::string &joiner_bin);
+#endif
+
   std::vector<ncnn::Mat> GetEncoderInitStates() const;
 
   void InitEncoderInputOutputIndexes();

+ 1 - 7
sherpa-ncnn/csrc/features.cc

@@ -25,13 +25,7 @@
 
 namespace sherpa_ncnn {
 
-FeatureExtractor::FeatureExtractor() {
-  knf::FbankOptions opts;
-  opts.frame_opts.dither = 0;
-  opts.frame_opts.snip_edges = false;
-  opts.frame_opts.samp_freq = expected_sampling_rate_;
-
-  opts.mel_opts.num_bins = 80;
+FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts) {
   fbank_ = std::make_unique<knf::OnlineFbank>(opts);
 }
 

+ 1 - 2
sherpa-ncnn/csrc/features.h

@@ -32,7 +32,7 @@ namespace sherpa_ncnn {
 
 class FeatureExtractor {
  public:
-  FeatureExtractor();
+  explicit FeatureExtractor(const knf::FbankOptions &fbank_opts);
 
   /**
      @param sampling_rate The sampling_rate of the input waveform. Should match
@@ -66,7 +66,6 @@ class FeatureExtractor {
  private:
   std::unique_ptr<knf::OnlineFbank> fbank_;
   mutable std::mutex mutex_;
-  float expected_sampling_rate_ = 16000;
 };
 
 }  // namespace sherpa_ncnn

+ 34 - 0
sherpa-ncnn/csrc/lstm-model.cc

@@ -33,6 +33,19 @@ LstmModel::LstmModel(const ModelConfig &config)
   InitJoinerInputOutputIndexes();
 }
 
+#if __ANDROID_API__ >= 9
+LstmModel::LstmModel(AAssetManager *mgr, const ModelConfig &config)
+    : num_threads_(config.num_threads) {
+  InitEncoder(mgr, config.encoder_param, config.encoder_bin);
+  InitDecoder(mgr, config.decoder_param, config.decoder_bin);
+  InitJoiner(mgr, config.joiner_param, config.joiner_bin);
+
+  InitEncoderInputOutputIndexes();
+  InitDecoderInputOutputIndexes();
+  InitJoinerInputOutputIndexes();
+}
+#endif
+
 std::pair<ncnn::Mat, std::vector<ncnn::Mat>> LstmModel::RunEncoder(
     ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
   ncnn::Mat hx;
@@ -109,6 +122,27 @@ void LstmModel::InitJoiner(const std::string &joiner_param,
   InitNet(joiner_, joiner_param, joiner_bin);
 }
 
+#if __ANDROID_API__ >= 9
+void LstmModel::InitEncoder(AAssetManager *mgr,
+                            const std::string &encoder_param,
+                            const std::string &encoder_bin) {
+  encoder_.opt.use_packing_layout = false;
+  encoder_.opt.use_fp16_storage = false;
+  InitNet(mgr, encoder_, encoder_param, encoder_bin);
+}
+
+void LstmModel::InitDecoder(AAssetManager *mgr,
+                            const std::string &decoder_param,
+                            const std::string &decoder_bin) {
+  InitNet(mgr, decoder_, decoder_param, decoder_bin);
+}
+
+void LstmModel::InitJoiner(AAssetManager *mgr, const std::string &joiner_param,
+                           const std::string &joiner_bin) {
+  InitNet(mgr, joiner_, joiner_param, joiner_bin);
+}
+#endif
+
 std::vector<ncnn::Mat> LstmModel::GetEncoderInitStates() const {
   int32_t num_encoder_layers = 12;
   int32_t d_model = 512;

+ 12 - 0
sherpa-ncnn/csrc/lstm-model.h

@@ -31,6 +31,9 @@ namespace sherpa_ncnn {
 class LstmModel : public Model {
  public:
   explicit LstmModel(const ModelConfig &config);
+#if __ANDROID_API__ >= 9
+  LstmModel(AAssetManager *mgr, const ModelConfig &config);
+#endif
 
   /** Run the encoder network.
    *
@@ -73,6 +76,15 @@ class LstmModel : public Model {
   void InitJoiner(const std::string &joiner_param,
                   const std::string &joiner_bin);
 
+#if __ANDROID_API__ >= 9
+  void InitEncoder(AAssetManager *mgr, const std::string &encoder_param,
+                   const std::string &encoder_bin);
+  void InitDecoder(AAssetManager *mgr, const std::string &decoder_param,
+                   const std::string &decoder_bin);
+  void InitJoiner(AAssetManager *mgr, const std::string &joiner_param,
+                  const std::string &joiner_bin);
+#endif
+
   std::vector<ncnn::Mat> GetEncoderInitStates() const;
 
   void InitEncoderInputOutputIndexes();

+ 39 - 0
sherpa-ncnn/csrc/model.cc

@@ -85,6 +85,21 @@ void Model::InitNet(ncnn::Net &net, const std::string &param,
   }
 }
 
+#if __ANDROID_API__ >= 9
+void Model::InitNet(AAssetManager *mgr, ncnn::Net &net,
+                    const std::string &param, const std::string &bin) {
+  if (net.load_param(mgr, param.c_str())) {
+    NCNN_LOGE("failed to load %s", param.c_str());
+    exit(-1);
+  }
+
+  if (net.load_model(mgr, bin.c_str())) {
+    NCNN_LOGE("failed to load %s", bin.c_str());
+    exit(-1);
+  }
+}
+#endif
+
 std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
   // 1. Load the encoder network
   // 2. If the encoder network has LSTM layers, we assume it is a LstmModel
@@ -112,4 +127,28 @@ std::unique_ptr<Model> Model::Create(const ModelConfig &config) {
   return nullptr;
 }
 
+#if __ANDROID_API__ >= 9
+std::unique_ptr<Model> Model::Create(AAssetManager *mgr,
+                                     const ModelConfig &config) {
+  ncnn::Net net;
+  RegisterMetaDataLayer(net);
+
+  auto ret = net.load_param(mgr, config.encoder_param.c_str());
+  if (ret != 0) {
+    NCNN_LOGE("Failed to load %s", config.encoder_param.c_str());
+    return nullptr;
+  }
+
+  if (IsLstmModel(net)) {
+    return std::make_unique<LstmModel>(mgr, config);
+  }
+
+  if (IsConvEmformerModel(net)) {
+    return std::make_unique<ConvEmformerModel>(mgr, config);
+  }
+
+  return nullptr;
+}
+#endif
+
 }  // namespace sherpa_ncnn

+ 13 - 2
sherpa-ncnn/csrc/model.h

@@ -46,8 +46,10 @@ class Model {
   /** Create a model from a config. */
   static std::unique_ptr<Model> Create(const ModelConfig &config);
 
-  static void InitNet(ncnn::Net &net, const std::string &param,
-                      const std::string &bin);
+#if __ANDROID_API__ >= 9
+  static std::unique_ptr<Model> Create(AAssetManager *mgr,
+                                       const ModelConfig &config);
+#endif
 
   /** Run the encoder network.
    *
@@ -94,6 +96,15 @@ class Model {
   // Advance the feature extractor by this number of frames after
   // running the encoder network
   virtual int32_t Offset() const = 0;
+
+ protected:
+  static void InitNet(ncnn::Net &net, const std::string &param,
+                      const std::string &bin);
+
+#if __ANDROID_API__ >= 9
+  static void InitNet(AAssetManager *mgr, ncnn::Net &net,
+                      const std::string &param, const std::string &bin);
+#endif
 };
 
 }  // namespace sherpa_ncnn

+ 8 - 2
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -95,9 +95,16 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
     exit(EXIT_FAILURE);
   }
 
+  float sample_rate = 16000;
   sherpa_ncnn::Microphone mic;
 
-  sherpa_ncnn::FeatureExtractor feature_extractor;
+  knf::FbankOptions fbank_opts;
+  fbank_opts.frame_opts.dither = 0;
+  fbank_opts.frame_opts.snip_edges = false;
+  fbank_opts.frame_opts.samp_freq = sample_rate;
+  fbank_opts.mel_opts.num_bins = 80;
+
+  sherpa_ncnn::FeatureExtractor feature_extractor(fbank_opts);
 
   PaDeviceIndex num_devices = Pa_GetDeviceCount();
   fprintf(stderr, "Num devices: %d\n", num_devices);
@@ -120,7 +127,6 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
 
   param.suggestedLatency = info->defaultLowInputLatency;
   param.hostApiSpecificStreamInfo = nullptr;
-  float sample_rate = 16000;
 
   PaStream *stream;
   PaError err =

+ 13 - 43
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -27,47 +27,6 @@
 #include "sherpa-ncnn/csrc/symbol-table.h"
 #include "sherpa-ncnn/csrc/wave-reader.h"
 
-/** Compute fbank features of the input wave filename.
- *
- * @param wav_filename. Path to a mono wave file.
- * @param expected_sampling_rate  Expected sampling rate of the input wave file.
- * @return Return a mat of shape (num_frames, feature_dim).
- *         Note: ans.w == feature_dim; ans.h == num_frames
- *
- */
-static ncnn::Mat ComputeFeatures(const std::string &wav_filename,
-                                 float expected_sampling_rate) {
-  std::vector<float> samples =
-      sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate);
-
-  float duration = samples.size() / expected_sampling_rate;
-
-  std::cout << "wav filename: " << wav_filename << "\n";
-  std::cout << "wav duration (s): " << duration << "\n";
-
-  knf::FbankOptions opts;
-  opts.frame_opts.dither = 0;
-  opts.frame_opts.snip_edges = false;
-  opts.frame_opts.samp_freq = expected_sampling_rate;
-
-  opts.mel_opts.num_bins = 80;
-
-  knf::OnlineFbank fbank(opts);
-  fbank.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
-  fbank.InputFinished();
-
-  int32_t feature_dim = 80;
-  ncnn::Mat features;
-  features.create(feature_dim, fbank.NumFramesReady());
-
-  for (int32_t i = 0; i != fbank.NumFramesReady(); ++i) {
-    const float *f = fbank.GetFrame(i);
-    std::copy(f, f + feature_dim, features.row(i));
-  }
-
-  return features;
-}
-
 int main(int argc, char *argv[]) {
   if (argc < 9 || argc > 10) {
     const char *usage = R"usage(
@@ -119,15 +78,26 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
     exit(EXIT_FAILURE);
   }
 
+  bool is_ok = false;
   std::vector<float> samples =
-      sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate);
+      sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
+  if (!is_ok) {
+    fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
+    exit(-1);
+  }
 
   float duration = samples.size() / expected_sampling_rate;
 
   std::cout << "wav filename: " << wav_filename << "\n";
   std::cout << "wav duration (s): " << duration << "\n";
 
-  sherpa_ncnn::FeatureExtractor feature_extractor;
+  knf::FbankOptions fbank_opts;
+  fbank_opts.frame_opts.dither = 0;
+  fbank_opts.frame_opts.snip_edges = false;
+  fbank_opts.frame_opts.samp_freq = expected_sampling_rate;
+  fbank_opts.mel_opts.num_bins = 80;
+
+  sherpa_ncnn::FeatureExtractor feature_extractor(fbank_opts);
   feature_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
                                    samples.size());
 

+ 28 - 0
sherpa-ncnn/csrc/symbol-table.cc

@@ -21,11 +21,39 @@
 #include <cassert>
 #include <fstream>
 #include <sstream>
+#include <strstream>
+
+#if __ANDROID_API__ >= 9
+#include "android/asset_manager.h"
+#include "android/asset_manager_jni.h"
+#include "android/log.h"
+#endif
 
 namespace sherpa_ncnn {
 
 SymbolTable::SymbolTable(const std::string &filename) {
   std::ifstream is(filename);
+  Init(is);
+}
+
+#if __ANDROID_API__ >= 9
+SymbolTable::SymbolTable(AAssetManager *mgr, const std::string &filename) {
+  AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
+  if (!asset) {
+    __android_log_print(ANDROID_LOG_FATAL, "sherpa-ncnn",
+                        "SymbolTable: Load %s failed", filename.c_str());
+    exit(-1);
+  }
+
+  auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
+  size_t asset_length = AAsset_getLength(asset);
+  std::istrstream is(p, asset_length);
+  Init(is);
+  AAsset_close(asset);
+}
+#endif
+
+void SymbolTable::Init(std::istream &is) {
   std::string sym;
   int32_t id;
   while (is >> sym >> id) {

+ 12 - 0
sherpa-ncnn/csrc/symbol-table.h

@@ -22,6 +22,11 @@
 #include <string>
 #include <unordered_map>
 
+#if __ANDROID_API__ >= 9
+#include "android/asset_manager.h"
+#include "android/asset_manager_jni.h"
+#endif
+
 namespace sherpa_ncnn {
 
 /// It manages mapping between symbols and integer IDs.
@@ -36,6 +41,10 @@ class SymbolTable {
   /// Fields are separated by space(s).
   explicit SymbolTable(const std::string &filename);
 
+#if __ANDROID_API__ >= 9
+  SymbolTable(AAssetManager *mgr, const std::string &filename);
+#endif
+
   /// Return a string representation of this symbol table
   std::string ToString() const;
 
@@ -50,6 +59,9 @@ class SymbolTable {
   /// Return true if there is a given symbol in the symbol table.
   bool contains(const std::string &sym) const;
 
+ private:
+  void Init(std::istream &is);
+
  private:
   std::unordered_map<std::string, int32_t> sym2id_;
   std::unordered_map<int32_t, std::string> id2sym_;

+ 69 - 28
sherpa-ncnn/csrc/wave-reader.cc

@@ -31,18 +31,44 @@ namespace {
 // Note: We assume little endian here
 // TODO(fangjun): Support big endian
 struct WaveHeader {
-  void Validate() const {
-    //                    F F I R
-    assert(chunk_id == 0x46464952);
-    //                  E V A W
-    assert(format == 0x45564157);
-    assert(subchunk1_id == 0x20746d66);
-    assert(subchunk1_size == 16);  // 16 for PCM
-    assert(audio_format == 1);     // 1 for PCM
-    assert(num_channels == 1);     // we support only single channel for now
-    assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
-    assert(block_align == num_channels * bits_per_sample / 8);
-    assert(bits_per_sample == 16);  // we support only 16 bits per sample
+  bool Validate() const {
+    //                 F F I R
+    if (chunk_id != 0x46464952) {
+      return false;
+    }
+    //               E V A W
+    if (format != 0x45564157) {
+      return false;
+    }
+
+    if (subchunk1_id != 0x20746d66) {
+      return false;
+    }
+
+    if (subchunk1_size != 16) {  // 16 for PCM
+      return false;
+    }
+
+    if (audio_format != 1) {  // 1 for PCM
+      return false;
+    }
+
+    if (num_channels != 1) {  // we support only single channel for now
+      return false;
+    }
+    if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
+      return false;
+    }
+
+    if (block_align != (num_channels * bits_per_sample / 8)) {
+      return false;
+    }
+
+    if (bits_per_sample != 16) {  // we support only 16 bits per sample
+      return false;
+    }
+
+    return true;
   }
 
   // See
@@ -79,46 +105,61 @@ static_assert(sizeof(WaveHeader) == 44, "");
 
 // Read a wave file of mono-channel.
 // Return its samples normalized to the range [-1, 1).
-std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
+std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
+                                bool *is_ok) {
   WaveHeader header;
   is.read(reinterpret_cast<char *>(&header), sizeof(header));
-  assert(static_cast<bool>(is));
-  header.Validate();
+  if (!is) {
+    *is_ok = false;
+    return {};
+  }
 
-  header.SeekToDataChunk(is);
+  if (!header.Validate()) {
+    *is_ok = false;
+    return {};
+  }
 
-  assert(static_cast<bool>(is));
+  header.SeekToDataChunk(is);
+  if (!is) {
+    *is_ok = false;
+    return {};
+  }
 
-  *sample_rate = header.sample_rate;
+  if (expected_sample_rate != header.sample_rate) {
+    *is_ok = false;
+    return {};
+  }
 
   // header.subchunk2_size contains the number of bytes in the data.
   // As we assume each sample contains two bytes, so it is divided by 2 here
   std::vector<int16_t> samples(header.subchunk2_size / 2);
 
   is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
-
-  assert(static_cast<bool>(is));
+  if (!is) {
+    *is_ok = false;
+    return {};
+  }
 
   std::vector<float> ans(samples.size());
   for (int32_t i = 0; i != ans.size(); ++i) {
     ans[i] = samples[i] / 32768.;
   }
 
+  *is_ok = true;
   return ans;
 }
 
 }  // namespace
 
 std::vector<float> ReadWave(const std::string &filename,
-                            float expected_sample_rate) {
+                            float expected_sample_rate, bool *is_ok) {
   std::ifstream is(filename, std::ifstream::binary);
-  float sample_rate;
-  auto samples = ReadWaveImpl(is, &sample_rate);
-  if (expected_sample_rate != sample_rate) {
-    std::cerr << "Expected sample rate: " << expected_sample_rate
-              << ". Given: " << sample_rate << ".\n";
-    exit(-1);
-  }
+  return ReadWave(is, expected_sample_rate, is_ok);
+}
+
+std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
+                            bool *is_ok) {
+  auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok);
   return samples;
 }
 

+ 5 - 1
sherpa-ncnn/csrc/wave-reader.h

@@ -30,11 +30,15 @@ namespace sherpa_ncnn {
     @param filename Path to a wave file. It MUST be single channel, PCM encoded.
     @param expected_sample_rate  Expected sample rate of the wave file. If the
                                sample rate don't match, it throws an exception.
+    @param is_ok On return it is true if the reading succeeded; false otherwise.
 
     @return Return wave samples normalized to the range [-1, 1).
  */
 std::vector<float> ReadWave(const std::string &filename,
-                            float expected_sample_rate);
+                            float expected_sample_rate, bool *is_ok);
+
+std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
+                            bool *is_ok);
 
 }  // namespace sherpa_ncnn
 

+ 5 - 0
sherpa-ncnn/jni/CMakeLists.txt

@@ -0,0 +1,5 @@
+include_directories(${CMAKE_SOURCE_DIR})
+
+add_library(sherpa-ncnn-jni jni.cc)
+target_link_libraries(sherpa-ncnn-jni sherpa-ncnn-core)
+install(TARGETS sherpa-ncnn-jni DESTINATION lib)

+ 352 - 0
sherpa-ncnn/jni/jni.cc

@@ -0,0 +1,352 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// TODO(fangjun): Add documentation to functions/methods in this file
+// and also show how to use them with kotlin, possibly with java.
+
+// If you use ndk, you can find "jni.h" inside
+// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
+#include "jni.h"  // NOLINT
+
+#include <strstream>
+
+#include "android/asset_manager.h"
+#include "android/asset_manager_jni.h"
+#include "sherpa-ncnn/csrc/decode.h"
+#include "sherpa-ncnn/csrc/features.h"
+#include "sherpa-ncnn/csrc/model.h"
+#include "sherpa-ncnn/csrc/symbol-table.h"
+#include "sherpa-ncnn/csrc/wave-reader.h"
+
+#define SHERPA_EXTERN_C extern "C"
+
+namespace sherpa_ncnn {
+
+class SherpaNcnn {
+ public:
+  SherpaNcnn(AAssetManager *mgr, const ModelConfig &model_config,
+             const knf::FbankOptions &fbank_config, const std::string &tokens)
+      : model_(Model::Create(mgr, model_config)),
+        feature_extractor_(fbank_config),
+        sym_(mgr, tokens) {
+    // Initialize decoder_output
+    int32_t context_size = model_->ContextSize();
+    int32_t blank_id = 0;
+
+    ncnn::Mat decoder_input(context_size);
+    for (int32_t i = 0; i != context_size; ++i) {
+      static_cast<int32_t *>(decoder_input)[i] = blank_id;
+    }
+
+    decoder_out_ = model_->RunDecoder(decoder_input);
+
+    hyp_.resize(context_size, 0);
+  }
+
+  void DecodeSamples(float sample_rate, const float *samples, int32_t n) {
+    feature_extractor_.AcceptWaveform(sample_rate, samples, n);
+    Decode();
+  }
+
+  void InputFinished() {
+    feature_extractor_.InputFinished();
+    Decode();
+  }
+
+  std::string GetText() const {
+    int32_t context_size = model_->ContextSize();
+
+    std::string text;
+    for (int32_t i = context_size; i != static_cast<int32_t>(hyp_.size());
+         ++i) {
+      text += sym_[hyp_[i]];
+    }
+    return text;
+  }
+
+ private:
+  void Decode() {
+    int32_t segment = model_->Segment();
+    int32_t offset = model_->Offset();
+
+    ncnn::Mat encoder_out;
+    while (feature_extractor_.NumFramesReady() - num_processed_ >= segment) {
+      ncnn::Mat features =
+          feature_extractor_.GetFrames(num_processed_, segment);
+      num_processed_ += offset;
+
+      std::tie(encoder_out, states_) = model_->RunEncoder(features, states_);
+
+      GreedySearch(model_.get(), encoder_out, &decoder_out_, &hyp_);
+    }
+  }
+
+ private:
+  std::unique_ptr<Model> model_;
+  FeatureExtractor feature_extractor_;
+  sherpa_ncnn::SymbolTable sym_;
+
+  std::vector<int32_t> hyp_;
+  ncnn::Mat decoder_out_;
+  std::vector<ncnn::Mat> states_;
+
+  // number of processed frames
+  int32_t num_processed_ = 0;
+};
+
+static ModelConfig GetModelConfig(JNIEnv *env, jobject config) {
+  ModelConfig model_config;
+
+  jclass cls = env->GetObjectClass(config);
+
+  jfieldID fid = env->GetFieldID(cls, "encoderParam", "Ljava/lang/String;");
+  jstring s = (jstring)env->GetObjectField(config, fid);
+  const char *p = env->GetStringUTFChars(s, nullptr);
+  model_config.encoder_param = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "encoderBin", "Ljava/lang/String;");
+  s = (jstring)env->GetObjectField(config, fid);
+  p = env->GetStringUTFChars(s, nullptr);
+  model_config.encoder_bin = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "decoderParam", "Ljava/lang/String;");
+  s = (jstring)env->GetObjectField(config, fid);
+  p = env->GetStringUTFChars(s, nullptr);
+  model_config.decoder_param = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "decoderBin", "Ljava/lang/String;");
+  s = (jstring)env->GetObjectField(config, fid);
+  p = env->GetStringUTFChars(s, nullptr);
+  model_config.decoder_bin = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "joinerParam", "Ljava/lang/String;");
+  s = (jstring)env->GetObjectField(config, fid);
+  p = env->GetStringUTFChars(s, nullptr);
+  model_config.joiner_param = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "joinerBin", "Ljava/lang/String;");
+  s = (jstring)env->GetObjectField(config, fid);
+  p = env->GetStringUTFChars(s, nullptr);
+  model_config.joiner_bin = p;
+  env->ReleaseStringUTFChars(s, p);
+
+  fid = env->GetFieldID(cls, "numThreads", "I");
+  model_config.num_threads = env->GetIntField(config, fid);
+
+  return model_config;
+}
+
+static knf::FbankOptions GetFbankOptions(JNIEnv *env, jobject opts) {
+  jclass cls = env->GetObjectClass(opts);
+  jfieldID fid;
+
+  // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
+  // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
+
+  knf::FbankOptions fbank_opts;
+
+  fid = env->GetFieldID(cls, "use_energy", "Z");
+  fbank_opts.use_energy = env->GetBooleanField(opts, fid);
+
+  fid = env->GetFieldID(cls, "energy_floor", "F");
+  fbank_opts.energy_floor = env->GetFloatField(opts, fid);
+
+  fid = env->GetFieldID(cls, "raw_energy", "Z");
+  fbank_opts.raw_energy = env->GetBooleanField(opts, fid);
+
+  fid = env->GetFieldID(cls, "htk_compat", "Z");
+  fbank_opts.htk_compat = env->GetBooleanField(opts, fid);
+
+  fid = env->GetFieldID(cls, "use_log_fbank", "Z");
+  fbank_opts.use_log_fbank = env->GetBooleanField(opts, fid);
+
+  fid = env->GetFieldID(cls, "use_power", "Z");
+  fbank_opts.use_power = env->GetBooleanField(opts, fid);
+
+  fid = env->GetFieldID(cls, "frame_opts",
+                        "Lcom/k2fsa/sherpa/ncnn/FrameExtractionOptions;");
+
+  jobject frame_opts = env->GetObjectField(opts, fid);
+  jclass frame_opts_cls = env->GetObjectClass(frame_opts);
+
+  fid = env->GetFieldID(frame_opts_cls, "samp_freq", "F");
+  fbank_opts.frame_opts.samp_freq = env->GetFloatField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "frame_shift_ms", "F");
+  fbank_opts.frame_opts.frame_shift_ms = env->GetFloatField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "frame_length_ms", "F");
+  fbank_opts.frame_opts.frame_length_ms = env->GetFloatField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "dither", "F");
+  fbank_opts.frame_opts.dither = env->GetFloatField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "preemph_coeff", "F");
+  fbank_opts.frame_opts.preemph_coeff = env->GetFloatField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "remove_dc_offset", "Z");
+  fbank_opts.frame_opts.remove_dc_offset =
+      env->GetBooleanField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "window_type", "Ljava/lang/String;");
+  jstring window_type = (jstring)env->GetObjectField(frame_opts, fid);
+  const char *p_window_type = env->GetStringUTFChars(window_type, nullptr);
+  fbank_opts.frame_opts.window_type = p_window_type;
+  env->ReleaseStringUTFChars(window_type, p_window_type);
+
+  fid = env->GetFieldID(frame_opts_cls, "round_to_power_of_two", "Z");
+  fbank_opts.frame_opts.round_to_power_of_two =
+      env->GetBooleanField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "blackman_coeff", "F");
+  fbank_opts.frame_opts.blackman_coeff = env->GetFloatField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "snip_edges", "Z");
+  fbank_opts.frame_opts.snip_edges = env->GetBooleanField(frame_opts, fid);
+
+  fid = env->GetFieldID(frame_opts_cls, "max_feature_vectors", "I");
+  fbank_opts.frame_opts.max_feature_vectors = env->GetIntField(frame_opts, fid);
+
+  fid = env->GetFieldID(cls, "mel_opts",
+                        "Lcom/k2fsa/sherpa/ncnn/MelBanksOptions;");
+  jobject mel_opts = env->GetObjectField(opts, fid);
+  jclass mel_opts_cls = env->GetObjectClass(mel_opts);
+
+  fid = env->GetFieldID(mel_opts_cls, "num_bins", "I");
+  fbank_opts.mel_opts.num_bins = env->GetIntField(mel_opts, fid);
+
+  fid = env->GetFieldID(mel_opts_cls, "low_freq", "F");
+  fbank_opts.mel_opts.low_freq = env->GetFloatField(mel_opts, fid);
+
+  fid = env->GetFieldID(mel_opts_cls, "high_freq", "F");
+  fbank_opts.mel_opts.high_freq = env->GetFloatField(mel_opts, fid);
+
+  fid = env->GetFieldID(mel_opts_cls, "vtln_low", "F");
+  fbank_opts.mel_opts.vtln_low = env->GetFloatField(mel_opts, fid);
+
+  fid = env->GetFieldID(mel_opts_cls, "vtln_high", "F");
+  fbank_opts.mel_opts.vtln_high = env->GetFloatField(mel_opts, fid);
+
+  fid = env->GetFieldID(mel_opts_cls, "debug_mel", "Z");
+  fbank_opts.mel_opts.debug_mel = env->GetBooleanField(mel_opts, fid);
+
+  fid = env->GetFieldID(mel_opts_cls, "htk_mode", "Z");
+  fbank_opts.mel_opts.htk_mode = env->GetBooleanField(mel_opts, fid);
+
+  return fbank_opts;
+}
+
+}  // namespace sherpa_ncnn
+
+SHERPA_EXTERN_C
+JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_new(
+    JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _model_config,
+    jobject _fbank_config, jstring tokens) {
+  AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
+  if (!mgr) {
+    NCNN_LOGE("Failed to get asset manager: %p", mgr);
+  }
+
+  sherpa_ncnn::ModelConfig model_config =
+      sherpa_ncnn::GetModelConfig(env, _model_config);
+
+  knf::FbankOptions fbank_opts =
+      sherpa_ncnn::GetFbankOptions(env, _fbank_config);
+
+  const char *p_tokens = env->GetStringUTFChars(tokens, nullptr);
+  auto model =
+      new sherpa_ncnn::SherpaNcnn(mgr, model_config, fbank_opts, p_tokens);
+  env->ReleaseStringUTFChars(tokens, p_tokens);
+
+  return (jlong)model;
+}
+
+SHERPA_EXTERN_C
+JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_delete(
+    JNIEnv *env, jobject /*obj*/, jlong ptr) {
+  delete reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr);
+}
+
+SHERPA_EXTERN_C
+JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_decodeSamples(
+    JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
+    jfloat sample_rate) {
+  auto model = reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr);
+
+  jfloat *p = env->GetFloatArrayElements(samples, nullptr);
+  jsize n = env->GetArrayLength(samples);
+
+  model->DecodeSamples(sample_rate, p, n);
+
+  env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
+}
+
+SHERPA_EXTERN_C
+JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_inputFinished(
+    JNIEnv *env, jobject /*obj*/, jlong ptr) {
+  reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr)->InputFinished();
+}
+
+SHERPA_EXTERN_C
+JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_getText(
+    JNIEnv *env, jobject /*obj*/, jlong ptr) {
+  // see
+  // https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
+  auto text = reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr)->GetText();
+  return env->NewStringUTF(text.c_str());
+}
+
+SHERPA_EXTERN_C
+JNIEXPORT jfloatArray JNICALL
+Java_com_k2fsa_sherpa_ncnn_WaveReader_00024Companion_readWave(
+    JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
+    jfloat expected_sample_rate) {
+  AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
+  if (!mgr) {
+    NCNN_LOGE("Failed to get asset manager: %p", mgr);
+    return nullptr;
+  }
+  const char *p_filename = env->GetStringUTFChars(filename, nullptr);
+
+  AAsset *asset = AAssetManager_open(mgr, p_filename, AASSET_MODE_BUFFER);
+  size_t asset_length = AAsset_getLength(asset);
+  std::vector<char> buffer(asset_length);
+  AAsset_read(asset, buffer.data(), asset_length);
+
+  std::istrstream is(buffer.data(), asset_length);
+
+  bool is_ok = false;
+  std::vector<float> samples =
+      sherpa_ncnn::ReadWave(is, expected_sample_rate, &is_ok);
+
+  AAsset_close(asset);
+  env->ReleaseStringUTFChars(filename, p_filename);
+
+  if (!is_ok) {
+    return nullptr;
+  }
+
+  jfloatArray ans = env->NewFloatArray(samples.size());
+  env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
+  return ans;
+}