SherpaNcnn.swift 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. /// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
  2. ///
  3. /// See LICENSE for clarification regarding multiple authors
  4. ///
  5. /// Licensed under the Apache License, Version 2.0 (the "License");
  6. /// you may not use this file except in compliance with the License.
  7. /// You may obtain a copy of the License at
  8. ///
  9. /// http://www.apache.org/licenses/LICENSE-2.0
  10. ///
  11. /// Unless required by applicable law or agreed to in writing, software
  12. /// distributed under the License is distributed on an "AS IS" BASIS,
  13. /// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. /// See the License for the specific language governing permissions and
  15. /// limitations under the License.
  16. import Foundation // For NSString
  17. /// Convert a String from swift to a `const char*` so that we can pass it to
  18. /// the C language.
  19. ///
  20. /// - Parameters:
  21. /// - s: The String to convert.
  22. /// - Returns: A pointer that can be passed to C as `const char*`
  23. func toCPointer(_ s: String) -> UnsafePointer<Int8>! {
  24. let cs = (s as NSString).utf8String
  25. return UnsafePointer<Int8>(cs)
  26. }
  27. /// Return an instance of SherpaNcnnModelConfig.
  28. ///
  29. /// Please refer to
  30. /// https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
  31. /// to download the required `.ncnn.param` and `.ncnn.bin` files.
  32. ///
  33. /// - Parameters:
  34. /// - encoderParam: Path to encoder.ncnn.param
  35. /// - encoderBin: Path to encoder.ncnn.bin
  36. /// - decoderParam: Path to decoder.ncnn.param
  37. /// - decoderBin: Path to decoder.ncnn.bin
  38. /// - joinerParam: Path to joiner.ncnn.param
  39. /// - joinerBin: Path to joiner.ncnn.bin
  40. /// - tokens.txt: Path to tokens.txt
  41. /// - useVulkanCompute: It if it true, and if sherpa-ncnn is compiled with
  42. /// vulkan support, and if there are GPUs available, then
  43. /// it will use GPU for neural network computation.
  44. /// Otherwise, it uses CPU for computation.
  45. /// - numThreads.txt: Number of threads to use for neural
  46. /// network computation.
  47. ///
  48. /// - Returns: Return an instance of SherpaNcnnModelConfig
  49. func sherpaNcnnModelConfig(
  50. encoderParam: String,
  51. encoderBin: String,
  52. decoderParam: String,
  53. decoderBin: String,
  54. joinerParam: String,
  55. joinerBin: String,
  56. tokens: String,
  57. numThreads: Int = 4,
  58. useVulkanCompute: Bool = true
  59. ) -> SherpaNcnnModelConfig {
  60. return SherpaNcnnModelConfig(
  61. encoder_param: toCPointer(encoderParam),
  62. encoder_bin: toCPointer(encoderBin),
  63. decoder_param: toCPointer(decoderParam),
  64. decoder_bin: toCPointer(decoderBin),
  65. joiner_param: toCPointer(joinerParam),
  66. joiner_bin: toCPointer(joinerBin),
  67. tokens: toCPointer(tokens),
  68. use_vulkan_compute: useVulkanCompute ? 1 : 0,
  69. num_threads: Int32(numThreads))
  70. }
  71. func sherpaNcnnFeatureExtractorConfig(
  72. sampleRate: Float,
  73. featureDim: Int
  74. )-> SherpaNcnnFeatureExtractorConfig {
  75. return SherpaNcnnFeatureExtractorConfig(
  76. sampling_rate: sampleRate,
  77. feature_dim: Int32(featureDim))
  78. }
  79. /// Create an instance of SherpaNcnnDecoderConfig
  80. ///
  81. /// - Parameters:
  82. /// - decodingMethod: Valid decoding methods are "greedy_search"
  83. /// and "modified_beam_search"
  84. /// - numActivePaths: Used only when decodingMethod is "modified_beam_search".
  85. /// It specifies the beam size for beam search.
  86. /// - enableEndpoint: true to enable endpoint detection. False to disable
  87. /// endpoint detection.
  88. /// - rule1MinTrailingSilence: An endpoint is detected if trailing silence in
  89. /// seconds is larger than this value even if
  90. /// nothing has been decoded. Used only when
  91. /// enable_endpoint is true.
  92. /// - rule2MinTrailingSilence: An endpoint is detected if trailing silence in
  93. /// seconds is larger than this value even after
  94. /// something that is not blank has been decoded.
  95. /// Used only when enable_endpoint is true.
  96. /// - rule3MinUtteranceLength: An endpoint is detected if the utterance in
  97. /// seconds is larger than this value.
  98. /// Used only when enable_endpoint is true.
  99. func sherpaNcnnDecoderConfig(
  100. decodingMethod: String = "greedy_search",
  101. numActivePaths: Int = 4
  102. ) -> SherpaNcnnDecoderConfig {
  103. return SherpaNcnnDecoderConfig(
  104. decoding_method: toCPointer(decodingMethod),
  105. num_active_paths: Int32(numActivePaths))
  106. }
  107. func sherpaNcnnRecognizerConfig(
  108. featConfig: SherpaNcnnFeatureExtractorConfig,
  109. modelConfig: SherpaNcnnModelConfig,
  110. decoderConfig: SherpaNcnnDecoderConfig,
  111. enableEndpoint: Bool = false,
  112. rule1MinTrailingSilence: Float = 2.4,
  113. rule2MinTrailingSilence: Float = 1.2,
  114. rule3MinUtteranceLength: Float = 30,
  115. hotwordsFile: String = "",
  116. hotwordsScore: Float = 1.5
  117. ) -> SherpaNcnnRecognizerConfig {
  118. return SherpaNcnnRecognizerConfig(
  119. feat_config: featConfig,
  120. model_config: modelConfig,
  121. decoder_config: decoderConfig,
  122. enable_endpoint: enableEndpoint ? 1 : 0,
  123. rule1_min_trailing_silence: rule1MinTrailingSilence,
  124. rule2_min_trailing_silence: rule2MinTrailingSilence,
  125. rule3_min_utterance_length: rule3MinUtteranceLength,
  126. hotwords_file: toCPointer(hotwordsFile),
  127. hotwords_score: hotwordsScore)
  128. }
  129. /// Wrapper for recognition result.
  130. ///
  131. /// Usage:
  132. ///
  133. /// let result = recognizer.getResult()
  134. /// print("text: \(result.text)")
  135. ///
  136. class SherpaNcnnRecongitionResult {
  137. /// A pointer to the underlying counterpart in C
  138. let result: UnsafePointer<SherpaNcnnResult>!
  139. /// Return the actual recognition result.
  140. /// For English models, it contains words separated by spaces.
  141. /// For Chinese models, it contains Chinese words.
  142. var text: String {
  143. return String(cString: result.pointee.text)
  144. }
  145. init(result: UnsafePointer<SherpaNcnnResult>!) {
  146. self.result = result
  147. }
  148. deinit {
  149. if let result {
  150. DestroyResult(result)
  151. }
  152. }
  153. }
  154. class SherpaNcnnRecognizer {
  155. /// A pointer to the underlying counterpart in C
  156. let recognizer: OpaquePointer!
  157. let stream: OpaquePointer!
  158. /// Constructor taking a model config and a decoder config.
  159. init(
  160. config: UnsafePointer<SherpaNcnnRecognizerConfig>!
  161. ) {
  162. recognizer = CreateRecognizer(config)
  163. stream = CreateStream(recognizer)
  164. }
  165. deinit {
  166. if let stream {
  167. DestroyStream(stream)
  168. }
  169. if let recognizer {
  170. DestroyRecognizer(recognizer)
  171. }
  172. }
  173. /// Decode wave samples.
  174. ///
  175. /// - Parameters:
  176. /// - samples: Audio samples normalzed to the range [-1, 1]
  177. /// - sampleRate: Sample rate of the input audio samples. If it is
  178. /// different from featConfig.sampleRate, we will do
  179. /// resample. Caution: You cannot use a different
  180. /// sampleRate across different calls to
  181. /// AcceptWaveform().
  182. func acceptWaveform(samples: [Float], sampleRate: Float = 16000) {
  183. AcceptWaveform(stream, sampleRate, samples, Int32(samples.count))
  184. }
  185. func isReady() -> Bool {
  186. return IsReady(recognizer, stream) == 1
  187. }
  188. /// If there are enough number of feature frames, it invokes the neural
  189. /// network computation and decoding. Otherwise, it is a no-op.
  190. func decode() {
  191. Decode(recognizer, stream)
  192. }
  193. /// Get the decoding results so far
  194. func getResult() -> SherpaNcnnRecongitionResult {
  195. let result: UnsafeMutablePointer<SherpaNcnnResult>? = GetResult(recognizer, stream)
  196. return SherpaNcnnRecongitionResult(result: result)
  197. }
  198. /// Reset the recognizer, which clears the neural network model state
  199. /// and the state for decoding.
  200. func reset() {
  201. Reset(recognizer, stream)
  202. }
  203. /// Signal that no more audio samples would be available.
  204. /// After this call, you cannot call acceptWaveform() any more.
  205. func inputFinished() {
  206. InputFinished(stream)
  207. }
  208. /// Return true is an endpoint has been detected.
  209. func isEndpoint() -> Bool {
  210. return IsEndpoint(recognizer, stream) == 1
  211. }
  212. }