decode-file.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #!/usr/bin/env python3
  2. """
  3. This file demonstrates how to use sherpa-ncnn Python API to recognize
  4. a single file.
  5. Please refer to
  6. https://k2-fsa.github.io/sherpa/ncnn/index.html
  7. to install sherpa-ncnn and to download the pre-trained models
  8. used in this file.
  9. """
  10. import wave
  11. import numpy as np
  12. import sherpa_ncnn
  13. def main():
  14. recognizer = sherpa_ncnn.Recognizer(
  15. tokens="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/tokens.txt",
  16. encoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.param",
  17. encoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.bin",
  18. decoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.param",
  19. decoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.bin",
  20. joiner_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.param",
  21. joiner_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.bin",
  22. num_threads=4,
  23. )
  24. filename = "./sherpa-ncnn-conv-emformer-transducer-2022-12-06/test_wavs/1.wav"
  25. with wave.open(filename) as f:
  26. assert f.getframerate() == recognizer.sample_rate, (
  27. f.getframerate(),
  28. recognizer.sample_rate,
  29. )
  30. assert f.getnchannels() == 1, f.getnchannels()
  31. assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  32. num_samples = f.getnframes()
  33. samples = f.readframes(num_samples)
  34. samples_int16 = np.frombuffer(samples, dtype=np.int16)
  35. samples_float32 = samples_int16.astype(np.float32)
  36. samples_float32 = samples_float32 / 32768
  37. recognizer.accept_waveform(recognizer.sample_rate, samples_float32)
  38. tail_paddings = np.zeros(int(recognizer.sample_rate * 0.5), dtype=np.float32)
  39. recognizer.accept_waveform(recognizer.sample_rate, tail_paddings)
  40. recognizer.input_finished()
  41. print(recognizer.text)
  42. if __name__ == "__main__":
  43. main()