Quellcode durchsuchen

Treat <unk> as blank (#315)

Fangjun Kuang vor 1 Jahr
Ursprung
Commit
55159e098c

+ 1 - 1
sherpa-ncnn/csrc/greedy-search-decoder.cc

@@ -71,7 +71,7 @@ void GreedySearchDecoder::Decode(ncnn::Mat encoder_out, DecoderResult *result) {
         std::max_element(joiner_out_ptr, joiner_out_ptr + joiner_out.w)));
 
     // the blank ID is fixed to 0
-    if (new_token != 0) {
+    if (new_token != 0 && new_token != 2) {
       result->tokens.push_back(new_token);
       ncnn::Mat decoder_input = BuildDecoderInput(*result);
       decoder_out = model_->RunDecoder(decoder_input);

+ 2 - 2
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -169,7 +169,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
       Hypothesis new_hyp = prev[hyp_index];
 
       // blank id is fixed to 0
-      if (new_token != 0) {
+      if (new_token != 0 && new_token != 2) {
         new_hyp.ys.push_back(new_token);
         new_hyp.num_trailing_blanks = 0;
         new_hyp.timestamps.push_back(t + frame_offset);
@@ -247,7 +247,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
       float context_score = 0;
       auto context_state = new_hyp.context_state;
       // blank id is fixed to 0
-      if (new_token != 0) {
+      if (new_token != 0 && new_token != 2) {
         new_hyp.ys.push_back(new_token);
         new_hyp.num_trailing_blanks = 0;
         new_hyp.timestamps.push_back(t + frame_offset);