ソースを参照

Refactor decoder and add endpointing (#42)

PF Luo 2 年 前
コミット
75661948e2

+ 17 - 0
.github/scripts/run-test.sh

@@ -173,6 +173,23 @@ for wave in ${waves[@]}; do
     $wave
 done
 
+log "Test beam-search"
+
+for wave in ${waves[@]}; do
+  time $EXE \
+    $repo/tokens.txt \
+    $repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+    $repo/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+    $repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+    $repo/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+    $repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \
+    $repo/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \
+    $wave
+    4
+    "modified_beam_search"
+done
+
+
 log "Test int8 models"
 
 for wave in ${waves[@]}; do

+ 2 - 0
.gitignore

@@ -4,3 +4,5 @@ __pycache__
 sherpa_ncnn.egg-info/
 run*.sh
 dist/
+.DS_Store
+tags

+ 10 - 6
sherpa-ncnn/csrc/CMakeLists.txt

@@ -2,11 +2,15 @@ include_directories(${CMAKE_SOURCE_DIR})
 
 set(sherpa_ncnn_core_srcs
   conv-emformer-model.cc
-  decode.cc
+  endpoint.cc
   features.cc
+  greedy-search-decoder.cc
+  hypothesis.cc
   lstm-model.cc
   meta-data.cc
   model.cc
+  modified-beam-search-decoder.cc
+  recognizer.cc
   symbol-table.cc
   wave-reader.cc
 )
@@ -15,11 +19,11 @@ target_link_libraries(sherpa-ncnn-core PUBLIC kaldi-native-fbank-core ncnn)
 install(TARGETS sherpa-ncnn-core DESTINATION lib)
 
 if(NOT SHERPA_NCNN_ENABLE_PYTHON)
-  add_executable(sherpa-ncnn sherpa-ncnn.cc)
-  target_link_libraries(sherpa-ncnn PRIVATE sherpa-ncnn-core)
-  install(TARGETS sherpa-ncnn DESTINATION bin)
-
   if(NOT DEFINED ANDROID_ABI)
+    add_executable(sherpa-ncnn sherpa-ncnn.cc)
+    target_link_libraries(sherpa-ncnn PRIVATE sherpa-ncnn-core)
+    install(TARGETS sherpa-ncnn DESTINATION bin)
+
     if(SHERPA_NCNN_ENABLE_PORTAUDIO)
       add_executable(sherpa-ncnn-microphone
         sherpa-ncnn-microphone.cc
@@ -39,9 +43,9 @@ if(NOT SHERPA_NCNN_ENABLE_PYTHON)
   endif()
 
   set(hdrs
-    decode.h
     features.h
     model.h
+    recognizer.h
     symbol-table.h
     wave-reader.h
   )

+ 0 - 49
sherpa-ncnn/csrc/decode.cc

@@ -1,49 +0,0 @@
-/**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
- *
- * See LICENSE for clarification regarding multiple authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "sherpa-ncnn/csrc/decode.h"
-
-namespace sherpa_ncnn {
-
-void GreedySearch(Model *model, ncnn::Mat &encoder_out, ncnn::Mat *decoder_out,
-                  std::vector<int32_t> *hyp) {
-  int32_t context_size = 2;
-  int32_t blank_id = 0;  // hard-code it to 0
-  ncnn::Mat decoder_input(context_size);
-
-  for (int32_t t = 0; t != encoder_out.h; ++t) {
-    ncnn::Mat encoder_out_t(encoder_out.w, encoder_out.row(t));
-    ncnn::Mat joiner_out = model->RunJoiner(encoder_out_t, *decoder_out);
-
-    auto y = static_cast<int32_t>(std::distance(
-        static_cast<const float *>(joiner_out),
-        std::max_element(
-            static_cast<const float *>(joiner_out),
-            static_cast<const float *>(joiner_out) + joiner_out.w)));
-
-    if (y != blank_id) {
-      static_cast<int32_t *>(decoder_input)[0] = hyp->back();
-      static_cast<int32_t *>(decoder_input)[1] = y;
-      hyp->push_back(y);
-
-      *decoder_out = model->RunDecoder(decoder_input);
-    }
-  }
-}
-
-}  // namespace sherpa_ncnn

+ 0 - 45
sherpa-ncnn/csrc/decode.h

@@ -1,45 +0,0 @@
-/**
- * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
- *
- * See LICENSE for clarification regarding multiple authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef SHERPA_NCNN_CSRC_DECODE_H_
-#define SHERPA_NCNN_CSRC_DECODE_H_
-
-#include <vector>
-
-#include "net.h"  // NOLINT
-#include "sherpa-ncnn/csrc/model.h"
-
-namespace sherpa_ncnn {
-
-/**
- *
- * @param model  The neural network.
- * @param encoder_out  Its shape is (num_frames, encoder_out_dim).
- *                     encoder_out.w == encoder_out_dim
- *                     encoder_out.h == num_frames
- * @param decoder_out  Its shape is (1, decoder_out_dim).
- *                     decoder_out.w == decoder_out_dim
- *                     decoder_out.h == 1
- * @param hyp The recognition result. It is changed in place.
- */
-void GreedySearch(Model *model, ncnn::Mat &encoder_out, ncnn::Mat *decoder_out,
-                  std::vector<int32_t> *hyp);
-
-}  // namespace sherpa_ncnn
-
-#endif  // SHERPA_NCNN_CSRC_DECODE_H_

+ 51 - 0
sherpa-ncnn/csrc/endpoint.cc

@@ -0,0 +1,51 @@
+/**
+ * Copyright      2022  (authors: Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "sherpa-ncnn/csrc/endpoint.h"
+
+#include <string>
+
+namespace sherpa_ncnn {
+
+static bool RuleActivated(const EndpointRule &rule,
+                          const std::string &rule_name,
+                          const float trailing_silence,
+                          const float utterance_length) {
+  bool contain_nonsilence = utterance_length > trailing_silence;
+  bool ans = (contain_nonsilence || !rule.must_contain_nonsilence) &&
+             trailing_silence >= rule.min_trailing_silence &&
+             utterance_length >= rule.min_utterance_length;
+  return ans;
+}
+
+bool Endpoint::IsEndpoint(const int num_frames_decoded,
+                          const int trailing_silence_frames,
+                          const float frame_shift_in_seconds) const {
+  float utterance_length = num_frames_decoded * frame_shift_in_seconds;
+  float trailing_silence = trailing_silence_frames * frame_shift_in_seconds;
+  if (RuleActivated(config_.rule1, "rule1", trailing_silence,
+                    utterance_length) ||
+      RuleActivated(config_.rule1, "rule2", trailing_silence,
+                    utterance_length) ||
+      RuleActivated(config_.rule3, "rule3", trailing_silence,
+                    utterance_length)) {
+    return true;
+  }
+  return false;
+}
+
+}  // namespace sherpa_ncnn

+ 75 - 0
sherpa-ncnn/csrc/endpoint.h

@@ -0,0 +1,75 @@
+/**
+ * Copyright      2022  (authors: Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef SHERPA_NCNN_CSRC_ENDPOINT_H_
+#define SHERPA_NCNN_CSRC_ENDPOINT_H_
+
+#include <vector>
+
+namespace sherpa_ncnn {
+
+struct EndpointRule {
+  // If True, for this endpointing rule to apply there must
+  // be nonsilence in the best-path traceback.
+  // For decoding, a non-blank token is considered as non-silence
+  bool must_contain_nonsilence;
+  // This endpointing rule requires duration of trailing silence
+  // (in seconds) to be >= this value.
+  float min_trailing_silence;
+  // This endpointing rule requires utterance-length (in seconds)
+  // to be >= this value.
+  float min_utterance_length;
+
+  explicit EndpointRule(const bool must_contain_nonsilence = true,
+                        const float min_trailing_silence = 2.0,
+                        const float min_utterance_length = 0)
+      : must_contain_nonsilence(must_contain_nonsilence),
+        min_trailing_silence(min_trailing_silence),
+        min_utterance_length(min_utterance_length) {}
+};
+
+struct EndpointConfig {
+  // For default setting,
+  // rule1 times out after 2.4 seconds of silence, even if we decoded nothing.
+  // rule2 times out after 1.2 seconds of silence after decoding something.
+  // rule3 times out after the utterance is 20 seconds long, regardless of
+  // anything else.
+  EndpointRule rule1;
+  EndpointRule rule2;
+  EndpointRule rule3;
+
+  EndpointConfig()
+      : rule1(false, 2.4, 0), rule2(true, 1.4, 0), rule3(false, 0, 20) {}
+};
+
+class Endpoint {
+ public:
+  explicit Endpoint(const EndpointConfig &config) : config_(config) {}
+
+  /// This function returns true if this set of endpointing rules thinks we
+  /// should terminate decoding.
+  bool IsEndpoint(const int num_frames_decoded,
+                  const int trailing_silence_frames,
+                  const float frame_shift_in_seconds) const;
+
+ private:
+  EndpointConfig config_;
+};
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_ENDPOINT_H_

+ 1 - 1
sherpa-ncnn/csrc/generate-int8-scale-table.cc

@@ -32,9 +32,9 @@
 #include "layer/innerproduct.h"
 #include "mat.h"
 #include "net.h"
-#include "sherpa-ncnn/csrc/decode.h"
 #include "sherpa-ncnn/csrc/features.h"
 #include "sherpa-ncnn/csrc/model.h"
+#include "sherpa-ncnn/csrc/recognizer.h"
 #include "sherpa-ncnn/csrc/wave-reader.h"
 
 static float compute_kl_divergence(const std::vector<float> &a,

+ 96 - 0
sherpa-ncnn/csrc/greedy-search-decoder.cc

@@ -0,0 +1,96 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/greedy-search-decoder.h"
+
+namespace sherpa_ncnn {
+
+void GreedySearchDecoder::AcceptWaveform(const int32_t sample_rate,
+                                         const float *input_buffer,
+                                         int32_t frames_per_buffer) {
+  feature_extractor_.AcceptWaveform(sample_rate, input_buffer,
+                                    frames_per_buffer);
+}
+
+void GreedySearchDecoder::BuildDecoderInput() {
+  for (int32_t i = 0; i != context_size_; ++i) {
+    static_cast<int32_t *>(decoder_input_)[i] =
+        *(result_.tokens.end() - context_size_ + i);
+  }
+}
+
+void GreedySearchDecoder::ResetResult() {
+  result_.tokens.clear();
+  result_.text.clear();
+  result_.num_trailing_blanks = 0;
+  for (int32_t i = 0; i != context_size_; ++i) {
+    result_.tokens.push_back(blank_id_);
+  }
+}
+
+void GreedySearchDecoder::Decode() {
+  while (feature_extractor_.NumFramesReady() - num_processed_ >= segment_) {
+    ncnn::Mat features = feature_extractor_.GetFrames(num_processed_, segment_);
+    std::tie(encoder_out_, encoder_state_) =
+        model_->RunEncoder(features, encoder_state_);
+
+    /* encoder_out_.w == encoder_out_dim, encoder_out_.h == num_frames. */
+    for (int32_t t = 0; t != encoder_out_.h; ++t) {
+      ncnn::Mat encoder_out_t(encoder_out_.w, encoder_out_.row(t));
+      ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out_);
+      auto joiner_out_ptr = joiner_out.row(0);
+
+      auto new_token = static_cast<int32_t>(std::distance(
+          joiner_out_ptr,
+          std::max_element(joiner_out_ptr, joiner_out_ptr + joiner_out.w)));
+
+      if (new_token != blank_id_) {
+        result_.tokens.push_back(new_token);
+        result_.text += (*sym_)[new_token];
+        BuildDecoderInput();
+        decoder_out_ = model_->RunDecoder(decoder_input_);
+        result_.num_trailing_blanks = 0;
+      } else {
+        ++result_.num_trailing_blanks;
+      }
+    }
+
+    num_processed_ += offset_;
+  }
+}
+
+RecognitionResult GreedySearchDecoder::GetResult() {
+  auto ans = result_;
+  if (config_.use_endpoint && IsEndpoint()) {
+    ResetResult();
+    endpoint_start_frame_ = num_processed_;
+  }
+  return ans;
+}
+
+void GreedySearchDecoder::InputFinished() {
+  feature_extractor_.InputFinished();
+}
+
+bool GreedySearchDecoder::IsEndpoint() const {
+  return endpoint_->IsEndpoint(num_processed_ - endpoint_start_frame_,
+                               result_.num_trailing_blanks * 4, 10 / 1000.0);
+}
+
+}  // namespace sherpa_ncnn

