소스 검색

First working version of offline greedy search (#2)

Fangjun Kuang 2 년 전
부모
커밋
6b7180a845

+ 9 - 1
sherpa-ncnn/csrc/CMakeLists.txt

@@ -1,7 +1,15 @@
+include_directories(${CMAKE_SOURCE_DIR})
+
 add_executable(online-fbank-test online-fbank-test.cc)
 target_link_libraries(online-fbank-test kaldi-native-fbank-core)
 
 add_executable(sherpa-ncnn
   sherpa-ncnn.cc
+  symbol-table.cc
+  wave-reader.cc
+)
+
+target_link_libraries(sherpa-ncnn
+  ncnn
+  kaldi-native-fbank-core
 )
-target_link_libraries(sherpa-ncnn ncnn)

+ 112 - 0
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -16,7 +16,11 @@
  * limitations under the License.
  */
 
+#include "kaldi-native-fbank/csrc/online-feature.h"
 #include "net.h"
+#include "sherpa-ncnn/csrc/symbol-table.h"
+#include "sherpa-ncnn/csrc/wave-reader.h"
+#include <algorithm>
 #include <iostream>
 
 static void InitNet(ncnn::Net &net, const std::string &param,
@@ -52,11 +56,119 @@ int main() {
   std::string joiner_model =
       "bar/joiner_jit_trace-iter-468000-avg-16-pnnx.ncnn.bin";
 
+  std::string wav1 = "./test_wavs/1089-134686-0001.wav";
+  // wav1 = "./test_wavs/1221-135766-0001.wav";
+  wav1 = "./test_wavs/1221-135766-0002.wav";
+
   ncnn::Net encoder_net;
+  encoder_net.opt.use_packing_layout = false;
+  encoder_net.opt.use_fp16_storage = false;
+
   ncnn::Net decoder_net;
+  decoder_net.opt.use_packing_layout = false;
+
   ncnn::Net joiner_net;
+  joiner_net.opt.use_packing_layout = false;
 
   InitNet(encoder_net, encoder_param, encoder_model);
   InitNet(decoder_net, decoder_param, decoder_model);
   InitNet(joiner_net, joiner_param, joiner_model);
+
+  std::vector<float> samples = sherpa_ncnn::ReadWave(wav1, 16000);
+
+  knf::FbankOptions opts;
+  opts.frame_opts.dither = 0;
+  opts.frame_opts.snip_edges = false;
+  opts.frame_opts.samp_freq = 16000;
+
+  opts.mel_opts.num_bins = 80;
+
+  knf::OnlineFbank fbank(opts);
+  fbank.AcceptWaveform(16000, samples.data(), samples.size());
+  fbank.InputFinished();
+
+  int32_t num_encoder_layers = 12;
+  int32_t batch_size = 1;
+  int32_t d_model = 512;
+  int32_t rnn_hidden_size = 1024;
+
+  ncnn::Mat h0;
+  h0.create(d_model, num_encoder_layers);
+  ncnn::Mat c0;
+  c0.create(rnn_hidden_size, num_encoder_layers);
+  h0.fill(0);
+  c0.fill(0);
+
+  int32_t feature_dim = 80;
+  ncnn::Mat features;
+  features.create(feature_dim, fbank.NumFramesReady());
+
+  for (int32_t i = 0; i != fbank.NumFramesReady(); ++i) {
+    const float *f = fbank.GetFrame(i);
+    std::copy(f, f + feature_dim, features.row(i));
+  }
+
+  ncnn::Mat feature_lengths(1);
+  feature_lengths[0] = features.h;
+
+  ncnn::Extractor encoder_ex = encoder_net.create_extractor();
+
+  encoder_ex.input("in0", features);
+  encoder_ex.input("in1", feature_lengths);
+  encoder_ex.input("in2", h0);
+  encoder_ex.input("in3", c0);
+
+  ncnn::Mat encoder_out;
+  encoder_ex.extract("out0", encoder_out);
+
+  int32_t context_size = 2;
+  int32_t blank_id = 0;
+
+  std::vector<int32_t> hyp(context_size, blank_id);
+  ncnn::Mat decoder_input(context_size);
+  static_cast<int32_t *>(decoder_input)[0] = blank_id + 1;
+  static_cast<int32_t *>(decoder_input)[1] = blank_id + 2;
+  decoder_input.fill(blank_id);
+
+  ncnn::Extractor decoder_ex = decoder_net.create_extractor();
+  ncnn::Mat decoder_out;
+  decoder_ex.input("in0", decoder_input);
+  decoder_ex.extract("out0", decoder_out);
+  decoder_out = decoder_out.reshape(decoder_out.w);
+
+  ncnn::Mat joiner_out;
+  for (int32_t t = 0; t != encoder_out.h; ++t) {
+    ncnn::Mat encoder_out_t(512, encoder_out.row(t));
+
+    auto joiner_ex = joiner_net.create_extractor();
+    joiner_ex.input("in0", encoder_out_t);
+    joiner_ex.input("in1", decoder_out);
+
+    joiner_ex.extract("out0", joiner_out);
+
+    auto y = static_cast<int32_t>(
+        std::distance(static_cast<const float *>(joiner_out),
+                      std::max_element(static_cast<const float *>(joiner_out),
+                                       static_cast<const float *>(joiner_out) +
+                                           joiner_out.w)));
+
+    if (y != blank_id) {
+      static_cast<int32_t *>(decoder_input)[0] = hyp.back();
+      static_cast<int32_t *>(decoder_input)[1] = y;
+      hyp.push_back(y);
+
+      decoder_ex = decoder_net.create_extractor();
+      decoder_ex.input("in0", decoder_input);
+      decoder_ex.extract("out0", decoder_out);
+      decoder_out = decoder_out.reshape(decoder_out.w);
+    }
+  }
+  std::string text;
+  sherpa_ncnn::SymbolTable sym("./tokens.txt");
+  for (int32_t i = context_size; i != hyp.size(); ++i) {
+    text += sym[hyp[i]];
+  }
+
+  fprintf(stderr, "%s\n", text.c_str());
+  return 0;
 }

+ 78 - 0
sherpa-ncnn/csrc/symbol-table.cc

@@ -0,0 +1,78 @@
+/**
+ * Copyright      2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "sherpa-ncnn/csrc/symbol-table.h"
+
+#include <cassert>
+#include <fstream>
+#include <sstream>
+
+namespace sherpa_ncnn {
+
+SymbolTable::SymbolTable(const std::string &filename) {
+  std::ifstream is(filename);
+  std::string sym;
+  int32_t id;
+  while (is >> sym >> id) {
+    if (sym.size() >= 3) {
+      // For BPE-based models, we replace ▁ with a space
+      // Unicode 9601, hex 0x2581, utf8 0xe29681
+      const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
+      if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) {
+        sym = sym.replace(0, 3, " ");
+      }
+    }
+
+    assert(!sym.empty());
+    assert(sym2id_.count(sym) == 0);
+    assert(id2sym_.count(id) == 0);
+
+    sym2id_.insert({sym, id});
+    id2sym_.insert({id, sym});
+  }
+  assert(is.eof());
+}
+
+std::string SymbolTable::ToString() const {
+  std::ostringstream os;
+  char sep = ' ';
+  for (const auto &p : sym2id_) {
+    os << p.first << sep << p.second << "\n";
+  }
+  return os.str();
+}
+
+const std::string &SymbolTable::operator[](int32_t id) const {
+  return id2sym_.at(id);
+}
+
+int32_t SymbolTable::operator[](const std::string &sym) const {
+  return sym2id_.at(sym);
+}
+
+bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; }
+
+bool SymbolTable::contains(const std::string &sym) const {
+  return sym2id_.count(sym) != 0;
+}
+
+std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) {
+  return os << symbol_table.ToString();
+}
+
+} // namespace sherpa_ncnn

+ 62 - 0
sherpa-ncnn/csrc/symbol-table.h

@@ -0,0 +1,62 @@
+/**
+ * Copyright      2022  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_
+#define SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_
+
+#include <string>
+#include <unordered_map>
+
+namespace sherpa_ncnn {
+
+/// It manages mapping between symbols and integer IDs.
+class SymbolTable {
+public:
+  SymbolTable() = default;
+  /// Construct a symbol table from a file.
+  /// Each line in the file contains two fields:
+  ///
+  ///    sym ID
+  ///
+  /// Fields are separated by space(s).
+  explicit SymbolTable(const std::string &filename);
+
+  /// Return a string representation of this symbol table
+  std::string ToString() const;
+
+  /// Return the symbol corresponding to the given ID.
+  const std::string &operator[](int32_t id) const;
+  /// Return the ID corresponding to the given symbol.
+  int32_t operator[](const std::string &sym) const;
+
+  /// Return true if there is a symbol with the given ID.
+  bool contains(int32_t id) const;
+
+  /// Return true if there is a given symbol in the symbol table.
+  bool contains(const std::string &sym) const;
+
+private:
+  std::unordered_map<std::string, int32_t> sym2id_;
+  std::unordered_map<int32_t, std::string> id2sym_;
+};
+
+std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table);
+
+} // namespace sherpa_ncnn
+
+#endif // SHERPA_NCNN_CSRC_SYMBOL_TABLE_H_

+ 107 - 0
sherpa-ncnn/csrc/wave-reader.cc

@@ -0,0 +1,107 @@
+/**
+ * Copyright      2021  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cassert>
+#include <fstream>
+#include <iostream>
+#include <utility>
+#include <vector>
+
+#include "sherpa-ncnn/csrc/wave-reader.h"
+
+namespace sherpa_ncnn {
+namespace {
+// see http://soundfile.sapp.org/doc/WaveFormat/
+//
+// Note: We assume little endian here
+// TODO(fangjun): Support big endian
+struct WaveHeader {
+  void Validate() const {
+    //                       F F I R
+    assert(chunk_id == 0x46464952);
+    assert(chunk_size == 36 + subchunk2_size);
+    //                     E V A W
+    assert(format == 0x45564157);
+    assert(subchunk1_id == 0x20746d66);
+    assert(subchunk1_size == 16); // 16 for PCM
+    assert(audio_format == 1);    // 1 for PCM
+    assert(num_channels == 1);    // we support only single channel for now
+    assert(byte_rate == sample_rate * num_channels * bits_per_sample / 8);
+    assert(block_align == num_channels * bits_per_sample / 8);
+    assert(bits_per_sample == 16); // we support only 16 bits per sample
+  }
+
+  int32_t chunk_id;
+  int32_t chunk_size;
+  int32_t format;
+  int32_t subchunk1_id;
+  int32_t subchunk1_size;
+  int16_t audio_format;
+  int16_t num_channels;
+  int32_t sample_rate;
+  int32_t byte_rate;
+  int16_t block_align;
+  int16_t bits_per_sample;
+  int32_t subchunk2_id;
+  int32_t subchunk2_size;
+};
+static_assert(sizeof(WaveHeader) == 44, "");
+
+// Read a wave file of mono-channel.
+// Return its samples normalized to the range [-1, 1).
+std::vector<float> ReadWaveImpl(std::istream &is, float *sample_rate) {
+  WaveHeader header;
+  is.read(reinterpret_cast<char *>(&header), sizeof(header));
+  assert((bool)is);
+
+  header.Validate();
+
+  *sample_rate = header.sample_rate;
+
+  // header.subchunk2_size contains the number of bytes in the data.
+  // As we assume each sample contains two bytes, so it is divided by 2 here
+  std::vector<int16_t> samples(header.subchunk2_size / 2);
+
+  is.read(reinterpret_cast<char *>(samples.data()), header.subchunk2_size);
+
+  assert((bool)is);
+
+  std::vector<float> ans(samples.size());
+  for (int32_t i = 0; i != ans.size(); ++i) {
+    ans[i] = samples[i] / 32768.;
+  }
+
+  return ans;
+}
+
+} // namespace
+
+std::vector<float> ReadWave(const std::string &filename,
+                            float expected_sample_rate) {
+  std::ifstream is(filename, std::ifstream::binary);
+  float sample_rate;
+  auto samples = ReadWaveImpl(is, &sample_rate);
+  if (expected_sample_rate != sample_rate) {
+    std::cerr << "Expected sample rate: " << expected_sample_rate
+              << ". Given: " << sample_rate << ".\n";
+    exit(-1);
+  }
+  return samples;
+}
+
+} // namespace sherpa_ncnn

+ 41 - 0
sherpa-ncnn/csrc/wave-reader.h

@@ -0,0 +1,41 @@
+/**
+ * Copyright      2021  Xiaomi Corporation (authors: Fangjun Kuang)
+ *
+ * See LICENSE for clarification regarding multiple authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SHERPA_NCNN_CSRC_WAVE_READER_H_
+#define SHERPA_NCNN_CSRC_WAVE_READER_H_
+
+#include <istream>
+#include <string>
+#include <vector>
+
+namespace sherpa_ncnn {
+
+/** Read a wave file with expected sample rate.
+
+    @param filename Path to a wave file. It MUST be single channel, PCM encoded.
+    @param expected_sample_rate  Expected sample rate of the wave file. If the
+                               sample rate don't match, it throws an exception.
+
+    @return Return wave samples normalized to the range [-1, 1).
+ */
+std::vector<float> ReadWave(const std::string &filename,
+                            float expected_sample_rate);
+
+} // namespace sherpa_ncnn
+
+#endif // SHERPA_NCNN_CSRC_WAVE_READER_H_