decode-file-c-api.c 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. /**
  2. * Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
  3. *
  4. * See LICENSE for clarification regarding multiple authors
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <stdio.h>
  19. #include <stdlib.h>
  20. #include <string.h>
  21. #include "sherpa-ncnn/c-api/c-api.h"
  22. const char *kUsage =
  23. "\n"
  24. "Usage:\n"
  25. " ./bin/decode-file-c-api \\\n"
  26. " /path/to/tokens.txt \\\n"
  27. " /path/to/encoder.ncnn.param \\\n"
  28. " /path/to/encoder.ncnn.bin \\\n"
  29. " /path/to/decoder.ncnn.param \\\n"
  30. " /path/to/decoder.ncnn.bin \\\n"
  31. " /path/to/joiner.ncnn.param \\\n"
  32. " /path/to/joiner.ncnn.bin \\\n"
  33. " /path/to/foo.wav [<num_threads> [decode_method, can be "
  34. "greedy_search/modified_beam_search]]"
  35. "\n\n"
  36. "Please refer to \n"
  37. "https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html\n"
  38. "for a list of pre-trained models to download.\n";
  39. int32_t main(int32_t argc, char *argv[]) {
  40. if (argc < 9 || argc > 13) {
  41. fprintf(stderr, "%s\n", kUsage);
  42. return -1;
  43. }
  44. SherpaNcnnRecognizerConfig config;
  45. config.model_config.tokens = argv[1];
  46. config.model_config.encoder_param = argv[2];
  47. config.model_config.encoder_bin = argv[3];
  48. config.model_config.decoder_param = argv[4];
  49. config.model_config.decoder_bin = argv[5];
  50. config.model_config.joiner_param = argv[6];
  51. config.model_config.joiner_bin = argv[7];
  52. int32_t num_threads = 4;
  53. if (argc >= 10 && atoi(argv[9]) > 0) {
  54. num_threads = atoi(argv[9]);
  55. }
  56. config.model_config.num_threads = num_threads;
  57. config.model_config.use_vulkan_compute = 0;
  58. config.decoder_config.decoding_method = "greedy_search";
  59. if (argc >= 11) {
  60. config.decoder_config.decoding_method = argv[10];
  61. }
  62. config.decoder_config.num_active_paths = 4;
  63. config.enable_endpoint = 0;
  64. config.rule1_min_trailing_silence = 2.4;
  65. config.rule2_min_trailing_silence = 1.2;
  66. config.rule3_min_utterance_length = 300;
  67. config.feat_config.sampling_rate = 16000;
  68. config.feat_config.feature_dim = 80;
  69. if(argc >= 12) {
  70. config.hotwords_file = argv[11];
  71. } else {
  72. config.hotwords_file = "";
  73. }
  74. if(argc == 13) {
  75. config.hotwords_score = atof(argv[12]);
  76. } else {
  77. config.hotwords_score = 1.5;
  78. }
  79. SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config);
  80. const char *wav_filename = argv[8];
  81. FILE *fp = fopen(wav_filename, "rb");
  82. if (!fp) {
  83. fprintf(stderr, "Failed to open %s\n", wav_filename);
  84. return -1;
  85. }
  86. // Assume the wave header occupies 44 bytes.
  87. fseek(fp, 44, SEEK_SET);
  88. // simulate streaming
  89. #define N 3200 // 0.2 s. Sample rate is fixed to 16 kHz
  90. int16_t buffer[N];
  91. float samples[N];
  92. SherpaNcnnStream *s = CreateStream(recognizer);
  93. SherpaNcnnDisplay *display = CreateDisplay(50);
  94. int32_t segment_id = -1;
  95. while (!feof(fp)) {
  96. size_t n = fread((void *)buffer, sizeof(int16_t), N, fp);
  97. if (n > 0) {
  98. for (size_t i = 0; i != n; ++i) {
  99. samples[i] = buffer[i] / 32768.;
  100. }
  101. AcceptWaveform(s, 16000, samples, n);
  102. while (IsReady(recognizer, s)) {
  103. Decode(recognizer, s);
  104. }
  105. SherpaNcnnResult *r = GetResult(recognizer, s);
  106. if (strlen(r->text)) {
  107. SherpaNcnnPrint(display, segment_id, r->text);
  108. }
  109. DestroyResult(r);
  110. }
  111. }
  112. fclose(fp);
  113. // add some tail padding
  114. float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
  115. AcceptWaveform(s, 16000, tail_paddings, 4800);
  116. InputFinished(s);
  117. while (IsReady(recognizer, s)) {
  118. Decode(recognizer, s);
  119. }
  120. SherpaNcnnResult *r = GetResult(recognizer, s);
  121. if (strlen(r->text)) {
  122. SherpaNcnnPrint(display, segment_id, r->text);
  123. }
  124. DestroyResult(r);
  125. DestroyDisplay(display);
  126. DestroyStream(s);
  127. DestroyRecognizer(recognizer);
  128. fprintf(stderr, "\n");
  129. return 0;
  130. }