Browse Source

Added Timetamp (#165)

Leo Huang 2 năm trước cách đây
mục cha
commit
3602f15556

+ 24 - 0
sherpa-ncnn/c-api/c-api.cc

@@ -117,17 +117,41 @@ void Decode(SherpaNcnnRecognizer *p, SherpaNcnnStream *s) {
 
 SherpaNcnnResult *GetResult(SherpaNcnnRecognizer *p, SherpaNcnnStream *s) {
   std::string text = p->recognizer->GetResult(s->stream.get()).text;
+  auto res = p->recognizer->GetResult(s->stream.get());
 
   auto r = new SherpaNcnnResult;
   r->text = new char[text.size() + 1];
   std::copy(text.begin(), text.end(), const_cast<char *>(r->text));
   const_cast<char *>(r->text)[text.size()] = 0;
+  r->count = res.tokens.size();
+  if (r->count > 0) {
+    // Each word ends with nullptr
+    r->tokens = new char[text.size() + r->count];
+    memset(reinterpret_cast<void*>(const_cast<char*>(r->tokens)), 0,
+             text.size() + r->count);
+    r->timestamps = new float[r->count];
+    int pos = 0;
+    for (int32_t i = 0; i < r->count; ++i) {
+      memcpy(reinterpret_cast<void*>(const_cast<char*>(r->tokens + pos)),
+             res.stokens[i].c_str(),
+             res.stokens[i].size());
+      pos += res.stokens[i].size() + 1;
+      r->timestamps[i] = res.timestamps[i];
+    }
+  } else {
+    r->timestamps = nullptr;
+    r->tokens = nullptr;
+  }
 
   return r;
 }
 
 void DestroyResult(const SherpaNcnnResult *r) {
   delete[] r->text;
+  if (r->timestamps != nullptr)
+      delete[] r->timestamps;
+  if (r->tokens != nullptr)
+      delete[] r->tokens;
   delete r;
 }
 

+ 12 - 1
sherpa-ncnn/c-api/c-api.h

@@ -113,8 +113,19 @@ typedef struct SherpaNcnnRecognizerConfig {
 } SherpaNcnnRecognizerConfig;
 
 typedef struct SherpaNcnnResult {
+  // Recognized text
   const char *text;
-  // TODO: Add more fields
+
+  // Pointer to continuous memory which holds string based tokens
+  // which are seperated by \0
+  const char *tokens;
+
+  // Pointer to continuous memory which holds timestamps which
+  // are seperated by \0
+  float* timestamps;
+
+  // The number of tokens/timestamps in above pointer
+  int32_t count;
 } SherpaNcnnResult;
 
 typedef struct SherpaNcnnRecognizer SherpaNcnnRecognizer;

+ 3 - 0
sherpa-ncnn/csrc/decoder.h

@@ -42,6 +42,9 @@ struct DecoderConfig {
 };
 
 struct DecoderResult {
+  /// Number of frames we have decoded so far, counted after subsampling
+  int32_t frame_offset = 0;
+
   /// The decoded token IDs so far
   std::vector<int32_t> tokens;
 

+ 3 - 0
sherpa-ncnn/csrc/greedy-search-decoder.cc

@@ -59,6 +59,7 @@ void GreedySearchDecoder::Decode(ncnn::Mat encoder_out, DecoderResult *result) {
     decoder_out = model_->RunDecoder(decoder_input);
   }
 
+  int32_t frame_offset = result->frame_offset;
   for (int32_t t = 0; t != encoder_out.h; ++t) {
     ncnn::Mat encoder_out_t(encoder_out.w, encoder_out.row(t));
     ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);
@@ -75,11 +76,13 @@ void GreedySearchDecoder::Decode(ncnn::Mat encoder_out, DecoderResult *result) {
       ncnn::Mat decoder_input = BuildDecoderInput(*result);
       decoder_out = model_->RunDecoder(decoder_input);
       result->num_trailing_blanks = 0;
+      result->timestamps.push_back(t + frame_offset);
     } else {
       ++result->num_trailing_blanks;
     }
   }
 
+  result->frame_offset += encoder_out.h;
   result->decoder_out = decoder_out;
 }
 

+ 3 - 0
sherpa-ncnn/csrc/modified-beam-search-decoder.cc

@@ -169,6 +169,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
     auto topk = TopkIndex(static_cast<float *>(joiner_out),
                           joiner_out.w * joiner_out.h, num_active_paths_);
 
+    int32_t frame_offset = result->frame_offset;
     for (auto i : topk) {
       int32_t hyp_index = i / joiner_out.w;
       int32_t new_token = i % joiner_out.w;
@@ -181,6 +182,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
       if (new_token != 0) {
         new_hyp.ys.push_back(new_token);
         new_hyp.num_trailing_blanks = 0;
+        new_hyp.timestamps.push_back(t + frame_offset);
       } else {
         ++new_hyp.num_trailing_blanks;
       }
@@ -190,6 +192,7 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
   }
 
   result->hyps = std::move(cur);
+  result->frame_offset += encoder_out.h;
   auto hyp = result->hyps.GetMostProbable(true);
 
   // set decoder_out in case of endpointing

+ 21 - 5
sherpa-ncnn/csrc/recognizer.cc

@@ -31,14 +31,27 @@
 namespace sherpa_ncnn {
 
 static RecognitionResult Convert(const DecoderResult &src,
-                                 const SymbolTable &sym_table) {
+                                 const SymbolTable &sym_table,
+                                 int32_t frame_shift_ms,
+                                 int32_t subsampling_factor) {
+  RecognitionResult ans;
+  ans.stokens.reserve(src.tokens.size());
+  ans.timestamps.reserve(src.timestamps.size());
+
   std::string text;
-  for (auto t : src.tokens) {
-    text += sym_table[t];
+  for (auto i : src.tokens) {
+    auto sym = sym_table[i];
+    text.append(sym);
+    ans.stokens.push_back(sym);
   }
 
-  RecognitionResult ans;
   ans.text = std::move(text);
+  ans.tokens = src.tokens;
+  float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
+  for (auto t : src.timestamps) {
+    float time = frame_shift_s * t;
+    ans.timestamps.push_back(time);
+  }
   return ans;
 }
 
@@ -163,7 +176,10 @@ class Recognizer::Impl {
     DecoderResult decoder_result = s->GetResult();
     decoder_->StripLeadingBlanks(&decoder_result);
 
-    return Convert(decoder_result, sym_);
+    // Those 2 parameters are figured out from sherpa source code
+    int32_t frame_shift_ms = 10;
+    int32_t subsampling_factor = 4;
+    return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor);
   }
 
  private:

+ 4 - 1
sherpa-ncnn/csrc/recognizer.h

@@ -34,9 +34,12 @@
 namespace sherpa_ncnn {
 
 struct RecognitionResult {
-  std::vector<int32_t> tokens;
   std::string text;
   std::vector<float> timestamps;
+  std::vector<int32_t> tokens;
+
+  // String based tokens
+  std::vector<std::string> stokens;
 
   std::string ToString() const;
 };

+ 5 - 1
sherpa-ncnn/csrc/stream.cc

@@ -50,7 +50,11 @@ class Stream::Impl {
 
   int32_t &GetNumProcessedFrames() { return num_processed_frames_; }
 
-  void SetResult(const DecoderResult &r) { result_ = r; }
+  void SetResult(const DecoderResult &r) { 
+	  int32_t offset = result_.frame_offset;
+	  result_ = r; 
+	  result_.frame_offset = offset;
+  }
 
   DecoderResult &GetResult() { return result_; }