+ 90 - 0
sherpa-ncnn/csrc/greedy-search-decoder.h

@@ -0,0 +1,90 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_GREEDY_SEARCH_DECODER_H_
+#define SHERPA_NCNN_CSRC_GREEDY_SEARCH_DECODER_H_
+
+#include <memory>
+#include <vector>
+
+#include "sherpa-ncnn/csrc/features.h"
+#include "sherpa-ncnn/csrc/recognizer.h"
+
+namespace sherpa_ncnn {
+
+class GreedySearchDecoder : public Decoder {
+ public:
+  GreedySearchDecoder(const DecoderConfig &config, Model *model,
+                      const knf::FbankOptions &fbank_opts,
+                      const sherpa_ncnn::SymbolTable *sym,
+                      const Endpoint *endpoint)
+      : config_(config),
+        model_(model),
+        feature_extractor_(fbank_opts),
+        sym_(sym),
+        blank_id_(model_->BlankId()),
+        context_size_(model_->ContextSize()),
+        segment_(model->Segment()),
+        offset_(model_->Offset()),
+        decoder_input_(context_size_),
+        num_processed_(0),
+        endpoint_start_frame_(0),
+        endpoint_(endpoint) {
+    ResetResult();
+    BuildDecoderInput();
+    decoder_out_ = model_->RunDecoder(decoder_input_);
+  }
+
+  void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+                      int32_t frames_per_buffer) override;
+
+  void Decode() override;
+
+  RecognitionResult GetResult() override;
+
+  void ResetResult() override;
+
+  bool IsEndpoint() const override;
+
+  void InputFinished() override;
+
+ private:
+  void BuildDecoderInput();
+
+  const DecoderConfig config_;
+  Model *model_;
+  sherpa_ncnn::FeatureExtractor feature_extractor_;
+  const sherpa_ncnn::SymbolTable *sym_;
+  const int32_t blank_id_;
+  const int32_t context_size_;
+  const int32_t segment_;
+  const int32_t offset_;
+  ncnn::Mat encoder_out_;
+  std::vector<ncnn::Mat> encoder_state_;
+  ncnn::Mat decoder_input_;
+  ncnn::Mat decoder_out_;
+  int32_t num_processed_;
+  int32_t endpoint_start_frame_;
+  const Endpoint *endpoint_;
+  RecognitionResult result_;
+};
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_GREEDY_SEARCH_DECODER_H_

