ソースを参照

Add cumstomized score for hotwords & add Finalize to stream (#281)

Wei Kang 1 年間 前
コミット
3c7724c137

+ 16 - 1
.github/scripts/run-test.sh

@@ -52,6 +52,21 @@ for wave in ${waves[@]}; do
   done
 done
 
+log "Start testing ${repo_url} with hotwords"
+
+time $EXE \
+  $repo/tokens.txt \
+  $repo/encoder_jit_trace-pnnx.ncnn.param \
+  $repo/encoder_jit_trace-pnnx.ncnn.bin \
+  $repo/decoder_jit_trace-pnnx.ncnn.param \
+  $repo/decoder_jit_trace-pnnx.ncnn.bin \
+  $repo/joiner_jit_trace-pnnx.ncnn.param \
+  $repo/joiner_jit_trace-pnnx.ncnn.bin \
+  $repo/test_wavs/1.wav \
+  2 \
+  modified_beam_search \
+  $repo/test_wavs/hotwords.txt
+
 rm -rf $repo
 
 log "------------------------------------------------------------"
@@ -588,4 +603,4 @@ time $EXE \
   modified_beam_search \
   $repo/hotwords.txt 1.6
 
-rm -rf $repo
+rm -rf $repo

+ 2 - 0
sherpa-ncnn/csrc/CMakeLists.txt

@@ -77,4 +77,6 @@ endif()
 if(SHERPA_NCNN_ENABLE_TEST)
   add_executable(test-resample test-resample.cc)
   target_link_libraries(test-resample sherpa-ncnn-core)
+  add_executable(test-context-graph test-context-graph.cc)
+  target_link_libraries(test-context-graph sherpa-ncnn-core)
 endif()

+ 74 - 7
sherpa-ncnn/csrc/context-graph.cc

@@ -4,22 +4,57 @@
 
 #include "sherpa-ncnn/csrc/context-graph.h"
 
+#include <algorithm>
 #include <cassert>
 #include <queue>
+#include <string>
+#include <tuple>
 #include <utility>
 
 namespace sherpa_ncnn {
-void ContextGraph::Build(
-    const std::vector<std::vector<int32_t>> &token_ids) const {
+void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
+                         const std::vector<float> &scores,
+                         const std::vector<std::string> &phrases,
+                         const std::vector<float> &ac_thresholds) const {
+  if (!scores.empty()) {
+    assert(token_ids.size() == scores.size());
+  }
+  if (!phrases.empty()) {
+    assert(token_ids.size() == phrases.size());
+  }
+  if (!ac_thresholds.empty()) {
+    assert(token_ids.size() == ac_thresholds.size());
+  }
   for (int32_t i = 0; i < token_ids.size(); ++i) {
     auto node = root_.get();
+    float score = scores.empty() ? 0.0f : scores[i];
+    score = score == 0.0f ? context_score_ : score;
+    float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
+    ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
+    std::string phrase = phrases.empty() ? std::string() : phrases[i];
+
     for (int32_t j = 0; j < token_ids[i].size(); ++j) {
       int32_t token = token_ids[i][j];
       if (0 == node->next.count(token)) {
         bool is_end = j == token_ids[i].size() - 1;
         node->next[token] = std::make_unique<ContextState>(
-            token, context_score_, node->node_score + context_score_,
-            is_end ? node->node_score + context_score_ : 0, is_end);
+            token, score, node->node_score + score,
+            is_end ? node->node_score + score : 0, j + 1,
+            is_end ? ac_threshold : 0.0f, is_end,
+            is_end ? phrase : std::string());
+      } else {
+        float token_score = std::max(score, node->next[token]->token_score);
+        node->next[token]->token_score = token_score;
+        float node_score = node->node_score + token_score;
+        node->next[token]->node_score = node_score;
+        bool is_end =
+            (j == token_ids[i].size() - 1) || node->next[token]->is_end;
+        node->next[token]->output_score = is_end ? node_score : 0.0f;
+        node->next[token]->is_end = is_end;
+        if (j == token_ids[i].size() - 1) {
+          node->next[token]->phrase = phrase;
+          node->next[token]->ac_threshold = ac_threshold;
+        }
       }
       node = node->next[token].get();
     }
@@ -27,8 +62,9 @@ void ContextGraph::Build(
   FillFailOutput();
 }
 
-std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
-    const ContextState *state, int32_t token) const {
+std::tuple<float, const ContextState *, const ContextState *>
+ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
+                             bool strict_mode /*= true*/) const {
   const ContextState *node;
   float score;
   if (1 == state->next.count(token)) {
@@ -45,7 +81,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
     }
     score = node->node_score - state->node_score;
   }
-  return std::make_pair(score + node->output_score, node);
+
+  assert(nullptr != node);
+
+  const ContextState *matched_node =
+      node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
+
+  if (!strict_mode && node->output_score != 0) {
+    assert(nullptr != matched_node);
+    float output_score =
+        node->is_end ? node->node_score
+                     : (node->output != nullptr ? node->output->node_score
+                                                : node->node_score);
+    return std::make_tuple(score + output_score - node->node_score, root_.get(),
+                           matched_node);
+  }
+  return std::make_tuple(score + node->output_score, node, matched_node);
 }
 
 std::pair<float, const ContextState *> ContextGraph::Finalize(
@@ -54,6 +105,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
   return std::make_pair(score, root_.get());
 }
 
+std::pair<bool, const ContextState *> ContextGraph::IsMatched(
+    const ContextState *state) const {
+  bool status = false;
+  const ContextState *node = nullptr;
+  if (state->is_end) {
+    status = true;
+    node = state;
+  } else {
+    if (state->output != nullptr) {
+      status = true;
+      node = state->output;
+    }
+  }
+  return std::make_pair(status, node);
+}
+
 void ContextGraph::FillFailOutput() const {
   std::queue<const ContextState *> node_queue;
   for (auto &kv : root_->next) {

+ 36 - 10
sherpa-ncnn/csrc/context-graph.h

@@ -6,11 +6,12 @@
 #define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_
 
 #include <memory>
+#include <string>
+#include <tuple>
 #include <unordered_map>
 #include <utility>
 #include <vector>
 
-
 namespace sherpa_ncnn {
 
 class ContextGraph;
@@ -21,34 +22,55 @@ struct ContextState {
   float token_score;
   float node_score;
   float output_score;
+  int32_t level;
+  float ac_threshold;
   bool is_end;
+  std::string phrase;
   std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
   const ContextState *fail = nullptr;
   const ContextState *output = nullptr;
 
   ContextState() = default;
   ContextState(int32_t token, float token_score, float node_score,
-               float output_score, bool is_end)
+               float output_score, int32_t level = 0, float ac_threshold = 0.0f,
+               bool is_end = false, const std::string &phrase = {})
       : token(token),
         token_score(token_score),
         node_score(node_score),
         output_score(output_score),
-        is_end(is_end) {}
+        level(level),
+        ac_threshold(ac_threshold),
+        is_end(is_end),
+        phrase(phrase) {}
 };
 
 class ContextGraph {
  public:
   ContextGraph() = default;
   ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
-               float hotwords_score)
-      : context_score_(hotwords_score) {
-    root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
+               float context_score, float ac_threshold,
+               const std::vector<float> &scores = {},
+               const std::vector<std::string> &phrases = {},
+               const std::vector<float> &ac_thresholds = {})
+      : context_score_(context_score), ac_threshold_(ac_threshold) {
+    root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
     root_->fail = root_.get();
-    Build(token_ids);
+    Build(token_ids, scores, phrases, ac_thresholds);
   }
 
-  std::pair<float, const ContextState *> ForwardOneStep(
-      const ContextState *state, int32_t token_id) const;
+  ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
+               float context_score, const std::vector<float> &scores = {},
+               const std::vector<std::string> &phrases = {})
+      : ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
+                     std::vector<float>()) {}
+
+  std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
+      const ContextState *state, int32_t token_id,
+      bool strict_mode = true) const;
+
+  std::pair<bool, const ContextState *> IsMatched(
+      const ContextState *state) const;
+
   std::pair<float, const ContextState *> Finalize(
       const ContextState *state) const;
 
