AudioSer.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import logging
  2. import mimetypes
  3. import os
  4. import shlex
  5. import shutil
  6. import subprocess
  7. import uuid
  8. import wave
  9. import numpy as np
  10. import sherpa_ncnn
  11. from flask import Flask, jsonify, render_template, request
  12. from flask_caching import Cache
  13. from config import *
  14. Server = Flask(__name__)
  15. Server.config['VO_UPLOAD_FOLDER'] = VO_UPLOAD_FOLDER
  16. Server.config['CACHE_TYPE'] = 'simple'
  17. Server.config['CACHE_DEFAULT_TIMEOUT'] = 1
  18. cache = Cache(Server)
  19. recognizer = sherpa_ncnn.Recognizer(
  20. tokens=TOKENS, encoder_param=ENCODER_PARMA,
  21. encoder_bin=ENCODER_BIN,decoder_param=DECODER_PARAM,
  22. decoder_bin=DECODER_BIN, joiner_param=JOINER_PARAM,
  23. joiner_bin=JOINER_BIN, num_threads=NUM_THREADS
  24. )
  25. def rewrite(input_file, output_file):
  26. shutil.copy(input_file, output_file)
  27. @cache.memoize()
  28. def Voice_recognition(filename):
  29. with wave.open(filename, 'rb') as f:
  30. if f.getsampwidth() != 2:
  31. raise ValueError(
  32. f"Invalid sample width: {f.getsampwidth()}, expected 2. File: {filename}")
  33. sample_rate = f.getframerate()
  34. num_channels = f.getnchannels()
  35. num_samples = f.getnframes()
  36. samples = f.readframes(num_samples)
  37. samples_int16 = np.frombuffer(samples, dtype=np.int16)
  38. samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0]
  39. samples_float32 = samples_int16.astype(np.float32)
  40. samples_float32 /= 32768
  41. recognizer.accept_waveform(sample_rate, samples_float32)
  42. tail_paddings = np.zeros(int(sample_rate * 0.5), dtype=np.float32)
  43. recognizer.accept_waveform(sample_rate, tail_paddings)
  44. res1 = recognizer.text.lower()
  45. recognizer.stream = recognizer.recognizer.create_stream()
  46. return res1
  47. def configure_app():
  48. if not os.path.exists(VO_UPLOAD_FOLDER):
  49. os.makedirs(VO_UPLOAD_FOLDER)
  50. cache.init_app(Server)
  51. def configure_log():
  52. logging.basicConfig(level=logging.INFO, filename='./cache/log/server.log',
  53. format='%(levelname)s:%(asctime)s %(message)s')
  54. def allowed_file(filename):
  55. if '.' not in filename:
  56. return False
  57. ext = filename.rsplit('.', 1)[1].lower()
  58. if ext not in ALLOWED_EXTENSIONS:
  59. return False
  60. mime_type, _ = mimetypes.guess_type(filename)
  61. if mime_type is None or mime_type not in ['audio/wav', 'audio/x-wav']:
  62. return False
  63. return True
  64. def check_type(mode):
  65. if 'file' not in request.files:
  66. raise ValueError('No file part.')
  67. file = request.files['file']
  68. if file.filename == '' or not allowed_file(file.filename):
  69. raise ValueError('Please upload a .wav file.')
  70. filename = str(uuid.uuid4()) + '.wav'
  71. filepath = os.path.join(Server.config[mode], filename)
  72. file.save(filepath)
  73. output_filepath = os.path.join(Server.config[mode], 'output_' + filename)
  74. return filepath, output_filepath
  75. @Server.route('/voice', methods=['POST'])
  76. def upload_file():
  77. try:
  78. if request.method == 'POST':
  79. filepath, output_filepath = check_type('VO_UPLOAD_FOLDER')
  80. rewrite(filepath, output_filepath)
  81. result = Voice_recognition(output_filepath)
  82. if os.path.exists(filepath):
  83. os.remove(filepath)
  84. if os.path.exists(output_filepath):
  85. os.remove(output_filepath)
  86. return jsonify({
  87. 'status': 200,
  88. 'message': result
  89. })
  90. except ValueError as e:
  91. return jsonify({
  92. 'status': 400,
  93. 'message': str(e)
  94. })
  95. except Exception as e:
  96. if os.path.exists(filepath):
  97. os.remove(filepath)
  98. if os.path.exists(output_filepath):
  99. os.remove(output_filepath)
  100. logging.error(f"Recognition error: {e}")
  101. return jsonify({
  102. 'status': 500,
  103. 'message': 'Error, Please try again later.'
  104. })
  105. @Server.route('/', methods=['GET'])
  106. def index():
  107. return render_template('index.html')
  108. if __name__ == '__main__':
  109. configure_app()
  110. configure_log()
  111. print(f" * Running on http://{HOST[0]}:{HOST[1]}")
  112. Server.run(host=HOST[0], port=HOST[1], debug=False)