decode-file-c-api.c 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. /**
  2. * Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
  3. *
  4. * See LICENSE for clarification regarding multiple authors
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include <stdio.h>
  19. #include <stdlib.h>
  20. #include <string.h>
  21. #include "sherpa-ncnn/c-api/c-api.h"
  22. const char *kUsage =
  23. "\n"
  24. "Usage:\n"
  25. " ./bin/decode-file-c-api \\\n"
  26. " /path/to/tokens.txt \\\n"
  27. " /path/to/encoder.ncnn.param \\\n"
  28. " /path/to/encoder.ncnn.bin \\\n"
  29. " /path/to/decoder.ncnn.param \\\n"
  30. " /path/to/decoder.ncnn.bin \\\n"
  31. " /path/to/joiner.ncnn.param \\\n"
  32. " /path/to/joiner.ncnn.bin \\\n"
  33. " /path/to/foo.wav [<num_threads> [decode_method, can be "
  34. "greedy_search/modified_beam_search]]"
  35. "\n\n"
  36. "Please refer to \n"
  37. "https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html\n"
  38. "for a list of pre-trained models to download.\n";
  39. int32_t main(int32_t argc, char *argv[]) {
  40. if (argc < 9 || argc > 11) {
  41. fprintf(stderr, "%s\n", kUsage);
  42. return -1;
  43. }
  44. SherpaNcnnRecognizerConfig config;
  45. config.model_config.tokens = argv[1];
  46. config.model_config.encoder_param = argv[2];
  47. config.model_config.encoder_bin = argv[3];
  48. config.model_config.decoder_param = argv[4];
  49. config.model_config.decoder_bin = argv[5];
  50. config.model_config.joiner_param = argv[6];
  51. config.model_config.joiner_bin = argv[7];
  52. int32_t num_threads = 4;
  53. if (argc >= 10 && atoi(argv[9]) > 0) {
  54. num_threads = atoi(argv[9]);
  55. }
  56. config.model_config.num_threads = num_threads;
  57. config.model_config.use_vulkan_compute = 0;
  58. config.decoder_config.decoding_method = "greedy_search";
  59. if (argc == 11) {
  60. config.decoder_config.decoding_method = argv[10];
  61. }
  62. config.decoder_config.num_active_paths = 4;
  63. config.enable_endpoint = 0;
  64. config.rule1_min_trailing_silence = 2.4;
  65. config.rule2_min_trailing_silence = 1.2;
  66. config.rule3_min_utterance_length = 300;
  67. config.feat_config.sampling_rate = 16000;
  68. config.feat_config.feature_dim = 80;
  69. SherpaNcnnRecognizer *recognizer = CreateRecognizer(&config);
  70. const char *wav_filename = argv[8];
  71. FILE *fp = fopen(wav_filename, "rb");
  72. if (!fp) {
  73. fprintf(stderr, "Failed to open %s\n", wav_filename);
  74. return -1;
  75. }
  76. // Assume the wave header occupies 44 bytes.
  77. fseek(fp, 44, SEEK_SET);
  78. // simulate streaming
  79. #define N 3200 // 0.2 s. Sample rate is fixed to 16 kHz
  80. int16_t buffer[N];
  81. float samples[N];
  82. SherpaNcnnStream *s = CreateStream(recognizer);
  83. SherpaNcnnDisplay *display = CreateDisplay(50);
  84. int32_t segment_id = -1;
  85. while (!feof(fp)) {
  86. size_t n = fread((void *)buffer, sizeof(int16_t), N, fp);
  87. if (n > 0) {
  88. for (size_t i = 0; i != n; ++i) {
  89. samples[i] = buffer[i] / 32768.;
  90. }
  91. AcceptWaveform(s, 16000, samples, n);
  92. while (IsReady(recognizer, s)) {
  93. Decode(recognizer, s);
  94. }
  95. SherpaNcnnResult *r = GetResult(recognizer, s);
  96. if (strlen(r->text)) {
  97. SherpaNcnnPrint(display, segment_id, r->text);
  98. }
  99. DestroyResult(r);
  100. }
  101. }
  102. fclose(fp);
  103. // add some tail padding
  104. float tail_paddings[4800] = {0}; // 0.3 seconds at 16 kHz sample rate
  105. AcceptWaveform(s, 16000, tail_paddings, 4800);
  106. InputFinished(s);
  107. while (IsReady(recognizer, s)) {
  108. Decode(recognizer, s);
  109. }
  110. SherpaNcnnResult *r = GetResult(recognizer, s);
  111. if (strlen(r->text)) {
  112. SherpaNcnnPrint(display, segment_id, r->text);
  113. }
  114. DestroyResult(r);
  115. DestroyDisplay(display);
  116. DestroyStream(s);
  117. DestroyRecognizer(recognizer);
  118. fprintf(stderr, "\n");
  119. return 0;
  120. }