|
@@ -20,65 +20,50 @@
|
|
|
#define SHERPA_NCNN_CSRC_LSTM_MODEL_H_
|
|
|
|
|
|
#include <string>
|
|
|
+#include <utility>
|
|
|
+#include <vector>
|
|
|
|
|
|
#include "net.h" // NOLINT
|
|
|
+#include "sherpa-ncnn/csrc/model.h"
|
|
|
|
|
|
namespace sherpa_ncnn {
|
|
|
|
|
|
-class LstmModel {
|
|
|
+class LstmModel : public Model {
|
|
|
public:
|
|
|
- /**
|
|
|
- * @param encoder_param Path to encoder.ncnn.param
|
|
|
- * @param encoder_bin Path to encoder.ncnn.bin
|
|
|
- * @param decoder_param Path to decoder.ncnn.param
|
|
|
- * @param decoder_bin Path to decoder.ncnn.bin
|
|
|
- * @param joiner_param Path to joiner.ncnn.param
|
|
|
- * @param joiner_bin Path to joiner.ncnn.bin
|
|
|
- * @param num_threads Number of threads to use when running the network
|
|
|
- */
|
|
|
- LstmModel(const std::string &encoder_param, const std::string &encoder_bin,
|
|
|
- const std::string &decoder_param, const std::string &decoder_bin,
|
|
|
- const std::string &joiner_param, const std::string &joiner_bin,
|
|
|
- int32_t num_threads);
|
|
|
+ explicit LstmModel(const ModelConfig &config);
|
|
|
|
|
|
/** Run the encoder network.
|
|
|
*
|
|
|
* @param features A 2-d mat of shape (num_frames, feature_dim).
|
|
|
* Note: features.w = feature_dim.
|
|
|
* features.h = num_frames.
|
|
|
- * @param hx Hidden state of the LSTM model. You can leave it to empty
|
|
|
- * on the first invocation. It is changed in-place.
|
|
|
+ * @param states Contains two tensors:
|
|
|
+ * - hx Hidden state of the LSTM model. You can leave it to empty
|
|
|
+ * on the first invocation. It is changed in-place.
|
|
|
*
|
|
|
- * @param cx Hidden cell state of the LSTM model. You can leave it to empty
|
|
|
- * on the first invocation. It is changed in-place.
|
|
|
+ * - cx Hidden cell state of the LSTM model. You can leave it to
|
|
|
+ * empty on the first invocation. It is changed in-place.
|
|
|
*
|
|
|
- * @return Return the output of the encoder. Its shape is
|
|
|
- * (num_out_frames, encoder_dim).
|
|
|
- * Note: ans.w == encoder_dim; ans.h == num_out_frames
|
|
|
- */
|
|
|
- ncnn::Mat RunEncoder(ncnn::Mat &features, ncnn::Mat *hx, ncnn::Mat *cx);
|
|
|
-
|
|
|
- /** Run the decoder network.
|
|
|
+ * - Note: on the first invocation, you can pass an empty vector.
|
|
|
*
|
|
|
- * @param decoder_input A mat of shape (context_size,). Note: Its underlying
|
|
|
- * content consists of integers, though its type is
|
|
|
- * float.
|
|
|
+ * @return Return a pair containing:
|
|
|
+ * - the output of the encoder. Its shape is (num_out_frames, encoder_dim).
|
|
|
+ * Note: ans.w == encoder_dim; ans.h == num_out_frames
|
|
|
*
|
|
|
- * @return Return a mat of shape (decoder_dim,)
|
|
|
+ * - next_states, a vector containing hx and cx for the next invocation
|
|
|
*/
|
|
|
- ncnn::Mat RunDecoder(ncnn::Mat &decoder_input);
|
|
|
+ std::pair<ncnn::Mat, std::vector<ncnn::Mat>> RunEncoder(
|
|
|
+ ncnn::Mat &features, const std::vector<ncnn::Mat> &states) override;
|
|
|
|
|
|
- /** Run the joiner network.
|
|
|
- *
|
|
|
- * @param encoder_out A mat of shape (encoder_dim,)
|
|
|
- * @param decoder_out A mat of shape (decoder_dim,)
|
|
|
- *
|
|
|
- * @return Return the joiner output which is of shape (vocab_size,)
|
|
|
- */
|
|
|
- ncnn::Mat RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out);
|
|
|
+ ncnn::Mat RunDecoder(ncnn::Mat &decoder_input) override;
|
|
|
+
|
|
|
+ ncnn::Mat RunJoiner(ncnn::Mat &encoder_out, ncnn::Mat &decoder_out) override;
|
|
|
+
|
|
|
+ int32_t Segment() const override { return 9; }
|
|
|
|
|
|
- int32_t ContextSize() const { return 2; }
|
|
|
- int32_t BlankId() const { return 0; }
|
|
|
+ // Advance the feature extract by this number of frames after
|
|
|
+ // running the encoder network
|
|
|
+ int32_t Offset() const override { return 4; }
|
|
|
|
|
|
private:
|
|
|
void InitEncoder(const std::string &encoder_param,
|