|
@@ -166,6 +166,16 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
|
|
|
// joiner_out.w == vocab_size
|
|
|
// joiner_out.h == num_active_paths
|
|
|
LogSoftmax(&joiner_out);
|
|
|
+
|
|
|
+ float *p_joiner_out = joiner_out;
|
|
|
+
|
|
|
+ for (int32_t i = 0; i != joiner_out.h; ++i) {
|
|
|
+ float prev_log_prob = prev[i].log_prob;
|
|
|
+ for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
|
|
|
+ *p_joiner_out += prev_log_prob;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
auto topk = TopkIndex(static_cast<float *>(joiner_out),
|
|
|
joiner_out.w * joiner_out.h, num_active_paths_);
|
|
|
|
|
@@ -186,7 +196,9 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
|
|
|
} else {
|
|
|
++new_hyp.num_trailing_blanks;
|
|
|
}
|
|
|
- new_hyp.log_prob += p[new_token];
|
|
|
+ // We have already added prev[hyp_index].log_prob to p[new_token]
|
|
|
+ new_hyp.log_prob = p[new_token];
|
|
|
+
|
|
|
cur.Add(std::move(new_hyp));
|
|
|
}
|
|
|
}
|