/ stt / main.py
main.py
  1  #!/usr/bin/env python
  2  
  3  from io import BytesIO
  4  import os
  5  import logging
  6  import functools
  7  import argparse
  8  
  9  from aiohttp import web
 10  from aiohttp_cors import setup, ResourceOptions
 11  
 12  from faster_whisper import WhisperModel
 13  
 14  
 15  logger = logging.getLogger(__name__)
 16  
 17  
 18  async def index(request: web.Request) -> web.Response:
 19      import pathlib
 20  
 21      # show api documentation
 22      return web.FileResponse(pathlib.Path(__file__).parent.resolve().joinpath('index.html'))
 23  
 24  
 25  async def transcribe_post(model: WhisperModel, request: web.Request) -> web.StreamResponse:
 26      if request.headers["Content-Type"] != "audio/wav":
 27          return web.Response(status=415, text="Unsupported Input Media Type")
 28  
 29      wav_data = await request.read()
 30  
 31      segments, info = model.transcribe(audio=BytesIO(wav_data), vad_filter=True)
 32  
 33      logger.debug(
 34          f"Detected language '{info.language}' with probability {info.language_probability}")
 35  
 36      segments_result = list()
 37  
 38      for segment in segments:
 39          segments_result.append({
 40              'text': segment.text,
 41              'start': segment.start,
 42              'end': segment.end,
 43          })
 44  
 45  ### OUTPUT can go anywhere besides back to sender...
 46  ### ->
 47      # return transcripted_text and correct_text as json
 48      return web.json_response({'transcribed_segments': segments_result,
 49                                'language': info.language, })
 50  
 51  
 52  def setup_cors(app):
 53      cors = setup(app, defaults={
 54          "*": ResourceOptions(
 55              allow_credentials=True,
 56              expose_headers="*",
 57              allow_headers="*",
 58          )
 59      })
 60  
 61      # Configure CORS on all routes
 62      for route in list(app.router.routes()):
 63          cors.add(route)
 64  
 65  
 66  async def start_server(model: str, compute_type: str = 'default', cache_dir: str = None, device: str = 'cpu') -> web.Application:
 67      model_path = f'{cache_dir}/{model}'
 68  
 69      logger.info(f'Loading AI model {model_path} to {device}...')
 70  
 71      if os.path.isdir(model_path):
 72          model = WhisperModel(model_path, device=device,
 73                               compute_type=compute_type, download_root=model_path)
 74      else:
 75          model = WhisperModel(model, device=device,
 76                               compute_type=compute_type, download_root=model_path)
 77  
 78      app = web.Application()
 79  
 80      # call handle_request with tts as first argument
 81      app.add_routes([
 82          web.get('/', handler=index),
 83          web.post('/transcribe', handler=functools.partial(transcribe_post, model))
 84      ])
 85      
 86      # Set up CORS
 87      setup_cors(app)
 88      
 89      return app
 90  
 91  
 92  if __name__ == '__main__':
 93      parser = argparse.ArgumentParser(
 94          description='An AI voice to text transcription server')
 95  
 96      parser.add_argument('-p', '--port', type=int,
 97                          default=3157, help='Port to listen on')
 98      parser.add_argument('-m', '--model',
 99                          help='Model name, see https://github.com/openai/whisper#available-models-and-languages',
100                          default='medium')
101      parser.add_argument('-t', '--compute-type', type=str,
102                          help='default, float16, int8', default='default')
103      parser.add_argument('-d', '--model-dir',
104                          type=str,
105                          help='Path to model directory',
106                          default='models')
107      parser.add_argument('-c', '--device', type=str, default='cpu', help='torch device to use')
108      args = parser.parse_args()
109  
110      logger.info(f'Starting server at http://localhost:{args.port}/')
111  
112      web.run_app(start_server(args.model, args.compute_type,
113                  args.model_dir, args.device), port=args.port)