Эх сурвалжийг харах

Add Python API (#47)

* Begin to wrap sherpa-ncnn to Python

* Add Python API

* build shared libs by default

* Add Recognizer API
Fangjun Kuang 2 жил өмнө
parent
commit
2b5300e975

+ 4 - 0
.gitignore

@@ -1,2 +1,6 @@
 build
 build-*/
+__pycache__
+sherpa_ncnn.egg-info/
+run*.sh
+dist/

+ 10 - 7
CMakeLists.txt

@@ -3,7 +3,7 @@ project(sherpa-ncnn)
 
 set(SHERPA_NCNN_VERSION_MAJOR "1")
 set(SHERPA_NCNN_VERSION_MINOR "0")
-set(SHERPA_NCNN_VERSION "${SHERPA_NCNN_VERSION_MAJOR}.${SHERPA_NCNN_VERSION_MINOR}")
+set(SHERPA_NCNN_VERSION "1.0")
 
 set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
 set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
@@ -23,7 +23,11 @@ set(CMAKE_INSTALL_RPATH ${SHERPA_NCNN_RPATH_ORIGIN})
 set(CMAKE_BUILD_RPATH ${SHERPA_NCNN_RPATH_ORIGIN})
 
 option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
+option(SHERPA_NCNN_ENABLE_PYTHON "Whether to build Python" OFF)
+option(SHERPA_NCNN_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
+
 message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
+message(STATUS "SHERPA_NCNN_ENABLE_PYTHON ${SHERPA_NCNN_ENABLE_PYTHON}")
 
 if(NOT CMAKE_BUILD_TYPE)
   message(STATUS "No CMAKE_BUILD_TYPE given, default to Release")
@@ -34,14 +38,8 @@ message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
 set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
 set(CMAKE_CXX_EXTENSIONS OFF)
 
-list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
 list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
 
-set(CMAKE_CXX_STANDARD 11 CACHE STRING "The C++ version to be used.")
-set(CMAKE_CXX_EXTENSIONS OFF)
-
-option(SHERPA_NCNN_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
-
 include(kaldi-native-fbank)
 include(ncnn)
 
@@ -49,4 +47,9 @@ if(SHERPA_NCNN_ENABLE_PORTAUDIO)
   include(portaudio)
 endif()
 
+if(SHERPA_NCNN_ENABLE_PYTHON)
+  include(pybind11)
+endif()
+
+
 add_subdirectory(sherpa-ncnn)

+ 6 - 0
MANIFEST.in

@@ -0,0 +1,6 @@
+include LICENSE
+include README.md
+include CMakeLists.txt
+exclude pyproject.toml
+recursive-include sherpa-ncnn *.*
+recursive-include cmake *.*

+ 0 - 0
cmake/__init__.py


+ 133 - 0
cmake/cmake_extension.py

@@ -0,0 +1,133 @@
+# Copyright (c)  2021-2022  Xiaomi Corporation (author: Fangjun Kuang)
+# flake8: noqa
+
+import os
+import platform
+import shutil
+import sys
+from pathlib import Path
+
+import setuptools
+from setuptools.command.build_ext import build_ext
+
+
+def is_for_pypi():
+    ans = os.environ.get("SHERPA_NCNN_IS_FOR_PYPI", None)
+    return ans is not None
+
+
+def is_macos():
+    return platform.system() == "Darwin"
+
+
+def is_windows():
+    return platform.system() == "Windows"
+
+
+try:
+    from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
+
+    class bdist_wheel(_bdist_wheel):
+        def finalize_options(self):
+            _bdist_wheel.finalize_options(self)
+            # In this case, the generated wheel has a name in the form
+            # sherpa-xxx-pyxx-none-any.whl
+            if is_for_pypi() and not is_macos():
+                self.root_is_pure = True
+            else:
+                # The generated wheel has a name ending with
+                # -linux_x86_64.whl
+                self.root_is_pure = False
+
+
+except ImportError:
+    bdist_wheel = None
+
+
+def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
+    kwargs["language"] = "c++"
+    sources = []
+    return setuptools.Extension(name, sources, *args, **kwargs)
+
+
+class BuildExtension(build_ext):
+    def build_extension(self, ext: setuptools.extension.Extension):
+        # build/temp.linux-x86_64-3.8
+        os.makedirs(self.build_temp, exist_ok=True)
+
+        # build/lib.linux-x86_64-3.8
+        os.makedirs(self.build_lib, exist_ok=True)
+
+        out_bin_dir = Path(self.build_lib).parent / "bin"
+        install_dir = Path(self.build_lib).resolve()
+
+        sherpa_ncnn_dir = Path(__file__).parent.parent.resolve()
+
+        cmake_args = os.environ.get("SHERPA_NCNN_CMAKE_ARGS", "")
+        make_args = os.environ.get("SHERPA_NCNN_MAKE_ARGS", "")
+        system_make_args = os.environ.get("MAKEFLAGS", "")
+
+        if cmake_args == "":
+            cmake_args = "-DCMAKE_BUILD_TYPE=Release"
+
+        extra_cmake_args = f" -DCMAKE_INSTALL_PREFIX={install_dir} "
+        extra_cmake_args += f" -DBUILD_SHARED_LIBS=ON "
+        extra_cmake_args += f" -DSHERPA_NCNN_ENABLE_PYTHON=ON "
+        extra_cmake_args += f" -DSHERPA_NCNN_ENABLE_PORTAUDIO=ON "
+
+        if "PYTHON_EXECUTABLE" not in cmake_args:
+            print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
+            cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"
+
+        cmake_args += extra_cmake_args
+
+        if is_windows():
+            build_cmd = f"""
+         cmake {cmake_args} -B {self.build_temp} -S {sherpa_ncnn_dir}
+         cmake --build {self.build_temp} --target install --config Release -- -m
+            """
+            print(f"build command is:\n{build_cmd}")
+            ret = os.system(
+                f"cmake {cmake_args} -B {self.build_temp} -S {sherpa_ncnn_dir}"
+            )
+            if ret != 0:
+                raise Exception("Failed to configure sherpa")
+
+            ret = os.system(
+                f"cmake --build {self.build_temp} --target install --config Release -- -m"  # noqa
+            )
+            if ret != 0:
+                raise Exception("Failed to build and install sherpa")
+        else:
+            if make_args == "" and system_make_args == "":
+                print("for fast compilation, run:")
+                print('export SHERPA_NCNN_MAKE_ARGS="-j"; python setup.py install')
+                print('Setting make_args to "-j4"')
+                make_args = "-j4"
+
+            build_cmd = f"""
+                cd {self.build_temp}
+
+                cmake {cmake_args} {sherpa_ncnn_dir}
+
+                make {make_args} install/strip
+            """
+            print(f"build command is:\n{build_cmd}")
+
+            ret = os.system(build_cmd)
+            if ret != 0:
+                raise Exception(
+                    "\nBuild sherpa-ncnn failed. Please check the error message.\n"
+                    "You can ask for help by creating an issue on GitHub.\n"
+                    "\nClick:\n\thttps://github.com/k2-fsa/sherpa-ncnn/issues/new\n"  # noqa
+                )
+
+        suffix = ".exe" if is_windows() else ""
+        # Remember to also change setup.py
+        binaries = ["sherpa-ncnn"]
+        binaries += ["sherpa-ncnn-microphone"]
+
+        for f in binaries:
+            src_file = install_dir / "bin" / (f + suffix)
+            print(f"Copying {src_file} to {out_bin_dir}/")
+            shutil.copy(f"{src_file}", f"{out_bin_dir}/")

+ 21 - 0
cmake/pybind11.cmake

@@ -0,0 +1,21 @@
+function(download_pybind11)
+  include(FetchContent)
+
+  set(pybind11_URL  "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.2.tar.gz")
+  set(pybind11_HASH "SHA256=93bd1e625e43e03028a3ea7389bba5d3f9f2596abc074b068e70f4ef9b1314ae")
+
+  FetchContent_Declare(pybind11
+    URL               ${pybind11_URL}
+    URL_HASH          ${pybind11_HASH}
+  )
+
+  FetchContent_GetProperties(pybind11)
+  if(NOT pybind11_POPULATED)
+    message(STATUS "Downloading pybind11 from ${pybind11_URL}")
+    FetchContent_Populate(pybind11)
+  endif()
+  message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}")
+  add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL)
+endfunction()
+
+download_pybind11()

+ 84 - 0
setup.py

@@ -0,0 +1,84 @@
+#!/usr/bin/env python3
+
+import os
+import re
+import sys
+from pathlib import Path
+
+import setuptools
+
+from cmake.cmake_extension import (
+    BuildExtension,
+    bdist_wheel,
+    cmake_extension,
+    is_windows,
+)
+
+
+def read_long_description():
+    with open("README.md", encoding="utf8") as f:
+        readme = f.read()
+    return readme
+
+
+def get_package_version():
+    with open("CMakeLists.txt") as f:
+        content = f.read()
+
+    match = re.search(r"set\(SHERPA_NCNN_VERSION (.*)\)", content)
+    latest_version = match.group(1).strip('"')
+    return latest_version
+
+
+def get_binaries_to_install():
+    bin_dir = Path("build") / "bin"
+    bin_dir.mkdir(parents=True, exist_ok=True)
+    suffix = ".exe" if is_windows() else ""
+    # Remember to also change cmake/cmake_extension.py
+    binaries = ["sherpa-ncnn"]
+    binaries += ["sherpa-ncnn-microphone"]
+    exe = []
+    for f in binaries:
+        t = bin_dir / (f + suffix)
+        exe.append(str(t))
+    return exe
+
+
+package_name = "sherpa-ncnn"
+
+with open("sherpa-ncnn/python/sherpa_ncnn/__init__.py", "a") as f:
+    f.write(f"__version__ = '{get_package_version()}'\n")
+
+setuptools.setup(
+    name=package_name,
+    version=get_package_version(),
+    author="The sherpa-ncnn development team",
+    author_email="dpovey@gmail.com",
+    package_dir={
+        "sherpa_ncnn": "sherpa-ncnn/python/sherpa_ncnn",
+    },
+    data_files=[("bin", get_binaries_to_install())],
+    packages=["sherpa_ncnn"],
+    url="https://github.com/k2-fsa/sherpa-ncnn",
+    long_description=read_long_description(),
+    long_description_content_type="text/markdown",
+    ext_modules=[cmake_extension("_sherpa_ncnn")],
+    cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel},
+    zip_safe=False,
+    classifiers=[
+        "Programming Language :: C++",
+        "Programming Language :: Python",
+        "Topic :: Scientific/Engineering :: Artificial Intelligence",
+    ],
+    license="Apache licensed, as found in the LICENSE file",
+)
+
+with open("sherpa-ncnn/python/sherpa_ncnn/__init__.py", "r") as f:
+    lines = f.readlines()
+
+with open("sherpa-ncnn/python/sherpa_ncnn/__init__.py", "w") as f:
+    for line in lines:
+        if "__version__" in line:
+            # skip __version__ = "x.x.x"
+            continue
+        f.write(line)

