浏览代码

Speech recognition from microphone (#9)

* Refactor

* First working version

* It is working now

* update readme
Fangjun Kuang 2 年之前
父节点
当前提交
b89a6721ce

+ 1 - 0
CMakeLists.txt

@@ -40,5 +40,6 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
 
 include(kaldi-native-fbank)
 include(ncnn)
+include(portaudio)
 
 add_subdirectory(sherpa-ncnn)

+ 13 - 0
README.md

@@ -43,4 +43,17 @@ git clone https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
   ./sherpa-ncnn-2022-09-05/test_wavs/1089-134686-0001.wav
 ```
 
+To do speech recognition in real-time with a microphone, run:
+
+```bash
+./build/bin/sherpa-ncnn-microphone \
+  ./sherpa-ncnn-2022-09-05/tokens.txt \
+  ./sherpa-ncnn-2022-09-05/bar/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
+  ./sherpa-ncnn-2022-09-05/bar/encoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
+  ./sherpa-ncnn-2022-09-05/bar/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
+  ./sherpa-ncnn-2022-09-05/bar/decoder_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
+  ./sherpa-ncnn-2022-09-05/bar/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.param \
+  ./sherpa-ncnn-2022-09-05/bar/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin \
+```
+
 [ncnn]: https://github.com/tencent/ncnn

+ 61 - 0
cmake/portaudio.cmake

@@ -0,0 +1,61 @@
+# Copyright     2020 Fangjun Kuang (csukuangfj@gmail.com)
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+function(download_portaudio)
+  if(CMAKE_VERSION VERSION_LESS 3.11)
+    # FetchContent is available since 3.11,
+    # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
+    # so that it can be used in lower CMake versions.
+    message(STATUS "Use FetchContent provided by sherpa-ncnn")
+    list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
+  endif()
+
+  include(FetchContent)
+
+  set(portaudio_URL  "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz")
+  set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def")
+  if(BUILD_SHARED_LIBS)
+    set(PA_BUILD_SHARED ON CACHE BOOL "" FORCE)
+    set(PA_BUILD_STATIC OFF CACHE BOOL "" FORCE)
+  else()
+    set(PA_BUILD_SHARED OFF CACHE BOOL "" FORCE)
+    set(PA_BUILD_STATIC ON CACHE BOOL "" FORCE)
+  endif()
+
+  FetchContent_Declare(portaudio
+    URL               ${portaudio_URL}
+    URL_HASH          ${portaudio_HASH}
+  )
+
+  FetchContent_GetProperties(portaudio)
+  if(NOT portaudio_POPULATED)
+    message(STATUS "Downloading portaudio")
+    FetchContent_Populate(portaudio)
+  endif()
+  message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}")
+  message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}")
+
+  if(APPLE)
+    set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS
+  endif()
+
+  add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL)
+endfunction()
+
+download_portaudio()
+
+# Note
+# See http://portaudio.com/docs/v19-doxydocs/tutorial_start.html
+# for how to use portaudio

+ 16 - 8
sherpa-ncnn/csrc/CMakeLists.txt

@@ -1,18 +1,26 @@
 include_directories(${CMAKE_SOURCE_DIR})
 
-add_executable(online-fbank-test online-fbank-test.cc)
-target_link_libraries(online-fbank-test kaldi-native-fbank-core)
-
-add_executable(sherpa-ncnn
+set(sherpa_ncnn_core_srcs
   decode.cc
   features.cc
   lstm-model.cc
-  sherpa-ncnn.cc
   symbol-table.cc
   wave-reader.cc
 )
+add_library(sherpa-ncnn-core ${sherpa_ncnn_core_srcs})
+target_link_libraries(sherpa-ncnn-core kaldi-native-fbank-core ncnn)
+
+add_executable(sherpa-ncnn sherpa-ncnn.cc)
+target_link_libraries(sherpa-ncnn sherpa-ncnn-core)
 
-target_link_libraries(sherpa-ncnn
-  ncnn
-  kaldi-native-fbank-core
+add_executable(sherpa-ncnn-microphone
+  sherpa-ncnn-microphone.cc
+  microphone.cc
 )
+if(BUILD_SHARED_LIBS)
+  set(PA_LIB portaudio)
+else()
+  set(PA_LIB portaudio_static)
+endif()
+
+target_link_libraries(sherpa-ncnn-microphone ${PA_LIB} sherpa-ncnn-core)

+ 8 - 1
sherpa-ncnn/csrc/features.cc

@@ -36,16 +36,22 @@ FeatureExtractor::FeatureExtractor() {
 
 void FeatureExtractor::AcceptWaveform(float sampling_rate,
                                       const float *waveform, int32_t n) {
+  std::lock_guard<std::mutex> lock(mutex_);
   fbank_->AcceptWaveform(sampling_rate, waveform, n);
 }
 
-void FeatureExtractor::InputFinished() { fbank_->InputFinished(); }
+void FeatureExtractor::InputFinished() {
+  std::lock_guard<std::mutex> lock(mutex_);
+  fbank_->InputFinished();
+}
 
 int32_t FeatureExtractor::NumFramesReady() const {
+  std::lock_guard<std::mutex> lock(mutex_);
   return fbank_->NumFramesReady();
 }
 
 bool FeatureExtractor::IsLastFrame(int32_t frame) const {
+  std::lock_guard<std::mutex> lock(mutex_);
   return fbank_->IsLastFrame(frame);
 }
 
@@ -54,6 +60,7 @@ ncnn::Mat FeatureExtractor::GetFrames(int32_t frame_index, int32_t n) const {
     fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
     exit(-1);
   }
+  std::lock_guard<std::mutex> lock(mutex_);
 
   int32_t feature_dim = fbank_->Dim();
   ncnn::Mat features;

+ 2 - 0
sherpa-ncnn/csrc/features.h

@@ -20,6 +20,7 @@
 #define SHERPA_NCNN_CSRC_FEATURES_H_
 
 #include <memory>
+#include <mutex>
 
 #include "kaldi-native-fbank/csrc/online-feature.h"
 
@@ -64,6 +65,7 @@ class FeatureExtractor {
 
  private:
   std::unique_ptr<knf::OnlineFbank> fbank_;
+  mutable std::mutex mutex_;
   float expected_sampling_rate_ = 16000;
 };
 

+ 19 - 23
sherpa-ncnn/csrc/online-fbank-test.cc → sherpa-ncnn/csrc/microphone.cc

@@ -16,33 +16,29 @@
  * limitations under the License.
  */
 
-#include <iostream>
+#include "sherpa-ncnn/csrc/microphone.h"
 
-#include "kaldi-native-fbank/csrc/online-feature.h"
+#include <stdio.h>
+#include <stdlib.h>
 
-int main() {
-  knf::FbankOptions opts;
-  opts.frame_opts.dither = 0;
-  opts.mel_opts.num_bins = 10;
+#include "portaudio.h"
 
-  knf::OnlineFbank fbank(opts);
-  for (int32_t i = 0; i < 1600; ++i) {
-    float s = (i * i - i / 2) / 32767.;
-    fbank.AcceptWaveform(16000, &s, 1);
-  }
-
-  std::ostringstream os;
+namespace sherpa_ncnn {
 
-  int32_t n = fbank.NumFramesReady();
-  for (int32_t i = 0; i != n; ++i) {
-    const float *frame = fbank.GetFrame(i);
-    for (int32_t k = 0; k != opts.mel_opts.num_bins; ++k) {
-      os << frame[k] << ", ";
-    }
-    os << "\n";
+Microphone::Microphone() {
+  PaError err = Pa_Initialize();
+  if (err != paNoError) {
+    fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
+    exit(-1);
   }
+}
 
-  std::cout << os.str() << "\n";
-
-  return 0;
+Microphone::~Microphone() {
+  PaError err = Pa_Terminate();
+  if (err != paNoError) {
+    fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
+    exit(-1);
+  }
 }
+
+}  // namespace sherpa_ncnn

+ 34 - 0
sherpa-ncnn/csrc/microphone.h

@@ -0,0 +1,34 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_MICROPHONE_H_
+#define SHERPA_NCNN_CSRC_MICROPHONE_H_
+
+namespace sherpa_ncnn {
+
+class Microphone {
+ public:
+  Microphone();
+  ~Microphone();
+
+ private:
+};
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_MICROPHONE_H_

+ 192 - 0
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -0,0 +1,192 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "portaudio.h"
+#include "sherpa-ncnn/csrc/decode.h"
+#include "sherpa-ncnn/csrc/features.h"
+#include "sherpa-ncnn/csrc/lstm-model.h"
+#include "sherpa-ncnn/csrc/microphone.h"
+#include "sherpa-ncnn/csrc/symbol-table.h"
+
+bool stop = false;
+
+static int recordCallback(const void *input_buffer, void *outputBuffer,
+                          unsigned long frames_per_buffer,
+                          const PaStreamCallbackTimeInfo *timeInfo,
+                          PaStreamCallbackFlags statusFlags, void *user_data) {
+  auto feature_extractor =
+      reinterpret_cast<sherpa_ncnn::FeatureExtractor *>(user_data);
+
+  feature_extractor->AcceptWaveform(
+      16000, reinterpret_cast<const float *>(input_buffer), frames_per_buffer);
+
+  return stop ? paComplete : paContinue;
+}
+static void handler(int sig) {
+  stop = true;
+  fprintf(stderr, "\nexiting...\n");
+};
+
+int main(int32_t argc, char *argv[]) {
+  if (argc != 8 && argc != 9) {
+    const char *usage = R"usage(
+Usage:
+  ./bin/sherpa-ncnn \
+    /path/to/tokens.txt \
+    /path/to/encoder.ncnn.param \
+    /path/to/encoder.ncnn.bin \
+    /path/to/decoder.ncnn.param \
+    /path/to/decoder.ncnn.bin \
+    /path/to/joiner.ncnn.param \
+    /path/to/joiner.ncnn.bin \
+    [num_threads]
+
+You can download pre-trained models from the following repository:
+https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
+)usage";
+    fprintf(stderr, "%s\n", usage);
+    fprintf(stderr, "argc, %d\n", argc);
+
+    return 0;
+  }
+  signal(SIGINT, handler);
+
+  std::string tokens = argv[1];
+  std::string encoder_param = argv[2];
+  std::string encoder_bin = argv[3];
+  std::string decoder_param = argv[4];
+  std::string decoder_bin = argv[5];
+  std::string joiner_param = argv[6];
+  std::string joiner_bin = argv[7];
+
+  int32_t num_threads = 4;
+  if (argc == 9) {
+    num_threads = atoi(argv[8]);
+  }
+
+  sherpa_ncnn::SymbolTable sym(tokens);
+  fprintf(stderr, "Number of threads: %d\n", num_threads);
+
+  sherpa_ncnn::LstmModel model(encoder_param, encoder_bin, decoder_param,
+                               decoder_bin, joiner_param, joiner_bin,
+                               num_threads);
+
+  sherpa_ncnn::Microphone mic;
+
+  sherpa_ncnn::FeatureExtractor feature_extractor;
+
+  PaDeviceIndex num_devices = Pa_GetDeviceCount();
+  fprintf(stderr, "num devices: %d\n", num_devices);
+
+  PaStreamParameters param;
+
+  param.device = Pa_GetDefaultInputDevice();
+  if (param.device == paNoDevice) {
+    fprintf(stderr, "No default input device found\n");
+    exit(EXIT_FAILURE);
+  }
+  fprintf(stderr, "Use default device: %d\n", param.device);
+
+  const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device);
+  fprintf(stderr, "  Name: %s\n", info->name);
+  fprintf(stderr, "  Max input channels: %d\n", info->maxInputChannels);
+
+  param.channelCount = 1;
+  param.sampleFormat = paFloat32;
+
+  param.suggestedLatency = info->defaultLowInputLatency;
+  param.hostApiSpecificStreamInfo = nullptr;
+  float sample_rate = 16000;
+
+  PaStream *stream;
+  PaError err = Pa_OpenStream(&stream, &param, nullptr, /* &outputParameters, */
+                              sample_rate,
+                              0,         // frames per buffer
+                              paClipOff, /* we won't output out of range samples
+                                            so don't bother clipping them */
+                              recordCallback,
+                              &feature_extractor  // userdata
+  );
+  if (err != paNoError) {
+    fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
+    exit(EXIT_FAILURE);
+  }
+
+  err = Pa_StartStream(stream);
+  fprintf(stderr, "Started\n");
+
+  if (err != paNoError) {
+    fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
+    exit(EXIT_FAILURE);
+  }
+
+  int32_t segment = 9;
+  int32_t offset = 4;
+
+  int32_t context_size = model.ContextSize();
+  int32_t blank_id = model.BlankId();
+
+  std::vector<int32_t> hyp(context_size, blank_id);
+
+  ncnn::Mat decoder_input(context_size);
+  for (int32_t i = 0; i != context_size; ++i) {
+    static_cast<int32_t *>(decoder_input)[i] = blank_id;
+  }
+
+  ncnn::Mat decoder_out = model.RunDecoder(decoder_input);
+
+  ncnn::Mat hx;
+  ncnn::Mat cx;
+
+  int32_t num_tokens = hyp.size();
+  int32_t num_processed = 0;
+
+  while (!stop) {
+    while (feature_extractor.NumFramesReady() - num_processed >= segment) {
+      ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
+      num_processed += offset;
+
+      ncnn::Mat encoder_out = model.RunEncoder(features, &hx, &cx);
+
+      GreedySearch(model, encoder_out, &decoder_out, &hyp);
+    }
+
+    if (hyp.size() != num_tokens) {
+      num_tokens = hyp.size();
+      std::string text;
+      for (int32_t i = context_size; i != hyp.size(); ++i) {
+        text += sym[hyp[i]];
+      }
+      fprintf(stderr, "%s\n", text.c_str());
+    }
+
+    Pa_Sleep(20);  // sleep for 20ms
+  }
+
+  err = Pa_CloseStream(stream);
+  if (err != paNoError) {
+    fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
+    exit(EXIT_FAILURE);
+  }
+
+  return 0;
+}

+ 1 - 1
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -72,7 +72,7 @@ int main(int argc, char *argv[]) {
   if (argc < 9 || argc > 10) {
     const char *usage = R"usage(
 Usage:
-  ./sherpa-ncnn \
+  ./bin/sherpa-ncnn \
     /path/to/tokens.txt \
     /path/to/encoder.ncnn.param \
     /path/to/encoder.ncnn.bin \