test-recognizer.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #!/usr/bin/env python3
  2. import wave
  3. import numpy as np
  4. import sherpa_ncnn
  5. def main():
  6. recognizer = sherpa_ncnn.Recognizer(
  7. tokens="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/tokens.txt",
  8. encoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.param",
  9. encoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/encoder_jit_trace-pnnx.ncnn.bin",
  10. decoder_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.param",
  11. decoder_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/decoder_jit_trace-pnnx.ncnn.bin",
  12. joiner_param="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.param",
  13. joiner_bin="./sherpa-ncnn-conv-emformer-transducer-2022-12-06/joiner_jit_trace-pnnx.ncnn.bin",
  14. num_threads=4,
  15. )
  16. filename = "./sherpa-ncnn-conv-emformer-transducer-2022-12-06/test_wavs/1.wav"
  17. with wave.open(filename) as f:
  18. assert f.getframerate() == recognizer.sample_rate, (
  19. f.getframerate(),
  20. recognizer.sample_rate,
  21. )
  22. assert f.getnchannels() == 1, f.getnchannels()
  23. assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
  24. num_samples = f.getnframes()
  25. samples = f.readframes(num_samples)
  26. samples_int16 = np.frombuffer(samples, dtype=np.int16)
  27. samples_float32 = samples_int16.astype(np.float32)
  28. samples_float32 = samples_float32 / 32768
  29. recognizer.accept_waveform(recognizer.sample_rate, samples_float32)
  30. tail_paddings = np.zeros(int(recognizer.sample_rate * 0.5), dtype=np.float32)
  31. recognizer.accept_waveform(recognizer.sample_rate, tail_paddings)
  32. recognizer.input_finished()
  33. print(recognizer.text)
  34. if __name__ == "__main__":
  35. main()