SherpaNcnn.swift 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. /// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
  2. ///
  3. /// See LICENSE for clarification regarding multiple authors
  4. ///
  5. /// Licensed under the Apache License, Version 2.0 (the "License");
  6. /// you may not use this file except in compliance with the License.
  7. /// You may obtain a copy of the License at
  8. ///
  9. /// http://www.apache.org/licenses/LICENSE-2.0
  10. ///
  11. /// Unless required by applicable law or agreed to in writing, software
  12. /// distributed under the License is distributed on an "AS IS" BASIS,
  13. /// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. /// See the License for the specific language governing permissions and
  15. /// limitations under the License.
  16. import Foundation // For NSString
  17. /// Convert a String from swift to a `const char*` so that we can pass it to
  18. /// the C language.
  19. ///
  20. /// - Parameters:
  21. /// - s: The String to convert.
  22. /// - Returns: A pointer that can be passed to C as `const char*`
  23. func toCPointer(_ s: String) -> UnsafePointer<Int8>! {
  24. let cs = (s as NSString).utf8String
  25. return UnsafePointer<Int8>(cs)
  26. }
  27. /// Return an instance of SherpaNcnnModelConfig.
  28. ///
  29. /// Please refer to
  30. /// https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
  31. /// to download the required `.ncnn.param` and `.ncnn.bin` files.
  32. ///
  33. /// - Parameters:
  34. /// - encoderParam: Path to encoder.ncnn.param
  35. /// - encoderBin: Path to encoder.ncnn.bin
  36. /// - decoderParam: Path to decoder.ncnn.param
  37. /// - decoderBin: Path to decoder.ncnn.bin
  38. /// - joinerParam: Path to joiner.ncnn.param
  39. /// - joinerBin: Path to joiner.ncnn.bin
  40. /// - tokens.txt: Path to tokens.txt
  41. /// - useVulkanCompute: It if it true, and if sherpa-ncnn is compiled with
  42. /// vulkan support, and if there are GPUs available, then
  43. /// it will use GPU for neural network computation.
  44. /// Otherwise, it uses CPU for computation.
  45. /// - numThreads.txt: Number of threads to use for neural
  46. /// network computation.
  47. ///
  48. /// - Returns: Return an instance of SherpaNcnnModelConfig
  49. func sherpaNcnnModelConfig(
  50. encoderParam: String,
  51. encoderBin: String,
  52. decoderParam: String,
  53. decoderBin: String,
  54. joinerParam: String,
  55. joinerBin: String,
  56. tokens: String,
  57. numThreads: Int = 4,
  58. useVulkanCompute: Bool = true
  59. ) -> SherpaNcnnModelConfig {
  60. return SherpaNcnnModelConfig(
  61. encoder_param: toCPointer(encoderParam),
  62. encoder_bin: toCPointer(encoderBin),
  63. decoder_param: toCPointer(decoderParam),
  64. decoder_bin: toCPointer(decoderBin),
  65. joiner_param: toCPointer(joinerParam),
  66. joiner_bin: toCPointer(joinerBin),
  67. tokens: toCPointer(tokens),
  68. use_vulkan_compute: useVulkanCompute ? 1 : 0,
  69. num_threads: Int32(numThreads))
  70. }
  71. /// Create an instance of SherpaNcnnDecoderConfig
  72. ///
  73. /// - Parameters:
  74. /// - decodingMethod: Valid decoding methods are "greedy_search"
  75. /// and "modified_beam_search"
  76. /// - numActivePaths: Used only when decodingMethod is "modified_beam_search".
  77. /// It specifies the beam size for beam search.
  78. /// - enableEndpoint: true to enable endpoint detection. False to disable
  79. /// endpoint detection.
  80. /// - rule1MinTrailingSilence: An endpoint is detected if trailing silence in
  81. /// seconds is larger than this value even if
  82. /// nothing has been decoded. Used only when
  83. /// enable_endpoint is true.
  84. /// - rule2MinTrailingSilence: An endpoint is detected if trailing silence in
  85. /// seconds is larger than this value even after
  86. /// something that is not blank has been decoded.
  87. /// Used only when enable_endpoint is true.
  88. /// - rule3MinUtteranceLength: An endpoint is detected if the utterance in
  89. /// seconds is larger than this value.
  90. /// Used only when enable_endpoint is true.
  91. func sherpaNcnnDecoderConfig(
  92. decodingMethod: String = "greedy_search",
  93. numActivePaths: Int = 4,
  94. enableEndpoint: Bool = false,
  95. rule1MinTrailingSilence: Float = 2.4,
  96. rule2MinTrailingSilence: Float = 1.2,
  97. rule3MinUtteranceLength: Float = 30
  98. ) -> SherpaNcnnDecoderConfig {
  99. return SherpaNcnnDecoderConfig(
  100. decoding_method: toCPointer(decodingMethod),
  101. num_active_paths: Int32(numActivePaths),
  102. enable_endpoint: enableEndpoint ? 1 : 0,
  103. rule1_min_trailing_silence: rule1MinTrailingSilence,
  104. rule2_min_trailing_silence: rule2MinTrailingSilence,
  105. rule3_min_utterance_length: rule3MinUtteranceLength)
  106. }
  107. /// Wrapper for recognition result.
  108. ///
  109. /// Usage:
  110. ///
  111. /// let result = recognizer.getResult()
  112. /// print("text: \(result.text)")
  113. ///
  114. class SherpaNcnnRecongitionResult {
  115. /// A pointer to the underlying counterpart in C
  116. let result: UnsafePointer<SherpaNcnnResult>!
  117. /// Return the actual recognition result.
  118. /// For English models, it contains words separated by spaces.
  119. /// For Chinese models, it contains Chinese words.
  120. var text: String {
  121. return String(cString: result.pointee.text)
  122. }
  123. init(result: UnsafePointer<SherpaNcnnResult>!) {
  124. self.result = result
  125. }
  126. deinit {
  127. if let result {
  128. DestroyResult(result)
  129. }
  130. }
  131. }
  132. class SherpaNcnnRecognizer {
  133. /// A pointer to the underlying counterpart in C
  134. let recognizer: OpaquePointer!
  135. /// Constructor taking a model config and a decoder config.
  136. init(
  137. modelConfig: UnsafePointer<SherpaNcnnModelConfig>!,
  138. decoderConfig: UnsafePointer<SherpaNcnnDecoderConfig>!
  139. ) {
  140. recognizer = CreateRecognizer(modelConfig, decoderConfig)
  141. }
  142. deinit {
  143. if let recognizer {
  144. DestroyRecognizer(recognizer)
  145. }
  146. }
  147. /// Decode wave samples.
  148. ///
  149. /// - Parameters:
  150. /// - samples: Audio samples normalzed to the range [-1, 1]
  151. /// - sampleRate: Sample rate of the input audio samples. Must match
  152. /// the one expected by the model. It must be 16000 for
  153. /// models from icefall.
  154. func acceptWaveform(samples: [Float], sampleRate: Float = 16000) {
  155. AcceptWaveform(recognizer, sampleRate, samples, Int32(samples.count))
  156. }
  157. /// If there are enough number of feature frames, it invokes the neural
  158. /// network computation and decoding. Otherwise, it is a no-op.
  159. func decode() {
  160. Decode(recognizer)
  161. }
  162. /// Get the decoding results so far
  163. func getResult() -> SherpaNcnnRecongitionResult {
  164. let result: UnsafeMutablePointer<SherpaNcnnResult>? = GetResult(recognizer)
  165. return SherpaNcnnRecongitionResult(result: result)
  166. }
  167. /// Reset the recognizer, which clears the neural network model state
  168. /// and the state for decoding.
  169. func reset() {
  170. Reset(recognizer)
  171. }
  172. /// Signal that no more audio samples would be available.
  173. /// After this call, you cannot call acceptWaveform() any more.
  174. func inputFinished() {
  175. InputFinished(recognizer)
  176. }
  177. /// Return true is an endpoint has been detected.
  178. func isEndpoint() -> Bool {
  179. return IsEndpoint(recognizer) == 1 ? true : false
  180. }
  181. }