@@ -56,8 +78,12 @@ class ContextGraph {
 
  private:
   float context_score_;
+  float ac_threshold_;
   std::unique_ptr<ContextState> root_;
-  void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
+  void Build(const std::vector<std::vector<int32_t>> &token_ids,
+             const std::vector<float> &scores,
+             const std::vector<std::string> &phrases,
+             const std::vector<float> &ac_thresholds) const;
   void FillFailOutput() const;
 };
 

+ 5 - 80
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -117,82 +117,7 @@ ncnn::Mat ModifiedBeamSearchDecoder::BuildDecoderInput(
 
 void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
                                        DecoderResult *result) {
-  int32_t context_size = model_->ContextSize();
-  Hypotheses cur = std::move(result->hyps);
-  /* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */
-  for (int32_t t = 0; t != encoder_out.h; ++t) {
-    std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
-    cur.Clear();
-
-    ncnn::Mat decoder_input = BuildDecoderInput(prev);
-    ncnn::Mat decoder_out;
-    if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
-        !result->decoder_out.empty()) {
-      // When an endpoint is detected, we keep the decoder_out
-      decoder_out = result->decoder_out;
-    } else {
-      decoder_out = RunDecoder2D(model_, decoder_input);
-    }
-
-    // decoder_out.w == decoder_dim
-    // decoder_out.h == num_active_paths
-    ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
-    // Note: encoder_out_t.h == 1, we rely on the binary op broadcasting
-    // in ncnn
-    // See https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
-    // broadcast B for outer axis, type 14
-    ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
-
-    // joiner_out.w == vocab_size
-    // joiner_out.h == num_active_paths
-    LogSoftmax(&joiner_out);
-
-    float *p_joiner_out = joiner_out;
-
-    for (int32_t i = 0; i != joiner_out.h; ++i) {
-      float prev_log_prob = prev[i].log_prob;
-      for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
-        *p_joiner_out += prev_log_prob;
-      }
-    }
-
-    auto topk = TopkIndex(static_cast<float *>(joiner_out),
-                          joiner_out.w * joiner_out.h, num_active_paths_);
-
-    int32_t frame_offset = result->frame_offset;
-    for (auto i : topk) {
-      int32_t hyp_index = i / joiner_out.w;
-      int32_t new_token = i % joiner_out.w;
-
-      const float *p = joiner_out.row(hyp_index);
-
-      Hypothesis new_hyp = prev[hyp_index];
-
-      // blank id is fixed to 0
-      if (new_token != 0 && new_token != 2) {
-        new_hyp.ys.push_back(new_token);
-        new_hyp.num_trailing_blanks = 0;
-        new_hyp.timestamps.push_back(t + frame_offset);
-      } else {
-        ++new_hyp.num_trailing_blanks;
-      }
-      // We have already added prev[hyp_index].log_prob to p[new_token]
-      new_hyp.log_prob = p[new_token];
-
-      cur.Add(std::move(new_hyp));
-    }
-  }
-
-  result->hyps = std::move(cur);
-  result->frame_offset += encoder_out.h;
-  auto hyp = result->hyps.GetMostProbable(true);
-
-  // set decoder_out in case of endpointing
-  ncnn::Mat decoder_input = BuildDecoderInput({hyp});
-  result->decoder_out = model_->RunDecoder(decoder_input);
-
-  result->tokens = std::move(hyp.ys);
-  result->num_trailing_blanks = hyp.num_trailing_blanks;
+  Decode(encoder_out, nullptr, result);
 }
 
 void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
