فهرست منبع

reduce code (#62)

* reduce code

* minor fix
PF Luo 2 سال پیش
والد
کامیت
2cd24efa58
3فایلهای تغییر یافته به همراه6 افزوده شده و 42 حذف شده
  1. 0 20
      sherpa-ncnn/csrc/hypothesis.cc
  2. 0 15
      sherpa-ncnn/csrc/hypothesis.h
  3. 6 7
      sherpa-ncnn/csrc/modified-beam-search-decoder.cc

+ 0 - 20
sherpa-ncnn/csrc/hypothesis.cc

@@ -55,24 +55,4 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
   }
 }
 
-Hypothesis Hypotheses::GetLeastProbable(bool length_norm) const {
-  if (length_norm == false) {
-    return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
-                            [](const auto &left, auto &right) -> bool {
-                              return left.second.log_prob >
-                                     right.second.log_prob;
-                            })
-        ->second;
-  } else {
-    // for length_norm is true
-    return std::max_element(
-               hyps_dict_.begin(), hyps_dict_.end(),
-               [](const auto &left, const auto &right) -> bool {
-                 return left.second.log_prob / left.second.ys.size() >
-                        right.second.log_prob / right.second.ys.size();
-               })
-        ->second;
-  }
-}
-
 }  // namespace sherpa_ncnn

+ 0 - 15
sherpa-ncnn/csrc/hypothesis.h

@@ -87,25 +87,10 @@ class Hypotheses {
   // len(hyp.ys) before comparison.
   Hypothesis GetMostProbable(bool length_norm) const;
 
-  // Get the hyp that has the least log_prob.
-  // If length_norm is true, hyp's log_prob are divided by
-  // len(hyp.ys) before comparison.
-  Hypothesis GetLeastProbable(bool length_norm) const;
-
   // Remove the given hyp from this object.
   // It is *NOT* an error if hyp does not exist in this object.
   void Remove(const Hypothesis &hyp) { hyps_dict_.erase(hyp.Key()); }
 
-  // Return a list of hyps contained in this object.
-  std::vector<Hypothesis> Vec() const {
-    std::vector<Hypothesis> ans;
-    ans.reserve(hyps_dict_.size());
-    for (const auto &p : hyps_dict_) {
-      ans.push_back(p.second);
-    }
-    return ans;
-  }
-
   int32_t Size() const { return hyps_dict_.size(); }
 
   std::string ToString() const {

+ 6 - 7
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -54,10 +54,14 @@ void ModifiedBeamSearchDecoder::Decode() {
         model_->RunEncoder(features, encoder_state_);
 
     Hypotheses cur = std::move(result_.hyps);
-    std::vector<Hypothesis> prev;
     /* encoder_out_.w == encoder_out_dim, encoder_out_.h == num_frames. */
     for (int32_t t = 0; t != encoder_out_.h; ++t) {
-      prev = std::move(cur.Vec());
+      std::vector<Hypothesis> prev;
+      for (int32_t i = 0; i != config_.num_active_paths && cur.Size(); ++i) {
+        auto cur_best_hyp = cur.GetMostProbable(true);
+        cur.Remove(cur_best_hyp);
+        prev.push_back(std::move(cur_best_hyp));
+      }
       cur.clear();
 
       for (const auto &h : prev) {
@@ -84,11 +88,6 @@ void ModifiedBeamSearchDecoder::Decode() {
           cur.Add(std::move(new_hyp));
         }
       }
-      // prune active_paths
-      while (cur.Size() > config_.num_active_paths) {
-        auto least_hyp = cur.GetLeastProbable(true);
-        cur.Remove(least_hyp);
-      }
     }
 
     num_processed_ += offset_;