main.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package main
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. sherpa "github.com/k2-fsa/sherpa-ncnn-go/sherpa_ncnn"
  7. flag "github.com/spf13/pflag"
  8. "github.com/youpy/go-wav"
  9. "log"
  10. "os"
  11. "strings"
  12. )
  13. func main() {
  14. log.SetFlags(log.LstdFlags | log.Lmicroseconds)
  15. config := sherpa.RecognizerConfig{}
  16. config.Feat = sherpa.FeatureConfig{SampleRate: 16000, FeatureDim: 80}
  17. flag.StringVar(&config.Model.EncoderParam, "encoder-param", "", "Path to the encoder.ncnn.param")
  18. flag.StringVar(&config.Model.EncoderBin, "encoder-bin", "", "Path to the encoder.ncnn.bin")
  19. flag.StringVar(&config.Model.DecoderParam, "decoder-param", "", "Path to the decoder.ncnn.param")
  20. flag.StringVar(&config.Model.DecoderBin, "decoder-bin", "", "Path to the decoder.ncnn.bin")
  21. flag.StringVar(&config.Model.JoinerParam, "joiner-param", "", "Path to the joiner.ncnn.param")
  22. flag.StringVar(&config.Model.JoinerBin, "joiner-bin", "", "Path to the joiner.ncnn.bin")
  23. flag.StringVar(&config.Model.Tokens, "tokens", "", "Path to the tokens file")
  24. flag.IntVar(&config.Model.NumThreads, "num-threads", 1, "Number of threads for computing")
  25. flag.StringVar(&config.Decoder.DecodingMethod, "decoding-method", "greedy_search", "Decoding method. Possible values: greedy_search, modified_beam_search")
  26. flag.IntVar(&config.Decoder.NumActivePaths, "num-active-paths", 4, "Used only when --decoding-method is modified_beam_search")
  27. flag.Parse()
  28. if len(flag.Args()) != 1 {
  29. log.Fatalf("Please provide one wave file")
  30. }
  31. checkConfig(&config)
  32. log.Println("Reading", flag.Arg(0))
  33. samples, sampleRate := readWave(flag.Arg(0))
  34. log.Println("Initializing recognizer")
  35. recognizer := sherpa.NewRecognizer(&config)
  36. log.Println("Recognizer created!")
  37. defer sherpa.DeleteRecognizer(recognizer)
  38. log.Println("Start decoding!")
  39. stream := sherpa.NewStream(recognizer)
  40. defer sherpa.DeleteStream(stream)
  41. stream.AcceptWaveform(sampleRate, samples)
  42. tailPadding := make([]float32, int(float32(sampleRate)*0.3))
  43. stream.AcceptWaveform(sampleRate, tailPadding)
  44. for recognizer.IsReady(stream) {
  45. recognizer.Decode(stream)
  46. }
  47. log.Println("Decoding done!")
  48. result := recognizer.GetResult(stream)
  49. log.Println(strings.ToLower(result.Text))
  50. log.Printf("Wave duration: %v seconds", float32(len(samples))/float32(sampleRate))
  51. }
  52. func readWave(filename string) (samples []float32, sampleRate int) {
  53. file, _ := os.Open(filename)
  54. defer file.Close()
  55. reader := wav.NewReader(file)
  56. format, err := reader.Format()
  57. if err != nil {
  58. log.Fatalf("Failed to read wave format")
  59. }
  60. if format.AudioFormat != 1 {
  61. log.Fatalf("Support only PCM format. Given: %v\n", format.AudioFormat)
  62. }
  63. if format.NumChannels != 1 {
  64. log.Fatalf("Support only 1 channel wave file. Given: %v\n", format.NumChannels)
  65. }
  66. if format.BitsPerSample != 16 {
  67. log.Fatalf("Support only 16-bit per sample. Given: %v\n", format.BitsPerSample)
  68. }
  69. reader.Duration() // so that it initializes reader.Size
  70. buf := make([]byte, reader.Size)
  71. n, err := reader.Read(buf)
  72. if n != int(reader.Size) {
  73. log.Fatalf("Failed to read %v bytes. Returned %v bytes\n", reader.Size, n)
  74. }
  75. samples = samplesInt16ToFloat(buf)
  76. sampleRate = int(format.SampleRate)
  77. return
  78. }
  79. func samplesInt16ToFloat(inSamples []byte) []float32 {
  80. numSamples := len(inSamples) / 2
  81. outSamples := make([]float32, numSamples)
  82. for i := 0; i != numSamples; i++ {
  83. s := inSamples[i*2 : (i+1)*2]
  84. var s16 int16
  85. buf := bytes.NewReader(s)
  86. err := binary.Read(buf, binary.LittleEndian, &s16)
  87. if err != nil {
  88. log.Fatal("Failed to parse 16-bit sample")
  89. }
  90. outSamples[i] = float32(s16) / 32768
  91. }
  92. return outSamples
  93. }
  94. func checkConfig(config *sherpa.RecognizerConfig) {
  95. // --encoder-param
  96. if config.Model.EncoderParam == "" {
  97. log.Fatal("Please provide --encoder-param")
  98. }
  99. if _, err := os.Stat(config.Model.EncoderParam); errors.Is(err, os.ErrNotExist) {
  100. log.Fatalf("--encoder-param %v does not exist", config.Model.EncoderParam)
  101. }
  102. // --encoder-bin
  103. if config.Model.EncoderBin == "" {
  104. log.Fatal("Please provide --encoder-bin")
  105. }
  106. if _, err := os.Stat(config.Model.EncoderBin); errors.Is(err, os.ErrNotExist) {
  107. log.Fatalf("--encoder-bin %v does not exist", config.Model.EncoderBin)
  108. }
  109. // --decoder-param
  110. if config.Model.DecoderParam == "" {
  111. log.Fatal("Please provide --decoder-param")
  112. }
  113. if _, err := os.Stat(config.Model.DecoderParam); errors.Is(err, os.ErrNotExist) {
  114. log.Fatalf("--decoder-param %v does not exist", config.Model.DecoderParam)
  115. }
  116. // --decoder-bin
  117. if config.Model.DecoderBin == "" {
  118. log.Fatal("Please provide --decoder-bin")
  119. }
  120. if _, err := os.Stat(config.Model.DecoderBin); errors.Is(err, os.ErrNotExist) {
  121. log.Fatalf("--decoder-bin %v does not exist", config.Model.DecoderBin)
  122. }
  123. // --joiner-param
  124. if config.Model.JoinerParam == "" {
  125. log.Fatal("Please provide --joiner-param")
  126. }
  127. if _, err := os.Stat(config.Model.JoinerParam); errors.Is(err, os.ErrNotExist) {
  128. log.Fatalf("--joiner-param %v does not exist", config.Model.JoinerParam)
  129. }
  130. // --joiner-bin
  131. if config.Model.JoinerBin == "" {
  132. log.Fatal("Please provide --joiner-bin")
  133. }
  134. if _, err := os.Stat(config.Model.JoinerBin); errors.Is(err, os.ErrNotExist) {
  135. log.Fatalf("--joiner-bin %v does not exist", config.Model.JoinerBin)
  136. }
  137. // --tokens
  138. if config.Model.Tokens == "" {
  139. log.Fatal("Please provide --tokens")
  140. }
  141. if _, err := os.Stat(config.Model.Tokens); errors.Is(err, os.ErrNotExist) {
  142. log.Fatalf("--tokens %v does not exist", config.Model.Tokens)
  143. }
  144. }