+ 78 - 0
sherpa-ncnn/csrc/hypothesis.cc

@@ -0,0 +1,78 @@
+/**
+ * Copyright      2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/hypothesis.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "sherpa-ncnn/csrc/math.h"
+
+namespace sherpa_ncnn {
+
+void Hypotheses::Add(Hypothesis hyp) {
+  auto key = hyp.Key();
+  auto it = hyps_dict_.find(key);
+  if (it == hyps_dict_.end()) {
+    hyps_dict_[key] = std::move(hyp);
+  } else {
+    it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
+  }
+}
+
+Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
+  if (length_norm == false) {
+    return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
+                            [](const auto &left, auto &right) -> bool {
+                              return left.second.log_prob <
+                                     right.second.log_prob;
+                            })
+        ->second;
+  } else {
+    // for length_norm is true
+    return std::max_element(
+               hyps_dict_.begin(), hyps_dict_.end(),
+               [](const auto &left, const auto &right) -> bool {
+                 return left.second.log_prob / left.second.ys.size() <
+                        right.second.log_prob / right.second.ys.size();
+               })
+        ->second;
+  }
+}
+
+Hypothesis Hypotheses::GetLeastProbable(bool length_norm) const {
+  if (length_norm == false) {
+    return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
+                            [](const auto &left, auto &right) -> bool {
+                              return left.second.log_prob >
+                                     right.second.log_prob;
+                            })
+        ->second;
+  } else {
+    // for length_norm is true
+    return std::max_element(
+               hyps_dict_.begin(), hyps_dict_.end(),
+               [](const auto &left, const auto &right) -> bool {
+                 return left.second.log_prob / left.second.ys.size() >
+                        right.second.log_prob / right.second.ys.size();
+               })
+        ->second;
+  }
+}
+
+}  // namespace sherpa_ncnn

+ 134 - 0
sherpa-ncnn/csrc/hypothesis.h

@@ -0,0 +1,134 @@
+/**
+ * Copyright      2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_HYPOTHESIS_H_
+#define SHERPA_NCNN_CSRC_HYPOTHESIS_H_
+
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace sherpa_ncnn {
+
+struct Hypothesis {
+  // The predicted tokens so far. Newly predicated tokens are appended.
+  std::vector<int32_t> ys;
+
+  // timestamps[i] contains the frame number after subsampling
+  // on which ys[i] is decoded.
+  std::vector<int32_t> timestamps;
+
+  // The total score of ys in log space.
+  double log_prob = 0;
+
+  int32_t num_trailing_blanks = 0;
+
+  Hypothesis() = default;
+  Hypothesis(const std::vector<int32_t> &ys, double log_prob)
+      : ys(ys), log_prob(log_prob) {}
+
+  // If two Hypotheses have the same `Key`, then they contain
+  // the same token sequence.
+  std::string Key() const {
+    // TODO(fangjun): Use a hash function?
+    std::ostringstream os;
+    std::string sep = "-";
+    for (auto i : ys) {
+      os << i << sep;
+      sep = "-";
+    }
+    return os.str();
+  }
+
+  // For debugging
+  std::string ToString() const {
+    std::ostringstream os;
+    os << "(" << Key() << ", " << log_prob << ")";
+    return os.str();
+  }
+};
+
+class Hypotheses {
+ public:
+  Hypotheses() = default;
+
+  explicit Hypotheses(std::vector<Hypothesis> hyps) {
+    for (auto &h : hyps) {
+      hyps_dict_[h.Key()] = std::move(h);
+    }
+  }
+
+  explicit Hypotheses(std::unordered_map<std::string, Hypothesis> hyps_dict)
+      : hyps_dict_(std::move(hyps_dict)) {}
+
+  // Add hyp to this object. If it already exists, its log_prob
+  // is updated with the given hyp using log-sum-exp.
+  void Add(Hypothesis hyp);
+
+  // Get the hyp that has the largest log_prob.
+  // If length_norm is true, hyp's log_prob are divided by
+  // len(hyp.ys) before comparison.
+  Hypothesis GetMostProbable(bool length_norm) const;
+
+  // Get the hyp that has the least log_prob.
+  // If length_norm is true, hyp's log_prob are divided by
+  // len(hyp.ys) before comparison.
+  Hypothesis GetLeastProbable(bool length_norm) const;
+
+  // Remove the given hyp from this object.
+  // It is *NOT* an error if hyp does not exist in this object.
+  void Remove(const Hypothesis &hyp) { hyps_dict_.erase(hyp.Key()); }
+
+  // Return a list of hyps contained in this object.
+  std::vector<Hypothesis> Vec() const {
+    std::vector<Hypothesis> ans;
+    ans.reserve(hyps_dict_.size());
+    for (const auto &p : hyps_dict_) {
+      ans.push_back(p.second);
+    }
+    return ans;
+  }
+
+  int32_t Size() const { return hyps_dict_.size(); }
+
+  std::string ToString() const {
+    std::ostringstream os;
+    for (const auto &p : hyps_dict_) {
+      os << p.second.ToString() << "\n";
+    }
+    return os.str();
+  }
+
+  auto begin() { return hyps_dict_.begin(); }
+  auto end() { return hyps_dict_.end(); }
+
+  const auto begin() const { return hyps_dict_.begin(); }
+  const auto end() const { return hyps_dict_.end(); }
+
+  void clear() { hyps_dict_.clear(); }
+
+ private:
+  using Map = std ::unordered_map<std::string, Hypothesis>;
+  Map hyps_dict_;
+};
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_HYPOTHESIS_H_

+ 120 - 0
sherpa-ncnn/csrc/math.h

@@ -0,0 +1,120 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Daniel Povey)
+ * Copyright (c)  2022                     (Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// This file is copied from k2/csrc/utils.h
+#ifndef SHERPA_NCNN_CSRC_MATH_H_
+#define SHERPA_NCNN_CSRC_MATH_H_
+
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <numeric>
+#include <vector>
+
+namespace sherpa_ncnn {
+
+// logf(FLT_EPSILON)
+#define SHERPA_NCNN_MIN_LOG_DIFF_FLOAT -15.9423847198486328125f
+
+// log(DBL_EPSILON)
+#define SHERPA_NCNN_MIN_LOG_DIFF_DOUBLE \
+  -36.0436533891171535515240975655615329742431640625
+
+template <typename T>
+struct LogAdd;
+
+template <>
+struct LogAdd<double> {
+  double operator()(double x, double y) const {
+    double diff;
+
+    if (x < y) {
+      diff = x - y;
+      x = y;
+    } else {
+      diff = y - x;
+    }
+    // diff is negative.  x is now the larger one.
+
+    if (diff >= SHERPA_NCNN_MIN_LOG_DIFF_DOUBLE) {
+      double res;
+      res = x + log1p(exp(diff));
+      return res;
+    }
+
+    return x;  // return the larger one.
+  }
+};
+
+template <>
+struct LogAdd<float> {
+  float operator()(float x, float y) const {
+    float diff;
+
+    if (x < y) {
+      diff = x - y;
+      x = y;
+    } else {
+      diff = y - x;
+    }
+    // diff is negative.  x is now the larger one.
+
+    if (diff >= SHERPA_NCNN_MIN_LOG_DIFF_DOUBLE) {
+      float res;
+      res = x + log1pf(expf(diff));
+      return res;
+    }
+
+    return x;  // return the larger one.
+  }
+};
+
+template <class T>
+void log_softmax(T *input, int32_t input_len) {
+  assert(input);
+
+  T m = *std::max_element(input, input + input_len);
+
+  T sum = 0.0;
+  for (int32_t i = 0; i < input_len; i++) {
+    sum += exp(input[i] - m);
+  }
+
+  T offset = m + log(sum);
+  for (int32_t i = 0; i < input_len; i++) {
+    input[i] -= offset;
+  }
+}
+
+template <class T>
+std::vector<int32_t> topk_index(const T *vec, int32_t size, int32_t topk) {
+  std::vector<int32_t> vec_index(size);
+  std::iota(vec_index.begin(), vec_index.end(), 0);
+
+  std::sort(vec_index.begin(), vec_index.end(),
+            [vec](int32_t index_1, int32_t index_2) {
+              return vec[index_1] > vec[index_2];
+            });
+
+  int32_t k_num = std::min<int32_t>(size, topk);
+  std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
+  return index;
+}
+
+}  // namespace sherpa_ncnn
+#endif  // SHERPA_NCNN_CSRC_MATH_H_

+ 130 - 0
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -0,0 +1,130 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "sherpa-ncnn/csrc/modified-beam-search-decoder.h"
+
+#include <string>
+#include <utility>
+
+#include "sherpa-ncnn/csrc/math.h"
+
+namespace sherpa_ncnn {
+
+void ModifiedBeamSearchDecoder::AcceptWaveform(const int32_t sample_rate,
+                                               const float *input_buffer,
+                                               int32_t frames_per_buffer) {
+  feature_extractor_.AcceptWaveform(sample_rate, input_buffer,
+                                    frames_per_buffer);
+}
+
+void ModifiedBeamSearchDecoder::BuildDecoderInput(Hypothesis hyp) {
+  decoder_input_.reshape(context_size_);
+  for (int32_t i = 0; i != context_size_; ++i) {
+    static_cast<int32_t *>(decoder_input_)[i] =
+        *(hyp.ys.end() - context_size_ + i);
+  }
+}
+
+void ModifiedBeamSearchDecoder::ResetResult() {
+  result_.text.clear();
+  std::vector<int32_t> blanks(context_size_, blank_id_);
+  Hypotheses blank_hyp({{blanks, 0}});
+  result_.hyps = std::move(blank_hyp);
+}
+
+void ModifiedBeamSearchDecoder::Decode() {
+  while (feature_extractor_.NumFramesReady() - num_processed_ >= segment_) {
+    ncnn::Mat features = feature_extractor_.GetFrames(num_processed_, segment_);
+    std::tie(encoder_out_, encoder_state_) =
+        model_->RunEncoder(features, encoder_state_);
+
+    Hypotheses cur = std::move(result_.hyps);
+    std::vector<Hypothesis> prev;
+    /* encoder_out_.w == encoder_out_dim, encoder_out_.h == num_frames. */
+    for (int32_t t = 0; t != encoder_out_.h; ++t) {
+      prev.clear();
+      prev.reserve(config_.num_active_paths);
+      for (auto &h : cur) {
+        prev.push_back(std::move(h.second));
+      }
+      cur.clear();
+
+      for (const auto &h : prev) {
+        ncnn::Mat encoder_out_t(encoder_out_.w, encoder_out_.row(t));
+        BuildDecoderInput(h);
+        decoder_out_ = model_->RunDecoder(decoder_input_);
+        ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out_);
+        auto joiner_out_ptr = joiner_out.row(0);
+        log_softmax(joiner_out_ptr, joiner_out.w);
+
+        // update active_paths
+        auto topk =
+            topk_index(joiner_out_ptr, joiner_out.w, config_.num_active_paths);
+        for (int i = 0; i != topk.size(); i++) {
+          Hypothesis new_hyp = h;
+          int32_t new_token = topk[i];
+          if (new_token != blank_id_) {
+            new_hyp.ys.push_back(new_token);
+            new_hyp.num_trailing_blanks = 0;
+          } else {
+            new_hyp.num_trailing_blanks += 1;
+          }
+          new_hyp.log_prob += joiner_out_ptr[new_token];
+          cur.Add(std::move(new_hyp));
+        }
+      }
+      while (cur.Size() > config_.num_active_paths) {
+        auto least_hyp = cur.GetLeastProbable(true);
+        cur.Remove(least_hyp);
+      }
+    }
+
+    num_processed_ += offset_;
+    result_.hyps = std::move(cur);
+  }
+}
+
+RecognitionResult ModifiedBeamSearchDecoder::GetResult() {
+  // return best result
+  auto best_hyp = result_.hyps.GetMostProbable(true);
+  std::string best_hyp_text;
+  for (const auto &token : best_hyp.ys) {
+    if (token != blank_id_) {
+      best_hyp_text += (*sym_)[token];
+    }
+  }
+  result_.text = std::move(best_hyp_text);
+  auto ans = result_;
+  result_.num_trailing_blanks = best_hyp.num_trailing_blanks;
+  if (config_.use_endpoint && IsEndpoint()) {
+    ResetResult();
+    endpoint_start_frame_ = num_processed_;
+  }
+  return ans;
+}
+
+void ModifiedBeamSearchDecoder::InputFinished() {
+  feature_extractor_.InputFinished();
+}
+
+bool ModifiedBeamSearchDecoder::IsEndpoint() const {
+  return endpoint_->IsEndpoint(num_processed_ - endpoint_start_frame_,
+                               result_.num_trailing_blanks * 4, 10 / 1000.0);
+}
+
+}  // namespace sherpa_ncnn

