main.py
1 #!/usr/bin/env python 2 3 from io import BytesIO 4 import os 5 import logging 6 import functools 7 import argparse 8 9 from aiohttp import web 10 from aiohttp_cors import setup, ResourceOptions 11 12 from faster_whisper import WhisperModel 13 14 15 logger = logging.getLogger(__name__) 16 17 18 async def index(request: web.Request) -> web.Response: 19 import pathlib 20 21 # show api documentation 22 return web.FileResponse(pathlib.Path(__file__).parent.resolve().joinpath('index.html')) 23 24 25 async def transcribe_post(model: WhisperModel, request: web.Request) -> web.StreamResponse: 26 if request.headers["Content-Type"] != "audio/wav": 27 return web.Response(status=415, text="Unsupported Input Media Type") 28 29 wav_data = await request.read() 30 31 segments, info = model.transcribe(audio=BytesIO(wav_data), vad_filter=True) 32 33 logger.debug( 34 f"Detected language '{info.language}' with probability {info.language_probability}") 35 36 segments_result = list() 37 38 for segment in segments: 39 segments_result.append({ 40 'text': segment.text, 41 'start': segment.start, 42 'end': segment.end, 43 }) 44 45 ### OUTPUT can go anywhere besides back to sender... 46 ### -> 47 # return transcripted_text and correct_text as json 48 return web.json_response({'transcribed_segments': segments_result, 49 'language': info.language, }) 50 51 52 def setup_cors(app): 53 cors = setup(app, defaults={ 54 "*": ResourceOptions( 55 allow_credentials=True, 56 expose_headers="*", 57 allow_headers="*", 58 ) 59 }) 60 61 # Configure CORS on all routes 62 for route in list(app.router.routes()): 63 cors.add(route) 64 65 66 async def start_server(model: str, compute_type: str = 'default', cache_dir: str = None, device: str = 'cpu') -> web.Application: 67 model_path = f'{cache_dir}/{model}' 68 69 logger.info(f'Loading AI model {model_path} to {device}...') 70 71 if os.path.isdir(model_path): 72 model = WhisperModel(model_path, device=device, 73 compute_type=compute_type, download_root=model_path) 74 else: 75 model = WhisperModel(model, device=device, 76 compute_type=compute_type, download_root=model_path) 77 78 app = web.Application() 79 80 # call handle_request with tts as first argument 81 app.add_routes([ 82 web.get('/', handler=index), 83 web.post('/transcribe', handler=functools.partial(transcribe_post, model)) 84 ]) 85 86 # Set up CORS 87 setup_cors(app) 88 89 return app 90 91 92 if __name__ == '__main__': 93 parser = argparse.ArgumentParser( 94 description='An AI voice to text transcription server') 95 96 parser.add_argument('-p', '--port', type=int, 97 default=3157, help='Port to listen on') 98 parser.add_argument('-m', '--model', 99 help='Model name, see https://github.com/openai/whisper#available-models-and-languages', 100 default='medium') 101 parser.add_argument('-t', '--compute-type', type=str, 102 help='default, float16, int8', default='default') 103 parser.add_argument('-d', '--model-dir', 104 type=str, 105 help='Path to model directory', 106 default='models') 107 parser.add_argument('-c', '--device', type=str, default='cpu', help='torch device to use') 108 args = parser.parse_args() 109 110 logger.info(f'Starting server at http://localhost:{args.port}/') 111 112 web.run_app(start_server(args.model, args.compute_type, 113 args.model_dir, args.device), port=args.port)