Răsfoiți Sursa

Support config by environment variables. (#154)

Winlin 2 ani în urmă
părinte
comite
8560eb3f00

+ 198 - 66
ffmpeg-examples/sherpa-ncnn-ffmpeg.cc

@@ -20,8 +20,8 @@
 #include <stdlib.h>
 #include <string.h>
 
-#include <string>
 #include <cctype>  // std::tolower
+#include <string>
 
 #include "sherpa-ncnn/csrc/display.h"
 #include "sherpa-ncnn/csrc/recognizer.h"
@@ -83,11 +83,11 @@ static AVCodecContext *dec_ctx;
 AVFilterContext *buffersink_ctx;
 AVFilterContext *buffersrc_ctx;
 AVFilterGraph *filter_graph;
-static int audio_stream_index = -1;
+static int32_t audio_stream_index = -1;
 
-static int open_input_file(const char *filename) {
+static int32_t FFmpegOpenInputFile(const char *filename) {
   const AVCodec *dec;
-  int ret;
+  int32_t ret;
 
   if ((ret = avformat_open_input(&fmt_ctx, filename, NULL, NULL)) < 0) {
     av_log(NULL, AV_LOG_ERROR, "Cannot open input file %s\n", filename);
@@ -123,16 +123,16 @@ static int open_input_file(const char *filename) {
   return 0;
 }
 
-static int init_filters(const char *filters_descr) {
+static int32_t FFmpegInitFilters(const char *filters_descr) {
   char args[512];
-  int ret = 0;
+  int32_t ret = 0;
   const AVFilter *abuffersrc = avfilter_get_by_name("abuffer");
   const AVFilter *abuffersink = avfilter_get_by_name("abuffersink");
   AVFilterInOut *outputs = avfilter_inout_alloc();
   AVFilterInOut *inputs = avfilter_inout_alloc();
   static const enum AVSampleFormat out_sample_fmts[] = {AV_SAMPLE_FMT_S16,
                                                         AV_SAMPLE_FMT_NONE};
-  static const int out_sample_rates[] = {16000, -1};
+  static const int32_t out_sample_rates[] = {16000, -1};
   const AVFilterLink *outlink;
   AVRational time_base = fmt_ctx->streams[audio_stream_index]->time_base;
 
@@ -239,15 +239,14 @@ end:
   return ret;
 }
 
-static void sherpa_decode_frame(const AVFrame *frame,
-                                const sherpa_ncnn::Recognizer &recognizer,
-                                sherpa_ncnn::Stream *s,
-                                sherpa_ncnn::Display &display,
-                                std::string &last_text,
-                                int32_t &segment_index) {
+static void FFmpegDecodeFrame(const AVFrame *frame,
+                              const sherpa_ncnn::Recognizer &recognizer,
+                              sherpa_ncnn::Stream *s,
+                              sherpa_ncnn::Display &display,
+                              std::string &last_text, int32_t &segment_index) {
 #define N 3200  // 0.2 s. Sample rate is fixed to 16 kHz
   static float samples[N];
-  static int nb_samples = 0;
+  static int32_t nb_samples = 0;
   const int16_t *p = (int16_t *)frame->data[0];
 
   if (frame->nb_samples + nb_samples >= N) {
@@ -280,12 +279,12 @@ static void sherpa_decode_frame(const AVFrame *frame,
     nb_samples = 0;
   }
 
-  for (int i = 0; i < frame->nb_samples; i++) {
+  for (int32_t i = 0; i < frame->nb_samples; i++) {
     samples[nb_samples++] = p[i] / 32768.;
   }
 }
 
-static inline char *__av_err2str(int errnum) {
+static inline char *FFmpegAvError2String(int32_t errnum) {
   static char str[AV_ERROR_MAX_STRING_SIZE];
   memset(str, 0, sizeof(str));
   return av_make_error_string(str, AV_ERROR_MAX_STRING_SIZE, errnum);
@@ -297,11 +296,153 @@ static void Handler(int32_t sig) {
   raise(sig);
 };
 
-int main(int argc, char **argv) {
-  if (argc < 9 || argc > 11) {
+#define SET_CONFIG_BY_ENV(config, key, required) \
+  config = "";                                   \
+  if (getenv(key)) {                             \
+    config = getenv(key);                        \
+    if (required) {                              \
+      parsed_required_envs++;                    \
+    }                                            \
+  }
+
+static int32_t ParseConfigFromENV(sherpa_ncnn::RecognizerConfig *config,
+                                  std::string *input_url) {
+  int32_t parsed_required_envs = 0;
+
+  sherpa_ncnn::ModelConfig &mc = config->model_config;
+  SET_CONFIG_BY_ENV(mc.tokens, "SHERPA_NCNN_TOKENS", true);
+  SET_CONFIG_BY_ENV(mc.encoder_param, "SHERPA_NCNN_ENCODER_PARAM", true);
+  SET_CONFIG_BY_ENV(mc.encoder_bin, "SHERPA_NCNN_ENCODER_BIN", true);
+  SET_CONFIG_BY_ENV(mc.decoder_param, "SHERPA_NCNN_DECODER_PARAM", true);
+  SET_CONFIG_BY_ENV(mc.decoder_bin, "SHERPA_NCNN_DECODER_BIN", true);
+  SET_CONFIG_BY_ENV(mc.joiner_param, "SHERPA_NCNN_JOINER_PARAM", true);
+  SET_CONFIG_BY_ENV(mc.joiner_bin, "SHERPA_NCNN_JOINER_BIN", true);
+  SET_CONFIG_BY_ENV(*input_url, "SHERPA_NCNN_INPUT_URL", true);
+
+  std::string val;
+  SET_CONFIG_BY_ENV(val, "SHERPA_NCNN_NUM_THREADS", false);
+  if (!val.empty()) {
+    if (atoi(val.c_str()) <= 0) {
+      fprintf(stderr, "Invalid SHERPA_NCNN_NUM_THREADS=%s\n", val.c_str());
+      return -1;
+    }
+    mc.encoder_opt.num_threads = atoi(val.c_str());
+    mc.decoder_opt.num_threads = atoi(val.c_str());
+    mc.joiner_opt.num_threads = atoi(val.c_str());
+  }
+
+  SET_CONFIG_BY_ENV(val, "SHERPA_NCNN_METHOD", false);
+  if (!val.empty()) {
+    if (val != "greedy_search" && val != "modified_beam_search") {
+      fprintf(stderr, "Invalid SHERPA_NCNN_METHOD=%s\n", val.c_str());
+      return -1;
+    }
+    config->decoder_config.method = val;
+  }
+
+  SET_CONFIG_BY_ENV(val, "SHERPA_NCNN_ENABLE_ENDPOINT", false);
+  if (!val.empty()) {
+    std::transform(val.begin(), val.end(), val.begin(),
+                   [](auto c) { return std::tolower(c); });
+    config->enable_endpoint = val == "true" || val == "on";
+  }
+
+  SET_CONFIG_BY_ENV(val, "SHERPA_NCNN_RULE1_MIN_TRAILING_SILENCE", false);
+  if (!val.empty()) {
+    if (::atof(val.c_str()) <= 0) {
+      fprintf(stderr, "Invalid SHERPA_NCNN_RULE1_MIN_TRAILING_SILENCE=%s\n",
+              val.c_str());
+      return -1;
+    }
+    config->endpoint_config.rule1.min_trailing_silence = ::atof(val.c_str());
+  }
+
+  SET_CONFIG_BY_ENV(val, "SHERPA_NCNN_RULE2_MIN_TRAILING_SILENCE", false);
+  if (!val.empty()) {
+    if (::atof(val.c_str()) <= 0) {
+      fprintf(stderr, "Invalid SHERPA_NCNN_RULE2_MIN_TRAILING_SILENCE=%s\n",
+              val.c_str());
+      return -1;
+    }
+    config->endpoint_config.rule2.min_trailing_silence = ::atof(val.c_str());
+  }
+
+  SET_CONFIG_BY_ENV(val, "SHERPA_NCNN_RULE3_MIN_UTTERANCE_LENGTH", false);
+  if (!val.empty()) {
+    if (::atof(val.c_str()) <= 0) {
+      fprintf(stderr, "Invalid SHERPA_NCNN_RULE3_MIN_UTTERANCE_LENGTH=%s\n",
+              val.c_str());
+      return -1;
+    }
+    config->endpoint_config.rule3.min_utterance_length = ::atof(val.c_str());
+  }
+
+  return parsed_required_envs;
+}
+
+static void SetDefaultConfigurations(sherpa_ncnn::RecognizerConfig *config) {
+  int32_t num_threads = 4;
+  config->model_config.encoder_opt.num_threads = num_threads;
+  config->model_config.decoder_opt.num_threads = num_threads;
+  config->model_config.joiner_opt.num_threads = num_threads;
+
+  config->enable_endpoint = true;
+  config->endpoint_config.rule1.min_trailing_silence = 2.4;
+  config->endpoint_config.rule2.min_trailing_silence = 1.2;
+  config->endpoint_config.rule3.min_utterance_length = 300;
+
+  const float expected_sampling_rate = 16000;
+  config->feat_config.sampling_rate = expected_sampling_rate;
+  config->feat_config.feature_dim = 80;
+}
+
+static int32_t OverwriteConfigByCLI(int32_t argc, char **argv,
+                                    sherpa_ncnn::RecognizerConfig *config,
+                                    std::string *input_url) {
+  if (argc > 1) config->model_config.tokens = argv[1];
+  if (argc > 2) config->model_config.encoder_param = argv[2];
+  if (argc > 3) config->model_config.encoder_bin = argv[3];
+  if (argc > 4) config->model_config.decoder_param = argv[4];
+  if (argc > 5) config->model_config.decoder_bin = argv[5];
+  if (argc > 6) config->model_config.joiner_param = argv[6];
+  if (argc > 7) config->model_config.joiner_bin = argv[7];
+  if (argc > 8) *input_url = argv[8];
+  if (argc >= 10 && atoi(argv[9]) > 0) {
+    int32_t num_threads = atoi(argv[9]);
+    config->model_config.encoder_opt.num_threads = num_threads;
+    config->model_config.decoder_opt.num_threads = num_threads;
+    config->model_config.joiner_opt.num_threads = num_threads;
+  }
+
+  if (argc == 11) {
+    std::string val = argv[10];
+    if (val != "greedy_search" && val != "modified_beam_search") {
+      fprintf(stderr, "Invalid SHERPA_NCNN_METHOD=%s\n", val.c_str());
+      return -1;
+    }
+    config->decoder_config.method = val;
+  }
+
+  return 0;
+}
+
+int32_t main(int32_t argc, char **argv) {
+  // Set the default values for config.
+  sherpa_ncnn::RecognizerConfig config;
+  SetDefaultConfigurations(&config);
+
+  // Load and overwrite config from environment variables.
+  std::string input_url;
+  int32_t parsed_required_envs = ParseConfigFromENV(&config, &input_url);
+  if (parsed_required_envs < 0) {
+    exit(-1);
+  }
+
+  // Error if not set by neither environment variables nor CLI.
+  if (parsed_required_envs < 8 && (argc < 9 || argc > 11)) {
     const char *usage = R"usage(
 Usage:
-  ./bin/sherpa-ncnn-microphone \
+  ./bin/sherpa-ncnn-ffmpeg \
     /path/to/tokens.txt \
     /path/to/encoder.ncnn.param \
     /path/to/encoder.ncnn.bin \
@@ -312,6 +453,23 @@ Usage:
     ffmpeg-input-url \
     [num_threads] [decode_method, can be greedy_search/modified_beam_search]
 
+Or configure by environment variables:
+  SHERPA_NCNN_TOKENS=/path/to/tokens.txt \
+  SHERPA_NCNN_ENCODER_PARAM=/path/to/encoder_jit_trace-pnnx.ncnn.param  \
+  SHERPA_NCNN_ENCODER_BIN=/path/to/encoder_jit_trace-pnnx.ncnn.bin \
+  SHERPA_NCNN_DECODER_PARAM=/path/to/decoder_jit_trace-pnnx.ncnn.param \
+  SHERPA_NCNN_DECODER_BIN=/path/to/decoder_jit_trace-pnnx.ncnn.bin \
+  SHERPA_NCNN_JOINER_PARAM=/path/to/joiner_jit_trace-pnnx.ncnn.param  \
+  SHERPA_NCNN_JOINER_BIN=/path/to/joiner_jit_trace-pnnx.ncnn.bin \
+  SHERPA_NCNN_INPUT_URL=ffmpeg-input-url \
+  SHERPA_NCNN_NUM_THREADS=4 \
+  SHERPA_NCNN_METHOD=greedy_search|modified_beam_search \
+  SHERPA_NCNN_ENABLE_ENDPOINT=on|off \
+  SHERPA_NCNN_RULE1_MIN_TRAILING_SILENCE=2.4 \
+  SHERPA_NCNN_RULE2_MIN_TRAILING_SILENCE=1.2 \
+  SHERPA_NCNN_RULE3_MIN_UTTERANCE_LENGTH=300 \
+  ./bin/sherpa-ncnn-ffmpeg
+
 Please refer to
 https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
 for a list of pre-trained models to download.
@@ -323,6 +481,17 @@ for a list of pre-trained models to download.
   }
   signal(SIGINT, Handler);
 
+  // Overwrite the config by CLI.
+  if (OverwriteConfigByCLI(argc, argv, &config, &input_url)) {
+    exit(-1);
+  }
+
+  fprintf(stderr, "%s\n", config.ToString().c_str());
+
+  sherpa_ncnn::Recognizer recognizer(config);
+  auto s = recognizer.CreateStream();
+
+  // Initialize FFmpeg framework.
   AVPacket *packet = av_packet_alloc();
   AVFrame *frame = av_frame_alloc();
   AVFrame *filt_frame = av_frame_alloc();
@@ -331,55 +500,18 @@ for a list of pre-trained models to download.
     exit(1);
   }
 
-  sherpa_ncnn::RecognizerConfig config;
-  config.model_config.tokens = argv[1];
-  config.model_config.encoder_param = argv[2];
-  config.model_config.encoder_bin = argv[3];
-  config.model_config.decoder_param = argv[4];
-  config.model_config.decoder_bin = argv[5];
-  config.model_config.joiner_param = argv[6];
-  config.model_config.joiner_bin = argv[7];
-  int32_t num_threads = 4;
-  if (argc >= 9 && atoi(argv[8]) > 0) {
-    num_threads = atoi(argv[8]);
-  }
-  config.model_config.encoder_opt.num_threads = num_threads;
-  config.model_config.decoder_opt.num_threads = num_threads;
-  config.model_config.joiner_opt.num_threads = num_threads;
-
-  const float expected_sampling_rate = 16000;
-  if (argc == 11) {
-    std::string method = argv[10];
-    if (method.compare("greedy_search") ||
-        method.compare("modified_beam_search")) {
-      config.decoder_config.method = method;
-    }
-  }
-
-  config.enable_endpoint = true;
-
-  config.endpoint_config.rule1.min_trailing_silence = 2.4;
-  config.endpoint_config.rule2.min_trailing_silence = 1.2;
-  config.endpoint_config.rule3.min_utterance_length = 300;
-
-  config.feat_config.sampling_rate = expected_sampling_rate;
-  config.feat_config.feature_dim = 80;
-
-  fprintf(stderr, "%s\n", config.ToString().c_str());
-
-  sherpa_ncnn::Recognizer recognizer(config);
-  auto s = recognizer.CreateStream();
-
-  int ret;
-  if ((ret = open_input_file(argv[8])) < 0) {
-    fprintf(stderr, "Open input file %s failed, r0=%d\n", argv[8], ret);
+  int32_t ret;
+  if ((ret = FFmpegOpenInputFile(input_url.c_str())) < 0) {
+    fprintf(stderr, "Open input file %s failed, r0=%d\n", input_url.c_str(),
+            ret);
     exit(1);
   }
 
-  if ((ret = init_filters(filter_descr)) < 0) {
+  if ((ret = FFmpegInitFilters(filter_descr)) < 0) {
     fprintf(stderr, "Init filters %s failed, r0=%d\n", filter_descr, ret);
     exit(1);
   }
+  fprintf(stderr, "Started\n");
 
   std::string last_text;
   int32_t segment_index = 0;
@@ -425,8 +557,8 @@ for a list of pre-trained models to download.
             if (ret < 0) {
               exit(1);
             }
-            sherpa_decode_frame(filt_frame, recognizer, s.get(), display,
-                                last_text, segment_index);
+            FFmpegDecodeFrame(filt_frame, recognizer, s.get(), display,
+                              last_text, segment_index);
             av_frame_unref(filt_frame);
           }
           av_frame_unref(frame);
@@ -436,7 +568,7 @@ for a list of pre-trained models to download.
     av_packet_unref(packet);
   }
 
-  // add some tail padding
+  // Add some tail padding
   float tail_paddings[4800] = {0};  // 0.3 seconds at 16 kHz sample rate
   s->AcceptWaveform(16000, tail_paddings, 4800);
 
@@ -462,7 +594,7 @@ for a list of pre-trained models to download.
   av_frame_free(&filt_frame);
 
   if (ret < 0 && ret != AVERROR_EOF) {
-    fprintf(stderr, "Error occurred: %s\n", __av_err2str(ret));
+    fprintf(stderr, "Error occurred: %s\n", FFmpegAvError2String(ret));
     exit(1);
   }
 

+ 1 - 2
sherpa-ncnn/csrc/sherpa-ncnn-alsa.cc

@@ -103,8 +103,7 @@ as the device_name.
   sherpa_ncnn::DecoderConfig decoder_conf;
   if (argc == 11) {
     std::string method = argv[10];
-    if (method.compare("greedy_search") ||
-        method.compare("modified_beam_search")) {
+    if (method == "greedy_search" || method == "modified_beam_search") {
       decoder_conf.method = method;
     }
   }

+ 1 - 2
sherpa-ncnn/csrc/sherpa-ncnn-microphone.cc

@@ -92,8 +92,7 @@ for a list of pre-trained models to download.
   const float expected_sampling_rate = 16000;
   if (argc == 10) {
     std::string method = argv[9];
-    if (method.compare("greedy_search") ||
-        method.compare("modified_beam_search")) {
+    if (method == "greedy_search" || method == "modified_beam_search") {
       config.decoder_config.method = method;
     }
   }

+ 1 - 2
sherpa-ncnn/csrc/sherpa-ncnn.cc

@@ -68,8 +68,7 @@ for a list of pre-trained models to download.
   float expected_sampling_rate = 16000;
   if (argc == 11) {
     std::string method = argv[10];
-    if (method.compare("greedy_search") ||
-        method.compare("modified_beam_search")) {
+    if (method == "greedy_search" || method == "modified_beam_search") {
       config.decoder_config.method = method;
     }
   }