+ 88 - 0
sherpa-ncnn/csrc/modified-beam-search-decoder.h

@@ -0,0 +1,88 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_MODIFIED_BEAM_SEARCH_DECODER_H_
+#define SHERPA_NCNN_CSRC_MODIFIED_BEAM_SEARCH_DECODER_H_
+
+#include <memory>
+#include <vector>
+
+#include "sherpa-ncnn/csrc/features.h"
+#include "sherpa-ncnn/csrc/recognizer.h"
+
+namespace sherpa_ncnn {
+
+class ModifiedBeamSearchDecoder : public Decoder {
+ public:
+  ModifiedBeamSearchDecoder(const DecoderConfig &config, Model *model,
+                            const knf::FbankOptions &fbank_opts,
+                            const sherpa_ncnn::SymbolTable *sym,
+                            const Endpoint *endpoint)
+      : config_(config),
+        model_(model),
+        feature_extractor_(fbank_opts),
+        sym_(sym),
+        blank_id_(model_->BlankId()),
+        context_size_(model_->ContextSize()),
+        segment_(model->Segment()),
+        offset_(model_->Offset()),
+        decoder_input_(context_size_),
+        num_processed_(0),
+        endpoint_start_frame_(0),
+        endpoint_(endpoint) {
+    ResetResult();
+  }
+
+  void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+                      int32_t frames_per_buffer) override;
+
+  void Decode() override;
+
+  RecognitionResult GetResult() override;
+
+  void ResetResult() override;
+
+  bool IsEndpoint() const override;
+
+  void InputFinished() override;
+
+ private:
+  void BuildDecoderInput(Hypothesis hyp);
+
+  const DecoderConfig config_;
+  Model *model_;
+  sherpa_ncnn::FeatureExtractor feature_extractor_;
+  const sherpa_ncnn::SymbolTable *sym_;
+  const int32_t blank_id_;
+  const int32_t context_size_;
+  const int32_t segment_;
+  const int32_t offset_;
+  ncnn::Mat encoder_out_;
+  std::vector<ncnn::Mat> encoder_state_;
+  ncnn::Mat decoder_input_;
+  ncnn::Mat decoder_out_;
+  int32_t num_processed_;
+  int32_t endpoint_start_frame_;
+  const Endpoint *endpoint_;
+  RecognitionResult result_;
+};
+
+}  // namespace sherpa_ncnn
+
+#endif  // SHERPA_NCNN_CSRC_MODIFIED_BEAM_SEARCH_DECODER_H_