+ 4 - 0
sherpa-ncnn/CMakeLists.txt

@@ -3,3 +3,7 @@ add_subdirectory(csrc)
 if(DEFINED ANDROID_ABI)
   add_subdirectory(jni)
 endif()
+
+if(SHERPA_NCNN_ENABLE_PYTHON)
+  add_subdirectory(python)
+endif()

+ 2 - 0
sherpa-ncnn/python/CMakeLists.txt

@@ -0,0 +1,2 @@
+add_subdirectory(csrc)
+# add_subdirectory(tests)

+ 13 - 0
sherpa-ncnn/python/csrc/CMakeLists.txt

@@ -0,0 +1,13 @@
+
+include_directories(${PROJECT_SOURCE_DIR})
+set(srcs
+  decode.cc
+  features.cc
+  mat-util.cc
+  model.cc
+  sherpa-ncnn.cc
+)
+
+pybind11_add_module(_sherpa_ncnn ${srcs})
+target_link_libraries(_sherpa_ncnn PRIVATE sherpa-ncnn-core)
+target_link_libraries(_sherpa_ncnn PRIVATE ncnn)

+ 46 - 0
sherpa-ncnn/python/csrc/decode.cc

@@ -0,0 +1,46 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/python/csrc/decode.h"
+
+#include "sherpa-ncnn/csrc/decode.h"
+#include "sherpa-ncnn/csrc/model.h"
+#include "sherpa-ncnn/python/csrc/mat-util.h"
+
+namespace sherpa_ncnn {
+
+static void PybindGreedySearch(py::module *m) {
+  m->def(
+      "greedy_search",
+      [](Model *model, py::array _encoder_out, py::array _decoder_out,
+         std::vector<int32_t> hyp)
+          -> std::pair<py::array, std::vector<int32_t>> {
+        ncnn::Mat encoder_out = ArrayToMat(_encoder_out);
+        ncnn::Mat decoder_out = ArrayToMat(_decoder_out);
+
+        GreedySearch(model, encoder_out, &decoder_out, &hyp);
+
+        return {MatToArray(decoder_out), hyp};
+      },
+      py::arg("model"), py::arg("encoder_out"), py::arg("decoder_out"),
+      py::arg("hyp"));
+}
+
+void PybindDecode(py::module *m) { PybindGreedySearch(m); }
+
+}  // namespace sherpa_ncnn

