|
@@ -17,67 +17,52 @@
|
|
*/
|
|
*/
|
|
#include "sherpa-ncnn/csrc/lstm-model.h"
|
|
#include "sherpa-ncnn/csrc/lstm-model.h"
|
|
|
|
|
|
-#include <iostream>
|
|
|
|
#include <utility>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include <vector>
|
|
|
|
|
|
namespace sherpa_ncnn {
|
|
namespace sherpa_ncnn {
|
|
|
|
|
|
-static void InitNet(ncnn::Net &net, const std::string ¶m,
|
|
|
|
- const std::string &bin) {
|
|
|
|
- if (net.load_param(param.c_str())) {
|
|
|
|
- std::cerr << "failed to load " << param << "\n";
|
|
|
|
- exit(-1);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if (net.load_model(bin.c_str())) {
|
|
|
|
- std::cerr << "failed to load " << bin << "\n";
|
|
|
|
- exit(-1);
|
|
|
|
- }
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
LstmModel::LstmModel(const ModelConfig &config)
|
|
LstmModel::LstmModel(const ModelConfig &config)
|
|
: num_threads_(config.num_threads) {
|
|
: num_threads_(config.num_threads) {
|
|
InitEncoder(config.encoder_param, config.encoder_bin);
|
|
InitEncoder(config.encoder_param, config.encoder_bin);
|
|
InitDecoder(config.decoder_param, config.decoder_bin);
|
|
InitDecoder(config.decoder_param, config.decoder_bin);
|
|
InitJoiner(config.joiner_param, config.joiner_bin);
|
|
InitJoiner(config.joiner_param, config.joiner_bin);
|
|
|
|
+
|
|
|
|
+ InitEncoderInputOutputIndexes();
|
|
|
|
+ InitDecoderInputOutputIndexes();
|
|
|
|
+ InitJoinerInputOutputIndexes();
|
|
}
|
|
}
|
|
|
|
|
|
std::pair<ncnn::Mat, std::vector<ncnn::Mat>> LstmModel::RunEncoder(
|
|
std::pair<ncnn::Mat, std::vector<ncnn::Mat>> LstmModel::RunEncoder(
|
|
ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
|
|
ncnn::Mat &features, const std::vector<ncnn::Mat> &states) {
|
|
- int32_t num_encoder_layers = 12;
|
|
|
|
- int32_t d_model = 512;
|
|
|
|
- int32_t rnn_hidden_size = 1024;
|
|
|
|
ncnn::Mat hx;
|
|
ncnn::Mat hx;
|
|
ncnn::Mat cx;
|
|
ncnn::Mat cx;
|
|
|
|
|
|
if (states.empty()) {
|
|
if (states.empty()) {
|
|
- hx.create(d_model, num_encoder_layers);
|
|
|
|
- cx.create(rnn_hidden_size, num_encoder_layers);
|
|
|
|
-
|
|
|
|
- hx.fill(0);
|
|
|
|
- cx.fill(0);
|
|
|
|
|
|
+ auto s = GetEncoderInitStates();
|
|
|
|
+ hx = s[0];
|
|
|
|
+ cx = s[1];
|
|
} else {
|
|
} else {
|
|
hx = states[0];
|
|
hx = states[0];
|
|
cx = states[1];
|
|
cx = states[1];
|
|
}
|
|
}
|
|
|
|
|
|
- ncnn::Mat feature_lengths(1);
|
|
|
|
- feature_lengths[0] = features.h;
|
|
|
|
|
|
+ ncnn::Mat feature_length(1);
|
|
|
|
+ feature_length[0] = features.h;
|
|
|
|
|
|
ncnn::Extractor encoder_ex = encoder_.create_extractor();
|
|
ncnn::Extractor encoder_ex = encoder_.create_extractor();
|
|
encoder_ex.set_num_threads(num_threads_);
|
|
encoder_ex.set_num_threads(num_threads_);
|
|
|
|
|
|
- encoder_ex.input("in0", features);
|
|
|
|
- encoder_ex.input("in1", feature_lengths);
|
|
|
|
- encoder_ex.input("in2", hx);
|
|
|
|
- encoder_ex.input("in3", cx);
|
|
|
|
|
|
+ encoder_ex.input(encoder_input_indexes_[0], features);
|
|
|
|
+ encoder_ex.input(encoder_input_indexes_[1], feature_length);
|
|
|
|
+ encoder_ex.input(encoder_input_indexes_[2], hx);
|
|
|
|
+ encoder_ex.input(encoder_input_indexes_[3], cx);
|
|
|
|
|
|
ncnn::Mat encoder_out;
|
|
ncnn::Mat encoder_out;
|
|
- encoder_ex.extract("out0", encoder_out);
|
|
|
|
|
|
+ encoder_ex.extract(encoder_output_indexes_[0], encoder_out);
|
|
|
|
|
|
- encoder_ex.extract("out2", hx);
|
|
|
|
- encoder_ex.extract("out3", cx);
|
|
|
|
|
|
+ encoder_ex.extract(encoder_output_indexes_[1], hx);
|
|
|
|
+ encoder_ex.extract(encoder_output_indexes_[2], cx);
|
|
|
|
|
|
std::vector<ncnn::Mat> next_states = {hx, cx};
|
|
std::vector<ncnn::Mat> next_states = {hx, cx};
|
|
|
|
|
|
@@ -89,8 +74,8 @@ ncnn::Mat LstmModel::RunDecoder(ncnn::Mat &decoder_input) {
|
|
decoder_ex.set_num_threads(num_threads_);
|
|
decoder_ex.set_num_threads(num_threads_);
|
|
|
|
|
|
ncnn::Mat decoder_out;
|
|
ncnn::Mat decoder_out;
|
|
- decoder_ex.input("in0", decoder_input);
|
|
|
|
- decoder_ex.extract("out0", decoder_out);
|
|
|
|
|
|
+ decoder_ex.input(decoder_input_indexes_[0], decoder_input);
|
|
|
|
+ decoder_ex.extract(decoder_output_indexes_[0], decoder_out);
|
|
decoder_out = decoder_out.reshape(decoder_out.w);
|
|
decoder_out = decoder_out.reshape(decoder_out.w);
|
|
|
|
|
|
return decoder_out;
|
|
return decoder_out;
|
|
@@ -99,11 +84,11 @@ ncnn::Mat LstmModel::RunDecoder(ncnn::Mat &decoder_input) {
|
|
ncnn::Mat LstmModel::RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out) {
|
|
ncnn::Mat LstmModel::RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out) {
|
|
auto joiner_ex = joiner_.create_extractor();
|
|
auto joiner_ex = joiner_.create_extractor();
|
|
joiner_ex.set_num_threads(num_threads_);
|
|
joiner_ex.set_num_threads(num_threads_);
|
|
- joiner_ex.input("in0", encoder_out);
|
|
|
|
- joiner_ex.input("in1", decoder_out);
|
|
|
|
|
|
+ joiner_ex.input(joiner_input_indexes_[0], encoder_out);
|
|
|
|
+ joiner_ex.input(joiner_input_indexes_[1], decoder_out);
|
|
|
|
|
|
ncnn::Mat joiner_out;
|
|
ncnn::Mat joiner_out;
|
|
- joiner_ex.extract("out0", joiner_out);
|
|
|
|
|
|
+ joiner_ex.extract(joiner_output_indexes_[0], joiner_out);
|
|
return joiner_out;
|
|
return joiner_out;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -124,4 +109,80 @@ void LstmModel::InitJoiner(const std::string &joiner_param,
|
|
InitNet(joiner_, joiner_param, joiner_bin);
|
|
InitNet(joiner_, joiner_param, joiner_bin);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+std::vector<ncnn::Mat> LstmModel::GetEncoderInitStates() const {
|
|
|
|
+ int32_t num_encoder_layers = 12;
|
|
|
|
+ int32_t d_model = 512;
|
|
|
|
+ int32_t rnn_hidden_size = 1024;
|
|
|
|
+
|
|
|
|
+ auto hx = ncnn::Mat(d_model, num_encoder_layers);
|
|
|
|
+ auto cx = ncnn::Mat(rnn_hidden_size, num_encoder_layers);
|
|
|
|
+
|
|
|
|
+ hx.fill(0);
|
|
|
|
+ cx.fill(0);
|
|
|
|
+
|
|
|
|
+ return {hx, cx};
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void LstmModel::InitEncoderInputOutputIndexes() {
|
|
|
|
+ // input indexes map
|
|
|
|
+ // [0] -> in0, features,
|
|
|
|
+ // [1] -> in1, features_length
|
|
|
|
+ // [2] -> in2, hx
|
|
|
|
+ // [3] -> in3, cx
|
|
|
|
+ encoder_input_indexes_.resize(4);
|
|
|
|
+
|
|
|
|
+ // output indexes map
|
|
|
|
+ // [0] -> out0, encoder_out
|
|
|
|
+ // [1] -> out2, hx
|
|
|
|
+ // [2] -> out3, cx
|
|
|
|
+ encoder_output_indexes_.resize(3);
|
|
|
|
+ const auto &blobs = encoder_.blobs();
|
|
|
|
+ for (int32_t i = 0; i != blobs.size(); ++i) {
|
|
|
|
+ const auto &b = blobs[i];
|
|
|
|
+ if (b.name == "in0") encoder_input_indexes_[0] = i;
|
|
|
|
+ if (b.name == "in1") encoder_input_indexes_[1] = i;
|
|
|
|
+ if (b.name == "in2") encoder_input_indexes_[2] = i;
|
|
|
|
+ if (b.name == "in3") encoder_input_indexes_[3] = i;
|
|
|
|
+ if (b.name == "out0") encoder_output_indexes_[0] = i;
|
|
|
|
+ if (b.name == "out2") encoder_output_indexes_[1] = i;
|
|
|
|
+ if (b.name == "out3") encoder_output_indexes_[2] = i;
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void LstmModel::InitDecoderInputOutputIndexes() {
|
|
|
|
+ // input indexes map
|
|
|
|
+ // [0] -> in0, decoder_input,
|
|
|
|
+ decoder_input_indexes_.resize(1);
|
|
|
|
+
|
|
|
|
+ // output indexes map
|
|
|
|
+ // [0] -> out0, decoder_out,
|
|
|
|
+ decoder_output_indexes_.resize(1);
|
|
|
|
+
|
|
|
|
+ const auto &blobs = decoder_.blobs();
|
|
|
|
+ for (int32_t i = 0; i != blobs.size(); ++i) {
|
|
|
|
+ const auto &b = blobs[i];
|
|
|
|
+ if (b.name == "in0") decoder_input_indexes_[0] = i;
|
|
|
|
+ if (b.name == "out0") decoder_output_indexes_[0] = i;
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+void LstmModel::InitJoinerInputOutputIndexes() {
|
|
|
|
+ // input indexes map
|
|
|
|
+ // [0] -> in0, encoder_input,
|
|
|
|
+ // [1] -> in1, decoder_input,
|
|
|
|
+ joiner_input_indexes_.resize(2);
|
|
|
|
+
|
|
|
|
+ // output indexes map
|
|
|
|
+ // [0] -> out0, joiner_out,
|
|
|
|
+ joiner_output_indexes_.resize(1);
|
|
|
|
+
|
|
|
|
+ const auto &blobs = joiner_.blobs();
|
|
|
|
+ for (int32_t i = 0; i != blobs.size(); ++i) {
|
|
|
|
+ const auto &b = blobs[i];
|
|
|
|
+ if (b.name == "in0") joiner_input_indexes_[0] = i;
|
|
|
|
+ if (b.name == "in1") joiner_input_indexes_[1] = i;
|
|
|
|
+ if (b.name == "out0") joiner_output_indexes_[0] = i;
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
} // namespace sherpa_ncnn
|
|
} // namespace sherpa_ncnn
|