@@ -252,10 +177,10 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
         new_hyp.num_trailing_blanks = 0;
         new_hyp.timestamps.push_back(t + frame_offset);
         if (s && s->GetContextGraph()) {
-          auto context_res =
-              s->GetContextGraph()->ForwardOneStep(context_state, new_token);
-          context_score = context_res.first;
-          new_hyp.context_state = context_res.second;
+          auto context_res = s->GetContextGraph()->ForwardOneStep(
+              context_state, new_token, false /*strict_mode*/);
+          context_score = std::get<0>(context_res);
+          new_hyp.context_state = std::get<1>(context_res);
         }
       } else {
         ++new_hyp.num_trailing_blanks;

+ 25 - 7
sherpa-ncnn/csrc/recognizer.cc

@@ -25,6 +25,7 @@
 #include <utility>
 #include <vector>
 
+#include "sherpa-ncnn/csrc/context-graph.h"
 #include "sherpa-ncnn/csrc/decoder.h"
 #include "sherpa-ncnn/csrc/greedy-search-decoder.h"
 #include "sherpa-ncnn/csrc/modified-beam-search-decoder.h"
@@ -225,7 +226,11 @@ class Recognizer::Impl {
   }
 
   RecognitionResult GetResult(Stream *s) const {
+    if (IsEndpoint(s)) {
+      s->Finalize();
+    }
     DecoderResult decoder_result = s->GetResult();
+
     decoder_->StripLeadingBlanks(&decoder_result);
 
     // Those 2 parameters are figured out from sherpa source code
@@ -272,23 +277,35 @@ class Recognizer::Impl {
     std::vector<int32_t> tmp;
     std::string line;
     std::string word;
-
+    // The format of each line in hotwords_file looks like:
+    // ▁HE LL O ▁WORLD :1.5
+    // the first several items are tokens of the hotword, the item starts with
+    // ":" is the customize boosting score for this hotword, if there is no
+    // customize score it will use the score from configuration (i.e.
+    // config_.hotwords_score).
     while (std::getline(is, line)) {
       std::istringstream iss(line);
+      float tmp_score = 0.0;  // MUST be 0.0, meaning if no customize score use
+                              // the global one.
       while (iss >> word) {
         if (sym_.contains(word)) {
           int32_t number = sym_[word];
           tmp.push_back(number);
         } else {
-          NCNN_LOGE(
-              "Cannot find ID for hotword %s at line: %s. (Hint: words on the "
-              "same line are separated by spaces)",
-              word.c_str(), line.c_str());
-          exit(-1);
+          if (word[0] == ':') {
+            tmp_score = std::stof(word.substr(1));
+          } else {
+            NCNN_LOGE(
+                "Cannot find ID for hotword %s at line: %s. (Hint: words on "
+                "the "
+                "same line are separated by spaces)",
+                word.c_str(), line.c_str());
+            exit(-1);
+          }
         }
       }
-
       hotwords_.push_back(std::move(tmp));
+      boost_scores_.push_back(tmp_score);
     }
   }
 
@@ -299,6 +316,7 @@ class Recognizer::Impl {
   Endpoint endpoint_;
   SymbolTable sym_;
   std::vector<std::vector<int32_t>> hotwords_;
+  std::vector<float> boost_scores_;
 };
 
 Recognizer::Recognizer(const RecognizerConfig &config)

+ 13 - 1
sherpa-ncnn/csrc/sherpa-ncnn-alsa.cc

@@ -47,7 +47,7 @@ Usage:
     /path/to/joiner.ncnn.param \
     /path/to/joiner.ncnn.bin \
     device_name \
-    [num_threads] [decode_method, can be greedy_search/modified_beam_search]
+    [num_threads] [decode_method, can be greedy_search/modified_beam_search] [hotwords_file] [hotwords_score]
 
 Please refer to
 https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
@@ -108,6 +108,14 @@ as the device_name.
     }
   }
 
