Переглянути джерело

Support multi-channel waves in examples. (#174)

Fangjun Kuang 2 роки тому
батько
коміт
fa80f822d7

+ 22 - 24
python-api-examples/AudioSer/AudioSer.py

@@ -1,14 +1,18 @@
-from config import *
-import wave, os, logging
-import sherpa_ncnn, uuid
-import shlex, subprocess, mimetypes
+import logging
+import mimetypes
+import os
+import shlex
+import shutil
+import subprocess
+import uuid
+import wave
+
 import numpy as np
+import sherpa_ncnn
+from flask import Flask, jsonify, render_template, request
 from flask_caching import Cache
-from flask import (
-    Flask, request, jsonify,
-    render_template
-)
 
+from config import *
 
 Server = Flask(__name__)
 
@@ -18,43 +22,37 @@ Server.config['CACHE_DEFAULT_TIMEOUT'] = 1
 cache = Cache(Server)
 
 recognizer = sherpa_ncnn.Recognizer(
-    tokens=TOKENS, encoder_param=ENCODER_PARMA, 
-    encoder_bin=ENCODER_BIN,decoder_param=DECODER_PARAM, 
+    tokens=TOKENS, encoder_param=ENCODER_PARMA,
+    encoder_bin=ENCODER_BIN,decoder_param=DECODER_PARAM,
     decoder_bin=DECODER_BIN, joiner_param=JOINER_PARAM,
     joiner_bin=JOINER_BIN, num_threads=NUM_THREADS
 )
 
 def rewrite(input_file, output_file):
-    command = ["./sox/ffmpeg", "-i", shlex.quote(input_file),
-               "-ar", "16000", shlex.quote(output_file), "-y"]
-    subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+    shutil.copy(input_file, output_file)
 
 
 @cache.memoize()
 def Voice_recognition(filename):
     with wave.open(filename, 'rb') as f:
-        if f.getframerate() != recognizer.sample_rate:
-            raise ValueError(
-                f"Invalid sample rate: {f.getframerate()}, expected {recognizer.sample_rate}. File: {filename}")
-        if f.getnchannels() != 1:
-            raise ValueError(
-                f"Invalid number of channels: {f.getnchannels()}, expected 1. File: {filename}")
         if f.getsampwidth() != 2:
             raise ValueError(
                 f"Invalid sample width: {f.getsampwidth()}, expected 2. File: {filename}")
 
+        sample_rate = f.getframerate()
+        num_channels = f.getnchannels()
         num_samples = f.getnframes()
         samples = f.readframes(num_samples)
         samples_int16 = np.frombuffer(samples, dtype=np.int16)
+        samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0]
         samples_float32 = samples_int16.astype(np.float32)
         samples_float32 /= 32768
 
-    recognizer.accept_waveform(recognizer.sample_rate, samples_float32)
-    tail_paddings = np.zeros(
-        int(recognizer.sample_rate * 0.5), dtype=np.float32)
-    recognizer.accept_waveform(recognizer.sample_rate, tail_paddings)
+    recognizer.accept_waveform(sample_rate, samples_float32)
+    tail_paddings = np.zeros(int(sample_rate * 0.5), dtype=np.float32)
+    recognizer.accept_waveform(sample_rate, tail_paddings)
     res1 = recognizer.text.lower()
-    recognizer.reset()
+    recognizer.stream = recognizer.recognizer.create_stream()
     return res1
 
 

+ 2 - 0
python-api-examples/AudioSer/cache/.gitignore

@@ -0,0 +1,2 @@
+server.log
+*.wav

+ 2 - 1
python-api-examples/decode-file.py

@@ -36,11 +36,12 @@ def main():
         # Note: If wave_file_sample_rate is different from
         # recognizer.sample_rate, we will do resampling inside sherpa-ncnn
         wave_file_sample_rate = f.getframerate()
-        assert f.getnchannels() == 1, f.getnchannels()
+        num_channels = f.getnchannels()
         assert f.getsampwidth() == 2, f.getsampwidth()  # it is in bytes
         num_samples = f.getnframes()
         samples = f.readframes(num_samples)
         samples_int16 = np.frombuffer(samples, dtype=np.int16)
+        samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0]
         samples_float32 = samples_int16.astype(np.float32)
 
         samples_float32 = samples_float32 / 32768