Procházet zdrojové kódy

Fix modified beam search (#186)

Fangjun Kuang před 2 roky
rodič
revize
c4dcf09141
1 změnil soubory, kde provedl 13 přidání a 1 odebrání
  1. 13 1
      sherpa-ncnn/csrc/modified-beam-search-decoder.cc

+ 13 - 1
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -166,6 +166,16 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
     // joiner_out.w == vocab_size
     // joiner_out.h == num_active_paths
     LogSoftmax(&joiner_out);
+
+    float *p_joiner_out = joiner_out;
+
+    for (int32_t i = 0; i != joiner_out.h; ++i) {
+      float prev_log_prob = prev[i].log_prob;
+      for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
+        *p_joiner_out += prev_log_prob;
+      }
+    }
+
     auto topk = TopkIndex(static_cast<float *>(joiner_out),
                           joiner_out.w * joiner_out.h, num_active_paths_);
 
@@ -186,7 +196,9 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
       } else {
         ++new_hyp.num_trailing_blanks;
       }
-      new_hyp.log_prob += p[new_token];
+      // We have already added prev[hyp_index].log_prob to p[new_token]
+      new_hyp.log_prob = p[new_token];
+
       cur.Add(std::move(new_hyp));
     }
   }