+  if (argc >= 11) {
+    config.hotwords_file = argv[10];
+  }
+
+  if (argc == 12) {
+    config.hotwords_score = atof(argv[11]);
+  }
+
   int32_t expected_sampling_rate = 16000;
 
   config.enable_endpoint = true;
@@ -148,6 +156,10 @@ as the device_name.
     }
 
     bool is_endpoint = recognizer.IsEndpoint(s.get());
+
+    if (is_endpoint) {
+      s->Finalize();
+    }
     auto text = recognizer.GetResult(s.get()).text;
 
     if (!text.empty() && last_text != text) {

+ 13 - 1
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -60,7 +60,7 @@ Usage:
     /path/to/decoder.ncnn.bin \
     /path/to/joiner.ncnn.param \
     /path/to/joiner.ncnn.bin \
-    [num_threads] [decode_method, can be greedy_search/modified_beam_search]
+    [num_threads] [decode_method, can be greedy_search/modified_beam_search] [hotwords_file] [hotwords_score]
 
 Please refer to
 https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
@@ -97,6 +97,14 @@ for a list of pre-trained models to download.
     }
   }
 
+  if (argc >= 11) {
+    config.hotwords_file = argv[10];
+  }
+
+  if (argc == 12) {
+    config.hotwords_score = atof(argv[11]);
+  }
+
   config.enable_endpoint = true;
 
   config.endpoint_config.rule1.min_trailing_silence = 2.4;
