hypothesis.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /**
  2. * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang)
  3. *
  4. * See LICENSE for clarification regarding multiple authors
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef SHERPA_NCNN_CSRC_HYPOTHESIS_H_
  19. #define SHERPA_NCNN_CSRC_HYPOTHESIS_H_
  20. #include <sstream>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <vector>
  25. #include "sherpa-ncnn/csrc/context-graph.h"
  26. namespace sherpa_ncnn {
  27. struct Hypothesis {
  28. // The predicted tokens so far. Newly predicated tokens are appended.
  29. std::vector<int32_t> ys;
  30. // timestamps[i] contains the frame number after subsampling
  31. // on which ys[i] is decoded.
  32. std::vector<int32_t> timestamps;
  33. // The total score of ys in log space.
  34. double log_prob = 0;
  35. const ContextState *context_state;
  36. int32_t num_trailing_blanks = 0;
  37. Hypothesis() = default;
  38. Hypothesis(const std::vector<int32_t> &ys, double log_prob,
  39. const ContextState *context_state = nullptr)
  40. : ys(ys), log_prob(log_prob), context_state(context_state) {}
  41. // If two Hypotheses have the same `Key`, then they contain
  42. // the same token sequence.
  43. std::string Key() const {
  44. // TODO(fangjun): Use a hash function?
  45. std::ostringstream os;
  46. std::string sep = "-";
  47. for (auto i : ys) {
  48. os << i << sep;
  49. sep = "-";
  50. }
  51. return os.str();
  52. }
  53. // For debugging
  54. std::string ToString() const {
  55. std::ostringstream os;
  56. os << "(" << Key() << ", " << log_prob << ")";
  57. return os.str();
  58. }
  59. };
  60. class Hypotheses {
  61. public:
  62. Hypotheses() = default;
  63. explicit Hypotheses(std::vector<Hypothesis> hyps) {
  64. for (auto &h : hyps) {
  65. hyps_dict_[h.Key()] = std::move(h);
  66. }
  67. }
  68. explicit Hypotheses(std::unordered_map<std::string, Hypothesis> hyps_dict)
  69. : hyps_dict_(std::move(hyps_dict)) {}
  70. // Add hyp to this object. If it already exists, its log_prob
  71. // is updated with the given hyp using log-sum-exp.
  72. void Add(Hypothesis hyp);
  73. // Get the hyp that has the largest log_prob.
  74. // If length_norm is true, hyp's log_prob is divided by
  75. // len(hyp.ys) before comparison.
  76. Hypothesis GetMostProbable(bool length_norm) const;
  77. // Get the k hyps that have the largest log_prob.
  78. // If length_norm is true, hyp's log_prob is divided by
  79. // len(hyp.ys) before comparison.
  80. std::vector<Hypothesis> GetTopK(int32_t k, bool length_norm) const;
  81. int32_t Size() const { return hyps_dict_.size(); }
  82. std::string ToString() const {
  83. std::ostringstream os;
  84. for (const auto &p : hyps_dict_) {
  85. os << p.second.ToString() << "\n";
  86. }
  87. return os.str();
  88. }
  89. const auto begin() const { return hyps_dict_.begin(); }
  90. const auto end() const { return hyps_dict_.end(); }
  91. auto begin() { return hyps_dict_.begin(); }
  92. auto end() { return hyps_dict_.end(); }
  93. void Clear() { hyps_dict_.clear(); }
  94. private:
  95. // Return a list of hyps contained in this object.
  96. std::vector<Hypothesis> Vec() const {
  97. std::vector<Hypothesis> ans;
  98. ans.reserve(hyps_dict_.size());
  99. for (const auto &p : hyps_dict_) {
  100. ans.push_back(p.second);
  101. }
  102. return ans;
  103. }
  104. private:
  105. using Map = std ::unordered_map<std::string, Hypothesis>;
  106. Map hyps_dict_;
  107. };
  108. } // namespace sherpa_ncnn
  109. #endif // SHERPA_NCNN_CSRC_HYPOTHESIS_H_