+ 69 - 0
sherpa-ncnn/csrc/recognizer.cc

@@ -0,0 +1,69 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "sherpa-ncnn/csrc/greedy-search-decoder.h"
+#include "sherpa-ncnn/csrc/modified-beam-search-decoder.h"
+
+namespace sherpa_ncnn {
+
+Recognizer::Recognizer(
+#if __ANDROID_API__ >= 9
+    AAssetManager *mgr,
+#endif
+    const DecoderConfig decoder_conf, const ModelConfig model_conf,
+    const knf::FbankOptions fbank_opts)
+    :
+#if __ANDROID_API__ >= 9
+      model_(Model::Create(mgr, model_conf)),
+      sym_(std::make_unique<SymbolTable>(mgr, model_conf.tokens)),
+#else
+      model_(Model::Create(model_conf)),
+      sym_(std::make_unique<SymbolTable>(model_conf.tokens)),
+#endif
+      endpoint_(std::make_unique<Endpoint>(decoder_conf.endpoint_config)) {
+  if (decoder_conf.method == "modified_beam_search") {
+    decoder_ = std::make_unique<ModifiedBeamSearchDecoder>(
+        decoder_conf, model_.get(), fbank_opts, sym_.get(), endpoint_.get());
+  } else if (decoder_conf.method == "greedy_search") {
+    decoder_ = std::make_unique<GreedySearchDecoder>(
+        decoder_conf, model_.get(), fbank_opts, sym_.get(), endpoint_.get());
+  } else {
+    NCNN_LOGE("Unsupported decoding method: %s\n", decoder_conf.method.c_str());
+    exit(-1);
+  }
+}
+
+void Recognizer::AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+                                int32_t frames_per_buffer) {
+  decoder_->AcceptWaveform(sample_rate, input_buffer, frames_per_buffer);
+}
+
+void Recognizer::Decode() { decoder_->Decode(); }
+
+RecognitionResult Recognizer::GetResult() { return decoder_->GetResult(); }
+
+bool Recognizer::IsEndpoint() const { return decoder_->IsEndpoint(); }
+
+void Recognizer::InputFinished() { return decoder_->InputFinished(); }
+
+}  // namespace sherpa_ncnn

+ 105 - 0
sherpa-ncnn/csrc/recognizer.h