@@ -166,6 +174,10 @@ for a list of pre-trained models to download.
     }
 
     bool is_endpoint = recognizer.IsEndpoint(s.get());
+
+    if (is_endpoint) {
+      s->Finalize();
+    }
     auto text = recognizer.GetResult(s.get()).text;
 
     if (!text.empty() && last_text != text) {

+ 2 - 2
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -40,7 +40,7 @@ Usage:
     /path/to/decoder.ncnn.bin \
     /path/to/joiner.ncnn.param \
     /path/to/joiner.ncnn.bin \
-    /path/to/foo.wav [num_threads] [decode_method, can be greedy_search/modified_beam_search]
+    /path/to/foo.wav [num_threads] [decode_method, can be greedy_search/modified_beam_search] [hotwords_file] [hotwords_score]
 
 Please refer to
 https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
@@ -112,10 +112,10 @@ for a list of pre-trained models to download.
       static_cast<int>(0.3 * expected_sampling_rate));
   stream->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
                          tail_paddings.size());
-
   while (recognizer.IsReady(stream.get())) {
     recognizer.DecodeStream(stream.get());
   }
+  stream->Finalize();
   auto result = recognizer.GetResult(stream.get());
   std::cout << "Done!\n";
 

+ 16 - 0
sherpa-ncnn/csrc/stream.cc

@@ -18,6 +18,8 @@
 
 #include "sherpa-ncnn/csrc/stream.h"
 