+ 30 - 0
sherpa-ncnn/python/csrc/decode.h

@@ -0,0 +1,30 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_PYTHON_CSRC_DECODE_H_
+#define SHERPA_NCNN_PYTHON_CSRC_DECODE_H_
+
+#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
+
+namespace sherpa_ncnn {
+
+void PybindDecode(py::module *m);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_PYTHON_CSRC_DECODE_H_

+ 56 - 0
sherpa-ncnn/python/csrc/features.cc

@@ -0,0 +1,56 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/features.h"
+
+#include "sherpa-ncnn/python/csrc/mat-util.h"
+#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
+
+namespace sherpa_ncnn {
+
+void PybindFeatures(py::module *m) {
+  using PyClass = FeatureExtractor;
+
+  py::class_<PyClass>(*m, "FeatureExtractor")
+      .def(py::init([](int32_t feature_dim,
+                       float sample_rate) -> std::unique_ptr<PyClass> {
+             knf::FbankOptions fbank_opts;
+             fbank_opts.frame_opts.dither = 0;
+             fbank_opts.frame_opts.snip_edges = false;
+             fbank_opts.frame_opts.samp_freq = sample_rate;
+             fbank_opts.mel_opts.num_bins = feature_dim;
+
+             return std::make_unique<PyClass>(fbank_opts);
+           }),
+           py::arg("feature_dim"), py::arg("sample_rate"))
+      .def("accept_waveform",
+           [](PyClass &self, float sample_rate, py::array_t<float> waveform) {
+             self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
+           })
+      .def("input_finished", &PyClass::InputFinished)
+      .def_property_readonly("num_frames_ready", &PyClass::NumFramesReady)
+      .def("is_last_frame", &PyClass::IsLastFrame, py::arg("frame"))
+      .def("get_frames",
+           [](PyClass &self, int32_t frame_index, int32_t n) -> py::array {
+             ncnn::Mat frames = self.GetFrames(frame_index, n);
+             return MatToArray(frames);
+           })
+      .def("reset", &PyClass::Reset);
+}
+
+}  // namespace sherpa_ncnn

+ 30 - 0
sherpa-ncnn/python/csrc/features.h

@@ -0,0 +1,30 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_PYTHON_CSRC_FEATURES_H_
+#define SHERPA_NCNN_PYTHON_CSRC_FEATURES_H_
+
+#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
+
+namespace sherpa_ncnn {
+
+void PybindFeatures(py::module *m);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_PYTHON_CSRC_FEATURES_H_

+ 97 - 0
sherpa-ncnn/python/csrc/mat-util.cc

@@ -0,0 +1,97 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/python/csrc/mat-util.h"
+
+namespace sherpa_ncnn {
+
+struct KeepMatAlive {
+  explicit KeepMatAlive(ncnn::Mat m) : m(m) {}
+
+  ncnn::Mat m;
+};
+
+py::array_t<float> MatToArray(ncnn::Mat m) {
+  std::vector<py::ssize_t> shape;
+  std::vector<py::ssize_t> strides;
+  if (m.dims == 1) {
+    shape.push_back(m.w);
+    strides.push_back(m.elemsize);
+  } else if (m.dims == 2) {
+    shape.push_back(m.h);
+    shape.push_back(m.w);
+    strides.push_back(m.w * m.elemsize);
+    strides.push_back(m.elemsize);
+  } else if (m.dims == 3) {
+    shape.push_back(m.c);
+    shape.push_back(m.h);
+    shape.push_back(m.w);
+    strides.push_back(m.cstep * m.elemsize);
+    strides.push_back(m.w * m.elemsize);
+    strides.push_back(m.elemsize);
+  } else if (m.dims == 4) {
+    shape.push_back(m.c);
+    shape.push_back(m.d);
+    shape.push_back(m.h);
+    shape.push_back(m.w);
+    strides.push_back(m.cstep * m.elemsize);
+    strides.push_back(m.w * m.h * m.elemsize);
+    strides.push_back(m.w * m.elemsize);
+    strides.push_back(m.elemsize);
+  }
+
+  auto keep_mat_alive = new KeepMatAlive(m);
+  py::capsule handle(keep_mat_alive, [](void *p) {
+    delete reinterpret_cast<KeepMatAlive *>(p);
+  });
+
+  return py::array_t<float>(shape, strides, (float *)m.data, handle);
+}
+
+ncnn::Mat ArrayToMat(py::array array) {
+  py::buffer_info info = array.request();
+  size_t elemsize = info.itemsize;
+
+  ncnn::Mat ans;
+
+  if (info.ndim == 1) {
+    ans = ncnn::Mat((int)info.shape[0], info.ptr, elemsize);
+  } else if (info.ndim == 2) {
+    ans = ncnn::Mat((int)info.shape[1], (int)info.shape[0], info.ptr, elemsize);
+  } else if (info.ndim == 3) {
+    ans = ncnn::Mat((int)info.shape[2], (int)info.shape[1], (int)info.shape[0],
+                    info.ptr, elemsize);
+
+    // in ncnn, buffer to construct ncnn::Mat need align to ncnn::alignSize
+    // with (w * h * elemsize, 16) / elemsize, but the buffer from numpy not
+    // so we set the cstep as numpy's cstep
+    ans.cstep = (int)info.shape[2] * (int)info.shape[1];
+  } else if (info.ndim == 4) {
+    ans = ncnn::Mat((int)info.shape[3], (int)info.shape[2], (int)info.shape[1],
+                    (int)info.shape[0], info.ptr, elemsize);
+
+    // in ncnn, buffer to construct ncnn::Mat need align to ncnn::alignSize
+    // with (w * h * d elemsize, 16) / elemsize, but the buffer from numpy not
+    // so we set the cstep as numpy's cstep
+    ans.cstep = (int)info.shape[3] * (int)info.shape[2] * (int)info.shape[1];
+  }
+
+  return ans;
+}
+
+}  // namespace sherpa_ncnn

+ 37 - 0
sherpa-ncnn/python/csrc/mat-util.h

@@ -0,0 +1,37 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_PYTHON_CSRC_MAT_UTIL_H_
+#define SHERPA_NCNN_PYTHON_CSRC_MAT_UTIL_H_
+
+#include "mat.h"
+#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
+
+namespace sherpa_ncnn {
+
+// Convert a ncnn::Mat to a numpy array. Data is shared.
+//
+// @param m It should be a float unpacked matrix
+py::array_t<float> MatToArray(ncnn::Mat m);
+
+// convert an array to a ncnn::Mat
+ncnn::Mat ArrayToMat(py::array array);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_PYTHON_CSRC_MODEL_UTIL_H_

+ 144 - 0
sherpa-ncnn/python/csrc/model.cc

@@ -0,0 +1,144 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/python/csrc/model.h"
+
+#include <memory>
+#include <string>
+
+#include "sherpa-ncnn/csrc/model.h"
+#include "sherpa-ncnn/python/csrc/mat-util.h"
+
+namespace sherpa_ncnn {
+
+const char *kModelConfigInitDoc = R"doc(
+Constructor for ModelConfig.
+
+Please refer to
+`<https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html>`_
+for download links about pre-trained models.
+
+Args:
+  encoder_param:
+    Path to encoder.ncnn.param.
+  encoder_bin:
+    Path to encoder.ncnn.bin.
+  decoder_param:
+    Path to decoder.ncnn.param.
+  decoder_bin:
+    Path to decoder.ncnn.bin.
+  joiner_param:
+    Path to joiner.ncnn.param.
+  joiner_bin:
+    Path to joiner.ncnn.bin.
+  num_threads:
+    Number of threads to use for neural network computation.
+)doc";
+
+static void PybindModelConfig(py::module *m) {
+  using PyClass = ModelConfig;
+  py::class_<PyClass>(*m, "ModelConfig")
+      .def(py::init([](const std::string &encoder_param,
+                       const std::string &encoder_bin,
+                       const std::string &decoder_param,
+                       const std::string &decoder_bin,
+                       const std::string &joiner_param,
+                       const std::string &joiner_bin,
+                       int32_t num_threads) -> std::unique_ptr<PyClass> {
+             auto ans = std::make_unique<PyClass>();
+             ans->encoder_param = encoder_param;
+             ans->encoder_bin = encoder_bin;
+             ans->decoder_param = decoder_param;
+             ans->decoder_bin = decoder_bin;
+             ans->joiner_param = joiner_param;
+             ans->joiner_bin = joiner_bin;
+
+             ans->use_vulkan_compute = false;
+
+             ans->encoder_opt.num_threads = num_threads;
+             ans->decoder_opt.num_threads = num_threads;
+             ans->joiner_opt.num_threads = num_threads;
+
+             return ans;
+           }),
+           py::arg("encoder_param"), py::arg("encoder_bin"),
+           py::arg("decoder_param"), py::arg("decoder_bin"),
+           py::arg("joiner_param"), py::arg("joiner_bin"),
+           py::arg("num_threads"), kModelConfigInitDoc);
+}
+
+void PybindModel(py::module *m) {
+  PybindModelConfig(m);
+
+  using PyClass = Model;
+  py::class_<PyClass>(*m, "Model")
+      .def_static("create", &PyClass::Create, py::arg("config"))
+      .def(
+          "run_encoder",
+          [](PyClass &self, py::array _features,
+             const std::vector<py::array> &_states)
+              -> std::pair<py::array, std::vector<py::array>> {
+            ncnn::Mat features = ArrayToMat(_features);
+
+            std::vector<ncnn::Mat> states;
+            states.reserve(_states.size());
+            for (const auto &s : _states) {
+              states.push_back(ArrayToMat(s));
+            }
+
+            ncnn::Mat encoder_out;
+            std::vector<ncnn::Mat> _next_states;
+
+            std::tie(encoder_out, _next_states) =
+                self.RunEncoder(features, states);
+
+            std::vector<py::array> next_states;
+            next_states.reserve(_next_states.size());
+            for (const auto &m : _next_states) {
+              next_states.push_back(MatToArray(m));
+            }
+
+            return std::make_pair(MatToArray(encoder_out), next_states);
+          },
+          py::arg("features"), py::arg("states"))
+      .def(
+          "run_decoder",
+          [](PyClass &self, py::array _decoder_input) -> py::array {
+            ncnn::Mat decoder_input = ArrayToMat(_decoder_input);
+            ncnn::Mat decoder_out = self.RunDecoder(decoder_input);
+            return MatToArray(decoder_out);
+          },
+          py::arg("decoder_input"))
+      .def(
+          "run_joiner",
+          [](PyClass &self, py::array _encoder_out,
+             py::array _decoder_out) -> py::array {
+            ncnn::Mat encoder_out = ArrayToMat(_encoder_out);
+            ncnn::Mat decoder_out = ArrayToMat(_decoder_out);
+            ncnn::Mat joiner_out = self.RunJoiner(encoder_out, decoder_out);
+
+            return MatToArray(joiner_out);
+          },
+          py::arg("encoder_out"), py::arg("decoder_out"))
+      .def_property_readonly("context_size", &PyClass::ContextSize)
+      .def_property_readonly("blank_id", &PyClass::BlankId)
+      .def_property_readonly("segment", &PyClass::Segment)
+      .def_property_readonly("offset", &PyClass::Offset);
+}
+
+}  // namespace sherpa_ncnn

+ 30 - 0
sherpa-ncnn/python/csrc/model.h

@@ -0,0 +1,30 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_PYTHON_CSRC_MODEL_H_
+#define SHERPA_NCNN_PYTHON_CSRC_MODEL_H_
+
+#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
+
+namespace sherpa_ncnn {
+
+void PybindModel(py::module *m);
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_PYTHON_CSRC_MODEL_H_

+ 37 - 0
sherpa-ncnn/python/csrc/sherpa-ncnn.cc

@@ -0,0 +1,37 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/python/csrc/sherpa-ncnn.h"
+
+#include "sherpa-ncnn/python/csrc/decode.h"
+#include "sherpa-ncnn/python/csrc/features.h"
+#include "sherpa-ncnn/python/csrc/model.h"
+
+namespace sherpa_ncnn {
+
+PYBIND11_MODULE(_sherpa_ncnn, m) {
+  m.doc() = "pybind11 binding of sherpa-ncnn";
+
+  PybindModel(&m);
+
+  PybindFeatures(&m);
+
+  PybindDecode(&m);
+}
+
+}  // namespace sherpa_ncnn

+ 28 - 0
sherpa-ncnn/python/csrc/sherpa-ncnn.h

@@ -0,0 +1,28 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_PYTHON_CSRC_SHERPA_NCNN_H_
+#define SHERPA_NCNN_PYTHON_CSRC_SHERPA_NCNN_H_
+
+#include "pybind11/numpy.h"
+#include "pybind11/pybind11.h"
+#include "pybind11/stl.h"
+
+namespace py = pybind11;
+
+#endif  // SHERPA_NCNN_PYTHON_CSRC_SHERPA_NCNN_H_

+ 3 - 0
sherpa-ncnn/python/sherpa_ncnn/__init__.py

@@ -0,0 +1,3 @@
+from _sherpa_ncnn import FeatureExtractor, Model, ModelConfig, greedy_search
+
+from .recognizer import Recognizer

+ 201 - 0
sherpa-ncnn/python/sherpa_ncnn/recognizer.py

@@ -0,0 +1,201 @@
+from pathlib import Path
+
+import numpy as np
+from _sherpa_ncnn import FeatureExtractor, Model, ModelConfig, greedy_search
+
+
+def _assert_file_exists(f: str):
+    assert Path(f).is_file(), f"{f} does not exist"
+
+
+def _read_tokens(tokens):
+    sym_table = {}
+    with open(tokens) as f:
+        for line in f:
+            sym, i = line.split()
+            sym = sym.replace("▁", " ")
+            sym_table[int(i)] = sym
+
+    return sym_table
+
+
+class Recognizer(object):
+    """A class for streaming speech recognition.
+
+    Please refer to
+    `<https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html>`_
+    to download pre-trained models for different languages, e.g., Chinese,
+    English, etc.
+
+    **Usage example**
+
+    .. code-block:: python3
+
+        import wave
+
+        import numpy as np
+        import sherpa_ncnn
+
+
+        def main():
+            recognizer = sherpa_ncnn.Recognizer(
+                tokens="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/tokens.txt",
+                encoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.param",
+                encoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.bin",
+                decoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.param",
+                decoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.bin",
+                joiner_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.param",
+                joiner_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.bin",
+                num_threads=4,
+            )
+
+            filename = "./sherpa-ncnn-conv-emformer-transducer-2022-12-06/test_wavs/1.wav"
+            with wave.open(filename) as f:
+                assert f.getframerate() == recognizer.sample_rate, (
+                    f.getframerate(),
+                    recognizer.sample_rate,
+                )
+                assert f.getnchannels() == 1, f.getnchannels()
+                assert f.getsampwidth() == 2, f.getsampwidth()  # it is in bytes
+                num_samples = f.getnframes()
+                samples = f.readframes(num_samples)
+                samples_int16 = np.frombuffer(samples, dtype=np.int16)
+                samples_float32 = samples_int16.astype(np.float32)
+
+                samples_float32 = samples_float32 / 32768
+
+            recognizer.accept_waveform(recognizer.sample_rate, samples_float32)
+
+            tail_paddings = np.zeros(int(recognizer.sample_rate * 0.5), dtype=np.float32)
+            recognizer.accept_waveform(recognizer.sample_rate, tail_paddings)
+
+            recognizer.input_finished()
+
+            print(recognizer.text)
+
+
+        if __name__ == "__main__":
+            main()
+    """
+
+    def __init__(
+        self,
+        tokens: str,
+        encoder_param: str,
+        encoder_bin: str,
+        decoder_param: str,
+        decoder_bin: str,
+        joiner_param: str,
+        joiner_bin: str,
+        num_threads: int = 4,
+    ):
+        """
+        Please refer to
+        `<https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html>`_
+        to download pre-trained models for different languages, e.g., Chinese,
+        English, etc.
+
+        Args:
+          tokens:
+            Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
+            columns::
+
+                symbol integer_id
+          encoder_param:
+            Path to ``encoder.ncnn.param``.
+          encoder_bin:
+            Path to ``encoder.ncnn.bin``.
+          decoder_param:
+            Path to ``decoder.ncnn.param``.
+          decoder_bin:
+            Path to ``decoder.ncnn.bin``.
+          joiner_param:
+            Path to ``joiner.ncnn.param``.
+          joiner_bin:
+            Path to ``joiner.ncnn.bin``.
+          num_threads:
+            Number of threads for neural network computation.
+        """
+        _assert_file_exists(tokens)
+        _assert_file_exists(encoder_param)
+        _assert_file_exists(encoder_bin)
+        _assert_file_exists(decoder_param)
+        _assert_file_exists(decoder_bin)
+        _assert_file_exists(joiner_param)
+        _assert_file_exists(joiner_bin)
+
+        assert num_threads > 0, num_threads
+
+        self.sym_table = _read_tokens(tokens)
+
+        model_config = ModelConfig(
+            encoder_param=encoder_param,
+            encoder_bin=encoder_bin,
+            decoder_param=decoder_param,
+            decoder_bin=decoder_bin,
+            joiner_param=joiner_param,
+            joiner_bin=joiner_bin,
+            num_threads=num_threads,
+        )
+
+        self.model = Model.create(model_config)
+        self.sample_rate = 16000
+
+        self.feature_extractor = FeatureExtractor(
+            feature_dim=80,
+            sample_rate=self.sample_rate,
+        )
+
+        self.num_processed = 0  # number of processed feature frames so far
+        self.states = []  # model state
+
+        self.hyp = [0] * self.model.context_size  # initial hypothesis
+
+        decoder_input = np.array(self.hyp, dtype=np.int32)
+        self.decoder_out = self.model.run_decoder(decoder_input)
+
+    def accept_waveform(self, sample_rate: float, waveform: np.array):
+        """Decode audio samples.
+
+        Args:
+          sample_rate:
+            Sample rate of the input audio samples. It should be 16000.
+          waveform:
+            A 1-D float32 array containing audio samples in the
+            range ``[-1, 1]``.
+        """
+        assert sample_rate == self.sample_rate, (sample_rate, self.sample_rate)
+        self.feature_extractor.accept_waveform(sample_rate, waveform)
+
+        self._decode()
+
+    def input_finished(self):
+        """Signal that no more audio samples are available."""
+        self.feature_extractor.input_finished()
+        self._decode()
+
+    @property
+    def text(self):
+        context_size = self.model.context_size
+        text = [self.sym_table[token] for token in self.hyp[context_size:]]
+        return "".join(text)
+
+    def _decode(self):
+        segment = self.model.segment
+        offset = self.model.offset
+
+        while self.feature_extractor.num_frames_ready - self.num_processed >= segment:
+            features = self.feature_extractor.get_frames(self.num_processed, segment)
+            self.num_processed += offset
+
+            encoder_out, self.states = self.model.run_encoder(
+                features=features,
+                states=self.states,
+            )
+
+            self.decoder_out, self.hyp = greedy_search(
+                model=self.model,
+                encoder_out=encoder_out,
+                decoder_out=self.decoder_out,
+                hyp=self.hyp,
+            )