@@ -0,0 +1,105 @@
+/**
+ * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_RECOGNIZER_H_
+#define SHERPA_NCNN_CSRC_RECOGNIZER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "sherpa-ncnn/csrc/endpoint.h"
+#include "sherpa-ncnn/csrc/features.h"
+#include "sherpa-ncnn/csrc/hypothesis.h"
+#include "sherpa-ncnn/csrc/model.h"
+#include "sherpa-ncnn/csrc/symbol-table.h"
+
+namespace sherpa_ncnn {
+
+struct RecognitionResult {
+  std::vector<int32_t> tokens;
+  std::string text;
+
+  int32_t num_trailing_blanks = 0;
+
+  // used only for modified_beam_search
+  Hypotheses hyps;
+};
+
+struct DecoderConfig {
+  std::string method = "modified_beam_search";
+
+  int32_t num_active_paths = 4;  // for modified beam search
+
+  bool use_endpoint = true;
+
+  EndpointConfig endpoint_config;
+};
+
+class Decoder {
+ public:
+  virtual ~Decoder() = default;
+
+  virtual void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+                              int32_t frames_per_buffer) = 0;
+
+  virtual void Decode() = 0;
+
+  virtual RecognitionResult GetResult() = 0;
+
+  virtual void ResetResult() = 0;
+
+  virtual void InputFinished() = 0;
+
+  virtual bool IsEndpoint() const = 0;
+};
+
+class Recognizer {
+ public:
+  /** Construct an instance of OnlineRecognizer.
+   */
+  Recognizer(
+#if __ANDROID_API__ >= 9
+      AAssetManager *mgr,
+#endif
+      const DecoderConfig decoder_conf, const ModelConfig model_conf,
+      const knf::FbankOptions fbank_opts);
+
+  ~Recognizer() = default;
+
+  void AcceptWaveform(int32_t sample_rate, const float *input_buffer,
+                      int32_t frames_per_buffer);
+
+  void Decode();
+
+  RecognitionResult GetResult();
+
+  void InputFinished();
+
+  bool IsEndpoint() const;
+
+ private:
+  std::unique_ptr<Model> model_;
+  std::unique_ptr<SymbolTable> sym_;
+  std::unique_ptr<Endpoint> endpoint_;
+  std::unique_ptr<Decoder> decoder_;
+};
+
+}  // namespace sherpa_ncnn
+#endif  // SHERPA_NCNN_CSRC_RECOGNIZER_H_

+ 39 - 78
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -21,11 +21,8 @@
 #include <stdlib.h>
 
 #include "portaudio.h"  // NOLINT
-#include "sherpa-ncnn/csrc/decode.h"
-#include "sherpa-ncnn/csrc/features.h"
 #include "sherpa-ncnn/csrc/microphone.h"
-#include "sherpa-ncnn/csrc/model.h"
-#include "sherpa-ncnn/csrc/symbol-table.h"
+#include "sherpa-ncnn/csrc/recognizer.h"
 
 bool stop = false;
 
@@ -34,14 +31,14 @@ static int RecordCallback(const void *input_buffer, void * /*output_buffer*/,
                           const PaStreamCallbackTimeInfo * /*time_info*/,
                           PaStreamCallbackFlags /*status_flags*/,
                           void *user_data) {
-  auto feature_extractor =
-      reinterpret_cast<sherpa_ncnn::FeatureExtractor *>(user_data);
+  auto recognizer = reinterpret_cast<sherpa_ncnn::Recognizer *>(user_data);
 
-  feature_extractor->AcceptWaveform(
+  recognizer->AcceptWaveform(
       16000, reinterpret_cast<const float *>(input_buffer), frames_per_buffer);
 
   return stop ? paComplete : paContinue;
 }
+
 static void Handler(int sig) {
   stop = true;
   fprintf(stderr, "\nexiting...\n");
@@ -71,44 +68,42 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
   }
   signal(SIGINT, Handler);
 
-  sherpa_ncnn::ModelConfig config;
-
-  config.tokens = argv[1];
-  config.encoder_param = argv[2];
-  config.encoder_bin = argv[3];
-  config.decoder_param = argv[4];
-  config.decoder_bin = argv[5];
-  config.joiner_param = argv[6];
-  config.joiner_bin = argv[7];
-
-  int32_t num_threads = 4;
-  if (argc == 9) {
+  sherpa_ncnn::ModelConfig model_conf;
+  model_conf.tokens = argv[1];
+  model_conf.encoder_param = argv[2];
+  model_conf.encoder_bin = argv[3];
+  model_conf.decoder_param = argv[4];
+  model_conf.decoder_bin = argv[5];
+  model_conf.joiner_param = argv[6];
+  model_conf.joiner_bin = argv[7];
+  int num_threads = 4;
+  if (argc >= 9 && atoi(argv[8]) > 0) {
     num_threads = atoi(argv[8]);
   }
-
-  config.encoder_opt.num_threads = num_threads;
-  config.decoder_opt.num_threads = num_threads;
-  config.joiner_opt.num_threads = num_threads;
-
-  sherpa_ncnn::SymbolTable sym(config.tokens);
-  fprintf(stderr, "%s\n", config.ToString().c_str());
-
-  auto model = sherpa_ncnn::Model::Create(config);
-  if (!model) {
-    fprintf(stderr, "Failed to create a model\n");
-    exit(EXIT_FAILURE);
+  model_conf.encoder_opt.num_threads = num_threads;
+  model_conf.decoder_opt.num_threads = num_threads;
+  model_conf.joiner_opt.num_threads = num_threads;
+
+  fprintf(stderr, "%s\n", model_conf.ToString().c_str());
+
+  const float expected_sampling_rate = 16000;
+  sherpa_ncnn::DecoderConfig decoder_conf;
+  if (argc == 10) {
+    std::string method = argv[9];
+    if (method.compare("greed_search") ||
+        method.compare("modified_beam_search")) {
+      decoder_conf.method = method;
+    }
   }
-
-  float sample_rate = 16000;
-  sherpa_ncnn::Microphone mic;
-
   knf::FbankOptions fbank_opts;
   fbank_opts.frame_opts.dither = 0;
   fbank_opts.frame_opts.snip_edges = false;
-  fbank_opts.frame_opts.samp_freq = sample_rate;
+  fbank_opts.frame_opts.samp_freq = expected_sampling_rate;
   fbank_opts.mel_opts.num_bins = 80;
 
-  sherpa_ncnn::FeatureExtractor feature_extractor(fbank_opts);
+  sherpa_ncnn::Recognizer recognizer(decoder_conf, model_conf, fbank_opts);
+
+  sherpa_ncnn::Microphone mic;
 
   PaDeviceIndex num_devices = Pa_GetDeviceCount();
   fprintf(stderr, "Num devices: %d\n", num_devices);
@@ -131,6 +126,7 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
 
   param.suggestedLatency = info->defaultLowInputLatency;
   param.hostApiSpecificStreamInfo = nullptr;
+  const float sample_rate = 16000;
 
   PaStream *stream;
   PaError err =
@@ -139,7 +135,7 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
                     0,          // frames per buffer
                     paClipOff,  // we won't output out of range samples
                                 // so don't bother clipping them
-                    RecordCallback, &feature_extractor);
+                    RecordCallback, &recognizer);
   if (err != paNoError) {
     fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err));
     exit(EXIT_FAILURE);
@@ -153,47 +149,12 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
     exit(EXIT_FAILURE);
   }
 
