|
@@ -33,7 +33,6 @@ void ModifiedBeamSearchDecoder::AcceptWaveform(const float sample_rate,
|
|
|
}
|
|
|
|
|
|
void ModifiedBeamSearchDecoder::BuildDecoderInput(Hypothesis hyp) {
|
|
|
- decoder_input_.reshape(context_size_);
|
|
|
for (int32_t i = 0; i != context_size_; ++i) {
|
|
|
static_cast<int32_t *>(decoder_input_)[i] =
|
|
|
*(hyp.ys.end() - context_size_ + i);
|
|
@@ -58,11 +57,7 @@ void ModifiedBeamSearchDecoder::Decode() {
|
|
|
std::vector<Hypothesis> prev;
|
|
|
/* encoder_out_.w == encoder_out_dim, encoder_out_.h == num_frames. */
|
|
|
for (int32_t t = 0; t != encoder_out_.h; ++t) {
|
|
|
- prev.clear();
|
|
|
- prev.reserve(config_.num_active_paths);
|
|
|
- for (auto &h : cur) {
|
|
|
- prev.push_back(std::move(h.second));
|
|
|
- }
|
|
|
+ prev = std::move(cur.Vec());
|
|
|
cur.clear();
|
|
|
|
|
|
for (const auto &h : prev) {
|
|
@@ -76,19 +71,20 @@ void ModifiedBeamSearchDecoder::Decode() {
|
|
|
// update active_paths
|
|
|
auto topk =
|
|
|
topk_index(joiner_out_ptr, joiner_out.w, config_.num_active_paths);
|
|
|
- for (int i = 0; i != topk.size(); i++) {
|
|
|
+ for (int i = 0; i != topk.size(); ++i) {
|
|
|
Hypothesis new_hyp = h;
|
|
|
int32_t new_token = topk[i];
|
|
|
if (new_token != blank_id_) {
|
|
|
new_hyp.ys.push_back(new_token);
|
|
|
new_hyp.num_trailing_blanks = 0;
|
|
|
} else {
|
|
|
- new_hyp.num_trailing_blanks += 1;
|
|
|
+ ++new_hyp.num_trailing_blanks;
|
|
|
}
|
|
|
new_hyp.log_prob += joiner_out_ptr[new_token];
|
|
|
cur.Add(std::move(new_hyp));
|
|
|
}
|
|
|
}
|
|
|
+ // prune active_paths
|
|
|
while (cur.Size() > config_.num_active_paths) {
|
|
|
auto least_hyp = cur.GetLeastProbable(true);
|
|
|
cur.Remove(least_hyp);
|