+#include <iostream>
+
 namespace sherpa_ncnn {
 
 class Stream::Impl {
@@ -49,6 +51,18 @@ class Stream::Impl {
     num_processed_frames_ = 0;
   }
 
+  void Finalize() {
+    if (!context_graph_) return;
+    auto &cur = result_.hyps;
+    for (auto iter = cur.begin(); iter != cur.end(); ++iter) {
+      auto context_res = context_graph_->Finalize(iter->second.context_state);
+      iter->second.log_prob += context_res.first;
+      iter->second.context_state = context_res.second;
+    }
+    auto hyp = result_.hyps.GetMostProbable(true);
+    result_.tokens = std::move(hyp.ys);
+  }
+
   int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
 
   void SetResult(const DecoderResult &r) {
@@ -99,6 +113,8 @@ ncnn::Mat Stream::GetFrames(int32_t frame_index, int32_t n) const {
 
 void Stream::Reset() { impl_->Reset(); }
 
+void Stream::Finalize() { impl_->Finalize(); }
+
 int32_t &Stream::GetNumProcessedFrames() {
   return impl_->GetNumProcessedFrames();
 }

+ 9 - 0
sherpa-ncnn/csrc/stream.h

@@ -70,6 +70,15 @@ class Stream {
 
   void Reset();
 
+  /**
+   * Finalize the decoding result. This is mainly for decoding with hotwords
+   * (i.e. providing context_graph). It will cancel the boosting score of the
+   * partial matching paths. For example, the hotword is "BCD", the path "ABC"
+   * gets boosting score of "BC" but it fails to match the whole hotword "BCD",
+   * so we have to cancel the scores of "BC" at the end.
+   */
+  void Finalize();
+
   // Return a reference to the number of processed frames so far
   // before subsampling..
   // Initially, it is 0. It is always less than NumFramesReady().

+ 105 - 0
sherpa-ncnn/csrc/test-context-graph.cc

@@ -0,0 +1,105 @@
+// sherpa-ncnn/csrc/test-context-graph.cc
+//
+// Copyright (c)  2023-2024  Xiaomi Corporation
+
+#include <cassert>
+#include <chrono>  // NOLINT
+#include <cmath>
+#include <map>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "sherpa-ncnn/csrc/context-graph.h"
+
+static void TestHelper(const std::map<std::string, float> &queries, float score,
+                       bool strict_mode) {
+  std::vector<std::string> contexts_str(
+      {"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
+  std::vector<std::vector<int32_t>> contexts;
+  std::vector<float> scores;
+  for (int32_t i = 0; i < contexts_str.size(); ++i) {
+    contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
+    scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
+  }
+  auto context_graph = sherpa_ncnn::ContextGraph(contexts, 1, scores);
+
+  for (const auto &iter : queries) {
+    float total_scores = 0;
+    auto state = context_graph.Root();
+    for (auto q : iter.first) {
+      auto res = context_graph.ForwardOneStep(state, q, strict_mode);
+      total_scores += std::get<0>(res);
+      state = std::get<1>(res);
+    }
+    auto res = context_graph.Finalize(state);
+    assert(res.second->token == -1);
+    total_scores += res.first;
+    assert(total_scores == iter.second);
+  }
+}
+
+static void TestBasic() {
+  auto queries = std::map<std::string, float>{
+      {"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
+      {"SHED", 6},      {"SHELF", 6},   {"HELL", 2},
+      {"HELLO", 7},     {"DHRHISQ", 4}, {"THEN", 2}};
+  TestHelper(queries, 0, true);
+}
+
+static void TestBasicNonStrict() {
+  auto queries = std::map<std::string, float>{
+      {"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5},   {"SHED", 3}, {"SHELF", 3},
+      {"HELL", 2},     {"HELLO", 2},  {"DHRHISQ", 3}, {"THEN", 2}};
+  TestHelper(queries, 0, false);
+}
+
+static void TestCustomize() {
+  auto queries = std::map<std::string, float>{
+      {"HEHERSHE", 35.84}, {"HERSHE", 30.84},  {"HISHE", 24.18},
+      {"SHED", 18.34},     {"SHELF", 18.34},   {"HELL", 5},
+      {"HELLO", 13},       {"DHRHISQ", 10.84}, {"THEN", 5}};
+  TestHelper(queries, 5, true);
+}
+
+static void TestCustomizeNonStrict() {
+  auto queries = std::map<std::string, float>{
+      {"HEHERSHE", 20}, {"HERSHE", 15},    {"HISHE", 10.84},
+      {"SHED", 10},     {"SHELF", 10},     {"HELL", 5},
+      {"HELLO", 5},     {"DHRHISQ", 5.84}, {"THEN", 5}};
+  TestHelper(queries, 5, false);
+}
+
+static void Benchmark() {
+  std::random_device rd;
+  std::mt19937 mt(rd());
+  std::uniform_int_distribution<int32_t> char_dist(0, 25);
+  std::uniform_int_distribution<int32_t> len_dist(3, 8);
+  for (int32_t num = 10; num <= 10000; num *= 10) {
+    std::vector<std::vector<int32_t>> contexts;
+    for (int32_t i = 0; i < num; ++i) {
+      std::vector<int32_t> tmp;
+      int32_t word_len = len_dist(mt);
+      for (int32_t j = 0; j < word_len; ++j) {
+        tmp.push_back(char_dist(mt));
+      }
+      contexts.push_back(std::move(tmp));
+    }
+    auto start = std::chrono::high_resolution_clock::now();
+    auto context_graph = sherpa_ncnn::ContextGraph(contexts, 1);
+    auto stop = std::chrono::high_resolution_clock::now();
+    auto duration =
+        std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
+    fprintf(stderr, "Construct context graph for %d item takes %d us.\n", num,
+            static_cast<int32_t>(duration.count()));
+  }
+}
+
+int32_t main() {
+  TestBasic();
+  TestBasicNonStrict();
+  TestCustomize();
+  TestCustomizeNonStrict();
+  Benchmark();
+  return 0;
+}