-  int32_t segment = model->Segment();
-  int32_t offset = model->Offset();
-
-  int32_t context_size = model->ContextSize();
-  int32_t blank_id = model->BlankId();
-
-  std::vector<int32_t> hyp(context_size, blank_id);
-
-  ncnn::Mat decoder_input(context_size);
-  for (int32_t i = 0; i != context_size; ++i) {
-    static_cast<int32_t *>(decoder_input)[i] = blank_id;
-  }
-
-  ncnn::Mat decoder_out = model->RunDecoder(decoder_input);
-
-  ncnn::Mat hx;
-  ncnn::Mat cx;
-
-  int32_t num_tokens = hyp.size();
-  int32_t num_processed = 0;
-
-  std::vector<ncnn::Mat> states;
-  ncnn::Mat encoder_out;
-
+  int num_tokens = 0;
   while (!stop) {
-    while (feature_extractor.NumFramesReady() - num_processed >= segment) {
-      ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
-      num_processed += offset;
-
-      std::tie(encoder_out, states) = model->RunEncoder(features, states);
-
-      GreedySearch(model.get(), encoder_out, &decoder_out, &hyp);
-    }
-
-    if (hyp.size() != num_tokens) {
-      num_tokens = hyp.size();
-      std::string text;
-      for (int32_t i = context_size; i != hyp.size(); ++i) {
-        text += sym[hyp[i]];
-      }
-      fprintf(stderr, "%s\n", text.c_str());
+    recognizer.Decode();
+    auto result = recognizer.GetResult();
+    if (result.text.size() != num_tokens) {
+      fprintf(stderr, "%s\n", result.text.c_str());
     }
 
     Pa_Sleep(20);  // sleep for 20ms

+ 43 - 84
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -1,5 +1,6 @@
 /**
  * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
  *
  * See LICENSE for clarification regarding multiple authors
  *
@@ -19,16 +20,12 @@
 #include <algorithm>
 #include <iostream>
 
-#include "kaldi-native-fbank/csrc/online-feature.h"
 #include "net.h"  // NOLINT
-#include "sherpa-ncnn/csrc/decode.h"
-#include "sherpa-ncnn/csrc/features.h"
-#include "sherpa-ncnn/csrc/model.h"
-#include "sherpa-ncnn/csrc/symbol-table.h"
+#include "sherpa-ncnn/csrc/recognizer.h"
 #include "sherpa-ncnn/csrc/wave-reader.h"
 
 int main(int argc, char *argv[]) {
-  if (argc < 9 || argc > 10) {
+  if (argc < 9 || argc > 11) {
     const char *usage = R"usage(
 Usage:
   ./bin/sherpa-ncnn \
@@ -39,7 +36,7 @@ Usage:
     /path/to/decoder.ncnn.bin \
     /path/to/joiner.ncnn.param \
     /path/to/joiner.ncnn.bin \
-    /path/to/foo.wav [num_threads]
+    /path/to/foo.wav [num_threads] [decode_method, can be greedy_search/modified_beam_search]
 
 You can download pre-trained models from the following repository:
 https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
@@ -48,38 +45,42 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
 
     return 0;
   }
-  sherpa_ncnn::ModelConfig config;
-
-  config.tokens = argv[1];
-  config.encoder_param = argv[2];
-  config.encoder_bin = argv[3];
-  config.decoder_param = argv[4];
-  config.decoder_bin = argv[5];
-  config.joiner_param = argv[6];
-  config.joiner_bin = argv[7];
-
-  std::string wav_filename = argv[8];
-
-  int32_t num_threads = 4;
-  if (argc == 10) {
+  sherpa_ncnn::ModelConfig model_conf;
+  model_conf.tokens = argv[1];
+  model_conf.encoder_param = argv[2];
+  model_conf.encoder_bin = argv[3];
+  model_conf.decoder_param = argv[4];
+  model_conf.decoder_bin = argv[5];
+  model_conf.joiner_param = argv[6];
+  model_conf.joiner_bin = argv[7];
+  int num_threads = 4;
+  if (argc >= 10 && atoi(argv[9]) > 0) {
     num_threads = atoi(argv[9]);
   }
-  config.encoder_opt.num_threads = num_threads;
-  config.decoder_opt.num_threads = num_threads;
-  config.joiner_opt.num_threads = num_threads;
-
-  float expected_sampling_rate = 16000;
-
-  sherpa_ncnn::SymbolTable sym(config.tokens);
+  model_conf.encoder_opt.num_threads = num_threads;
+  model_conf.decoder_opt.num_threads = num_threads;
+  model_conf.joiner_opt.num_threads = num_threads;
+
+  const float expected_sampling_rate = 16000;
+  sherpa_ncnn::DecoderConfig decoder_conf;
+  if (argc == 11) {
+    std::string method = argv[10];
+    if (method.compare("greed_search") ||
+        method.compare("modified_beam_search")) {
+      decoder_conf.method = method;
+    }
+  }
+  knf::FbankOptions fbank_opts;
+  fbank_opts.frame_opts.dither = 0;
+  fbank_opts.frame_opts.snip_edges = false;
+  fbank_opts.frame_opts.samp_freq = expected_sampling_rate;
+  fbank_opts.mel_opts.num_bins = 80;
 
-  std::cout << config.ToString() << "\n";
+  sherpa_ncnn::Recognizer recognizer(decoder_conf, model_conf, fbank_opts);
 
-  auto model = sherpa_ncnn::Model::Create(config);
-  if (!model) {
-    std::cout << "Failed to create a model\n";
-    exit(EXIT_FAILURE);
-  }
+  std::string wav_filename = argv[8];
 
+  std::cout << model_conf.ToString() << "\n";
   bool is_ok = false;
   std::vector<float> samples =
       sherpa_ncnn::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
@@ -88,63 +89,21 @@ https://huggingface.co/csukuangfj/sherpa-ncnn-2022-09-05
     exit(-1);
   }
 
-  float duration = samples.size() / expected_sampling_rate;
-
+  const float duration = samples.size() / expected_sampling_rate;
   std::cout << "wav filename: " << wav_filename << "\n";
   std::cout << "wav duration (s): " << duration << "\n";
 
-  knf::FbankOptions fbank_opts;
-  fbank_opts.frame_opts.dither = 0;
-  fbank_opts.frame_opts.snip_edges = false;
-  fbank_opts.frame_opts.samp_freq = expected_sampling_rate;
-  fbank_opts.mel_opts.num_bins = 80;
-
-  sherpa_ncnn::FeatureExtractor feature_extractor(fbank_opts);
-  feature_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
-                                   samples.size());
-
+  recognizer.AcceptWaveform(expected_sampling_rate, samples.data(),
+                            samples.size());
   std::vector<float> tail_paddings(
       static_cast<int>(0.3 * expected_sampling_rate));
-  feature_extractor.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
-                                   tail_paddings.size());
-
-  feature_extractor.InputFinished();
-
-  int32_t segment = model->Segment();
-  int32_t offset = model->Offset();
-
-  int32_t context_size = model->ContextSize();
-  int32_t blank_id = model->BlankId();
-
-  std::vector<int32_t> hyp(context_size, blank_id);
-
-  ncnn::Mat decoder_input(context_size);
-  for (int32_t i = 0; i != context_size; ++i) {
-    static_cast<int32_t *>(decoder_input)[i] = blank_id;
-  }
-
-  ncnn::Mat decoder_out = model->RunDecoder(decoder_input);
-
-  std::vector<ncnn::Mat> states;
-  ncnn::Mat encoder_out;
-
-  int32_t num_processed = 0;
-  while (feature_extractor.NumFramesReady() - num_processed >= segment) {
-    ncnn::Mat features = feature_extractor.GetFrames(num_processed, segment);
-    num_processed += offset;
-
-    std::tie(encoder_out, states) = model->RunEncoder(features, states);
-
-    GreedySearch(model.get(), encoder_out, &decoder_out, &hyp);
-  }
-
-  std::string text;
-  for (int32_t i = context_size; i != hyp.size(); ++i) {
-    text += sym[hyp[i]];
-  }
+  recognizer.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
+                            tail_paddings.size());
 
+  recognizer.Decode();
+  auto result = recognizer.GetResult();
   std::cout << "Recognition result for " << wav_filename << "\n"
-            << text << "\n";
+            << result.text << "\n";
 
   return 0;
 }

+ 24 - 73
sherpa-ncnn/jni/jni.cc

@@ -1,5 +1,6 @@
 /**
  * Copyright (c)  2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ * Copyright (c)  2022                     (Pingfeng Luo)
  *
  * See LICENSE for clarification regarding multiple authors
  *
@@ -27,10 +28,7 @@
 
 #include "android/asset_manager.h"
 #include "android/asset_manager_jni.h"
-#include "sherpa-ncnn/csrc/decode.h"
-#include "sherpa-ncnn/csrc/features.h"
-#include "sherpa-ncnn/csrc/model.h"
-#include "sherpa-ncnn/csrc/symbol-table.h"
+#include "sherpa-ncnn/csrc/recognizer.h"
 #include "sherpa-ncnn/csrc/wave-reader.h"
 
 #define SHERPA_EXTERN_C extern "C"
@@ -39,81 +37,33 @@ namespace sherpa_ncnn {
 
 class SherpaNcnn {
  public:
-  SherpaNcnn(AAssetManager *mgr, const ModelConfig &model_config,
-             const knf::FbankOptions &fbank_config)
-      : model_(Model::Create(mgr, model_config)),
-        feature_extractor_(std::make_unique<FeatureExtractor>(fbank_config)),
-        sym_(mgr, model_config.tokens) {
-    Reset();
-  }
+  SherpaNcnn(AAssetManager *mgr,
+             const sherpa_ncnn::DecoderConfig &decoder_config,
+             const ModelConfig &model_config,
+             const knf::FbankOptions &fbank_opts)
+      : recognizer_(mgr, decoder_config, model_config, fbank_opts),
+        tail_padding_(16000 * 0.32, 0) {}
 
   void DecodeSamples(float sample_rate, const float *samples, int32_t n) {
-    feature_extractor_->AcceptWaveform(sample_rate, samples, n);
-    Decode();
+    recognizer_.AcceptWaveform(sample_rate, samples, n);
+    recognizer_.Decode();
   }
 
   void InputFinished() {
-    feature_extractor_->InputFinished();
-    Decode();
+    recognizer_.AcceptWaveform(16000, tail_padding_.data(),
+                               tail_padding_.size());
+    recognizer_.InputFinished();
+    recognizer_.Decode();
   }
 
-  std::string GetText() const {
-    int32_t context_size = model_->ContextSize();
-
-    std::string text;
-    for (int32_t i = context_size; i != static_cast<int32_t>(hyp_.size());
-         ++i) {
-      text += sym_[hyp_[i]];
-    }
-    return text;
-  }
-
-  void Reset() {
-    feature_extractor_->Reset();
-    num_processed_ = 0;
-    states_.clear();
-
-    int32_t context_size = model_->ContextSize();
-    int32_t blank_id = 0;
-
-    ncnn::Mat decoder_input(context_size);
-    for (int32_t i = 0; i != context_size; ++i) {
-      static_cast<int32_t *>(decoder_input)[i] = blank_id;
-    }
-
-    decoder_out_ = model_->RunDecoder(decoder_input);
-
-    hyp_.resize(context_size, 0);
+  const std::string GetText() {
+    auto result = recognizer_.GetResult();
+    return result.text;
   }
 
  private:
-  void Decode() {
-    int32_t segment = model_->Segment();
-    int32_t offset = model_->Offset();
-
-    ncnn::Mat encoder_out;
-    while (feature_extractor_->NumFramesReady() - num_processed_ >= segment) {
-      ncnn::Mat features =
-          feature_extractor_->GetFrames(num_processed_, segment);
-      num_processed_ += offset;
-
-      std::tie(encoder_out, states_) = model_->RunEncoder(features, states_);
-
-      GreedySearch(model_.get(), encoder_out, &decoder_out_, &hyp_);
-    }
-  }
-
- private:
-  std::unique_ptr<Model> model_;
-  std::unique_ptr<FeatureExtractor> feature_extractor_;
-  sherpa_ncnn::SymbolTable sym_;
-
-  std::vector<int32_t> hyp_;
-  ncnn::Mat decoder_out_;
-  std::vector<ncnn::Mat> states_;
-
-  // number of processed frames
-  int32_t num_processed_ = 0;
+  sherpa_ncnn::Recognizer recognizer_;
+  std::vector<float> tail_padding_;
 };
 
 static ModelConfig GetModelConfig(JNIEnv *env, jobject config) {
@@ -290,10 +240,13 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_new(
   sherpa_ncnn::ModelConfig model_config =
       sherpa_ncnn::GetModelConfig(env, _model_config);
 
+  sherpa_ncnn::DecoderConfig decoder_config;
+
   knf::FbankOptions fbank_opts =
       sherpa_ncnn::GetFbankOptions(env, _fbank_config);
 
-  auto model = new sherpa_ncnn::SherpaNcnn(mgr, model_config, fbank_opts);
+  auto model = new sherpa_ncnn::SherpaNcnn(mgr, decoder_config, model_config,
+                                           fbank_opts);
 
   return (jlong)model;
 }
@@ -306,9 +259,7 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_delete(
 
 SHERPA_EXTERN_C
 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_reset(
-    JNIEnv *env, jobject /*obj*/, jlong ptr) {
-  reinterpret_cast<sherpa_ncnn::SherpaNcnn *>(ptr)->Reset();
-}
+    JNIEnv *env, jobject /*obj*/, jlong ptr) {}
 
 SHERPA_EXTERN_C
 JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_ncnn_SherpaNcnn_decodeSamples(