SherpaNcnn.swift 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. /// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
  2. ///
  3. /// See LICENSE for clarification regarding multiple authors
  4. ///
  5. /// Licensed under the Apache License, Version 2.0 (the "License");
  6. /// you may not use this file except in compliance with the License.
  7. /// You may obtain a copy of the License at
  8. ///
  9. /// http://www.apache.org/licenses/LICENSE-2.0
  10. ///
  11. /// Unless required by applicable law or agreed to in writing, software
  12. /// distributed under the License is distributed on an "AS IS" BASIS,
  13. /// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. /// See the License for the specific language governing permissions and
  15. /// limitations under the License.
  16. import Foundation // For NSString
  17. /// Convert a String from swift to a `const char*` so that we can pass it to
  18. /// the C language.
  19. ///
  20. /// - Parameters:
  21. /// - s: The String to convert.
  22. /// - Returns: A pointer that can be passed to C as `const char*`
  23. func toCPointer(_ s: String) -> UnsafePointer<Int8>! {
  24. let cs = (s as NSString).utf8String
  25. return UnsafePointer<Int8>(cs)
  26. }
  27. /// Return an instance of SherpaNcnnModelConfig.
  28. ///
  29. /// Please refer to
  30. /// https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
  31. /// to download the required `.ncnn.param` and `.ncnn.bin` files.
  32. ///
  33. /// - Parameters:
  34. /// - encoderParam: Path to encoder.ncnn.param
  35. /// - encoderBin: Path to encoder.ncnn.bin
  36. /// - decoderParam: Path to decoder.ncnn.param
  37. /// - decoderBin: Path to decoder.ncnn.bin
  38. /// - joinerParam: Path to joiner.ncnn.param
  39. /// - joinerBin: Path to joiner.ncnn.bin
  40. /// - tokens.txt: Path to tokens.txt
  41. /// - useVulkanCompute: It if it true, and if sherpa-ncnn is compiled with
  42. /// vulkan support, and if there are GPUs available, then
  43. /// it will use GPU for neural network computation.
  44. /// Otherwise, it uses CPU for computation.
  45. /// - numThreads.txt: Number of threads to use for neural
  46. /// network computation.
  47. ///
  48. /// - Returns: Return an instance of SherpaNcnnModelConfig
  49. func sherpaNcnnModelConfig(
  50. encoderParam: String,
  51. encoderBin: String,
  52. decoderParam: String,
  53. decoderBin: String,
  54. joinerParam: String,
  55. joinerBin: String,
  56. tokens: String,
  57. numThreads: Int = 4,
  58. useVulkanCompute: Bool = true
  59. ) -> SherpaNcnnModelConfig {
  60. return SherpaNcnnModelConfig(
  61. encoder_param: toCPointer(encoderParam),
  62. encoder_bin: toCPointer(encoderBin),
  63. decoder_param: toCPointer(decoderParam),
  64. decoder_bin: toCPointer(decoderBin),
  65. joiner_param: toCPointer(joinerParam),
  66. joiner_bin: toCPointer(joinerBin),
  67. tokens: toCPointer(tokens),
  68. use_vulkan_compute: useVulkanCompute ? 1 : 0,
  69. num_threads: Int32(numThreads))
  70. }
  71. func sherpaNcnnFeatureExtractorConfig(
  72. sampleRate: Float,
  73. featureDim: Int
  74. )-> SherpaNcnnFeatureExtractorConfig {
  75. return SherpaNcnnFeatureExtractorConfig(
  76. sampling_rate: sampleRate,
  77. feature_dim: Int32(featureDim))
  78. }
  79. /// Create an instance of SherpaNcnnDecoderConfig
  80. ///
  81. /// - Parameters:
  82. /// - decodingMethod: Valid decoding methods are "greedy_search"
  83. /// and "modified_beam_search"
  84. /// - numActivePaths: Used only when decodingMethod is "modified_beam_search".
  85. /// It specifies the beam size for beam search.
  86. /// - enableEndpoint: true to enable endpoint detection. False to disable
  87. /// endpoint detection.
  88. /// - rule1MinTrailingSilence: An endpoint is detected if trailing silence in
  89. /// seconds is larger than this value even if
  90. /// nothing has been decoded. Used only when
  91. /// enable_endpoint is true.
  92. /// - rule2MinTrailingSilence: An endpoint is detected if trailing silence in
  93. /// seconds is larger than this value even after
  94. /// something that is not blank has been decoded.
  95. /// Used only when enable_endpoint is true.
  96. /// - rule3MinUtteranceLength: An endpoint is detected if the utterance in
  97. /// seconds is larger than this value.
  98. /// Used only when enable_endpoint is true.
  99. func sherpaNcnnDecoderConfig(
  100. decodingMethod: String = "greedy_search",
  101. numActivePaths: Int = 4
  102. ) -> SherpaNcnnDecoderConfig {
  103. return SherpaNcnnDecoderConfig(
  104. decoding_method: toCPointer(decodingMethod),
  105. num_active_paths: Int32(numActivePaths))
  106. }
  107. func sherpaNcnnRecognizerConfig(
  108. featConfig: SherpaNcnnFeatureExtractorConfig,
  109. modelConfig: SherpaNcnnModelConfig,
  110. decoderConfig: SherpaNcnnDecoderConfig,
  111. enableEndpoint: Bool = false,
  112. rule1MinTrailingSilence: Float = 2.4,
  113. rule2MinTrailingSilence: Float = 1.2,
  114. rule3MinUtteranceLength: Float = 30
  115. ) -> SherpaNcnnRecognizerConfig {
  116. return SherpaNcnnRecognizerConfig(
  117. feat_config: featConfig,
  118. model_config: modelConfig,
  119. decoder_config: decoderConfig,
  120. enable_endpoint: enableEndpoint ? 1 : 0,
  121. rule1_min_trailing_silence: rule1MinTrailingSilence,
  122. rule2_min_trailing_silence: rule2MinTrailingSilence,
  123. rule3_min_utterance_length: rule3MinUtteranceLength)
  124. }
  125. /// Wrapper for recognition result.
  126. ///
  127. /// Usage:
  128. ///
  129. /// let result = recognizer.getResult()
  130. /// print("text: \(result.text)")
  131. ///
  132. class SherpaNcnnRecongitionResult {
  133. /// A pointer to the underlying counterpart in C
  134. let result: UnsafePointer<SherpaNcnnResult>!
  135. /// Return the actual recognition result.
  136. /// For English models, it contains words separated by spaces.
  137. /// For Chinese models, it contains Chinese words.
  138. var text: String {
  139. return String(cString: result.pointee.text)
  140. }
  141. init(result: UnsafePointer<SherpaNcnnResult>!) {
  142. self.result = result
  143. }
  144. deinit {
  145. if let result {
  146. DestroyResult(result)
  147. }
  148. }
  149. }
  150. class SherpaNcnnRecognizer {
  151. /// A pointer to the underlying counterpart in C
  152. let recognizer: OpaquePointer!
  153. let stream: OpaquePointer!
  154. /// Constructor taking a model config and a decoder config.
  155. init(
  156. config: UnsafePointer<SherpaNcnnRecognizerConfig>!
  157. ) {
  158. recognizer = CreateRecognizer(config)
  159. stream = CreateStream(recognizer)
  160. }
  161. deinit {
  162. if let stream {
  163. DestroyStream(stream)
  164. }
  165. if let recognizer {
  166. DestroyRecognizer(recognizer)
  167. }
  168. }
  169. /// Decode wave samples.
  170. ///
  171. /// - Parameters:
  172. /// - samples: Audio samples normalzed to the range [-1, 1]
  173. /// - sampleRate: Sample rate of the input audio samples. If it is
  174. /// different from featConfig.sampleRate, we will do
  175. /// resample. Caution: You cannot use a different
  176. /// sampleRate across different calls to
  177. /// AcceptWaveform().
  178. func acceptWaveform(samples: [Float], sampleRate: Float = 16000) {
  179. AcceptWaveform(stream, sampleRate, samples, Int32(samples.count))
  180. }
  181. func isReady() -> Bool {
  182. return IsReady(recognizer, stream) == 1
  183. }
  184. /// If there are enough number of feature frames, it invokes the neural
  185. /// network computation and decoding. Otherwise, it is a no-op.
  186. func decode() {
  187. Decode(recognizer, stream)
  188. }
  189. /// Get the decoding results so far
  190. func getResult() -> SherpaNcnnRecongitionResult {
  191. let result: UnsafeMutablePointer<SherpaNcnnResult>? = GetResult(recognizer, stream)
  192. return SherpaNcnnRecongitionResult(result: result)
  193. }
  194. /// Reset the recognizer, which clears the neural network model state
  195. /// and the state for decoding.
  196. func reset() {
  197. Reset(recognizer, stream)
  198. }
  199. /// Signal that no more audio samples would be available.
  200. /// After this call, you cannot call acceptWaveform() any more.
  201. func inputFinished() {
  202. InputFinished(stream)
  203. }
  204. /// Return true is an endpoint has been detected.
  205. func isEndpoint() -> Bool {
  206. return IsEndpoint(recognizer, stream) == 1
  207. }
  208. }