decode-file-c-api.c 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. /**
  2. * Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
  3. *
  4. * See LICENSE for clarification regarding multiple authors
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <stdio.h>
  19. #include <stdlib.h>
  20. #include <string.h>
  21. #include "sherpa-ncnn/c-api/c-api.h"
  22. const char *kUsage =
  23. "\n"
  24. "Usage:\n"
  25. " ./bin/decode-file-c-api \\\n"
  26. " /path/to/tokens.txt \\\n"
  27. " /path/to/encoder.ncnn.param \\\n"
  28. " /path/to/encoder.ncnn.bin \\\n"
  29. " /path/to/decoder.ncnn.param \\\n"
  30. " /path/to/decoder.ncnn.bin \\\n"
  31. " /path/to/joiner.ncnn.param \\\n"
  32. " /path/to/joiner.ncnn.bin \\\n"
  33. " /path/to/foo.wav [<num_threads> [decode_method, can be "
  34. "greedy_search/modified_beam_search]]"
  35. "\n\n"
  36. "Please refer to \n"
  37. "https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html\n"
  38. "for a list of pre-trained models to download.\n";
  39. int32_t main(int32_t argc, char *argv[]) {
  40. if (argc < 9 || argc > 13) {
  41. fprintf(stderr, "%s\n", kUsage);
  42. return -1;
  43. }
  44. SherpaNcnnRecognizerConfig config;
  45. memset(&config, 0, sizeof(config));
  46. config.model_config.tokens = argv[1];
  47. config.model_config.encoder_param = argv[2];
  48. config.model_config.encoder_bin = argv[3];
  49. config.model_config.decoder_param = argv[4];
  50. config.model_config.decoder_bin = argv[5];
  51. config.model_config.joiner_param = argv[6];
  52. config.model_config.joiner_bin = argv[7];
  53. int32_t num_threads = 4;
  54. if (argc >= 10 && atoi(argv[9]) > 0) {
  55. num_threads = atoi(argv[9]);
  56. }
  57. config.model_config.num_threads = num_threads;
  58. config.model_config.use_vulkan_compute = 0;
  59. config.decoder_config.decoding_method = "greedy_search";
  60. if (argc >= 11) {
  61. config.decoder_config.decoding_method = argv[10];
  62. }
  63. config.decoder_config.num_active_paths = 4;
  64. config.enable_endpoint = 0;
  65. config.rule1_min_trailing_silence = 2.4;
  66. config.rule2_min_trailing_silence = 1.2;
  67. config.rule3_min_utterance_length = 300;
  68. config.feat_config.sampling_rate = 16000;
  69. config.feat_config.feature_dim = 80;
  70. if (argc >= 12) {
  71. config.hotwords_file = argv[11];
  72. }
  73. if (argc == 13) {
  74. config.hotwords_score = atof(argv[12]);
  75. }
  76. SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config);
  77. const char *wav_filename = argv[8];
  78. FILE *fp = fopen(wav_filename, "rb");
  79. if (!fp) {
  80. fprintf(stderr, "Failed to open %s\n", wav_filename);
  81. return -1;
  82. }
  83. // Assume the wave header occupies 44 bytes.
  84. fseek(fp, 44, SEEK_SET);
  85. // simulate streaming
  86. #define N 3200 // 0.2 s. Sample rate is fixed to 16 kHz
  87. int16_t buffer[N];
  88. float samples[N];
  89. SherpaNcnnStream *s = CreateStream(recognizer);
  90. SherpaNcnnDisplay *display = CreateDisplay(50);
  91. int32_t segment_id = -1;
  92. while (!feof(fp)) {
  93. size_t n = fread((void *)buffer, sizeof(int16_t), N, fp);
  94. if (n > 0) {
  95. for (size_t i = 0; i != n; ++i) {
  96. samples[i] = buffer[i] / 32768.;
  97. }
  98. AcceptWaveform(s, 16000, samples, n);
  99. while (IsReady(recognizer, s)) {
  100. Decode(recognizer, s);
  101. }
  102. SherpaNcnnResult *r = GetResult(recognizer, s);
  103. if (strlen(r->text)) {
  104. SherpaNcnnPrint(display, segment_id, r->text);
  105. }
  106. DestroyResult(r);
  107. }
  108. }
  109. fclose(fp);
  110. // add some tail padding
  111. float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
  112. AcceptWaveform(s, 16000, tail_paddings, 4800);
  113. InputFinished(s);
  114. while (IsReady(recognizer, s)) {
  115. Decode(recognizer, s);
  116. }
  117. SherpaNcnnResult *r = GetResult(recognizer, s);
  118. if (strlen(r->text)) {
  119. SherpaNcnnPrint(display, segment_id, r->text);
  120. }
  121. DestroyResult(r);
  122. DestroyDisplay(display);
  123. DestroyStream(s);
  124. DestroyRecognizer(recognizer);
  125. fprintf(stderr, "\n");
  126. return 0;
  127. }