dispatch.py
1 """Speech-to-text dispatch — look up a model by name in the registry and 2 route to its provider. 3 4 Callers: 5 - `restai/routers/audio.py` — `POST /audio/{generator}/transcript`. 6 7 The dispatch takes a path to an audio file on disk (the router writes the 8 upload to a tempfile, ffmpeg-converts to mp3 if needed, and hands us the 9 final path). External providers (OpenAI, Google, Deepgram, AssemblyAI) 10 read the bytes; the local branch hands the path to the existing 11 `restai.audio.runner.generate` which spawns a worker subprocess. 12 """ 13 from __future__ import annotations 14 15 import json 16 import logging 17 import os 18 19 from restai.models.databasemodels import SpeechToTextDatabase 20 21 logger = logging.getLogger(__name__) 22 23 24 class UnknownModelError(Exception): 25 """Raised when no enabled STT model matches the requested name.""" 26 27 28 class ModelDisabledError(Exception): 29 """Raised when a matching model exists but `enabled=False`.""" 30 31 32 def _load_options(row: SpeechToTextDatabase) -> dict: 33 from restai.utils.crypto import decrypt_sensitive_options, LLM_SENSITIVE_KEYS 34 35 try: 36 raw = json.loads(row.options) if row.options else {} 37 except Exception: 38 raw = {} 39 if isinstance(raw, dict): 40 try: 41 raw = decrypt_sensitive_options(raw, LLM_SENSITIVE_KEYS) 42 except Exception: 43 pass 44 return raw if isinstance(raw, dict) else {} 45 46 47 def list_available_stt_models(db_wrapper) -> list[str]: 48 rows = db_wrapper.get_speech_to_text() 49 return [r.name for r in rows if r.enabled] 50 51 52 def transcribe_audio( 53 name: str, 54 audio_path: str, 55 filename: str, 56 language: str | None, 57 brain, 58 db_wrapper, 59 ) -> str: 60 """Resolve `name` to a model row and run transcription on the file at 61 `audio_path`. Returns the transcript string.""" 62 row = db_wrapper.get_speech_to_text_by_name(name) 63 if row is None: 64 raise UnknownModelError(name) 65 if not row.enabled: 66 raise ModelDisabledError(name) 67 68 options = _load_options(row) 69 70 # External providers operate on raw bytes — slurp the file once. 71 if row.class_name in ("openai", "google", "deepgram", "assemblyai"): 72 with open(audio_path, "rb") as f: 73 audio_bytes = f.read() 74 if row.class_name == "openai": 75 from restai.speech_to_text.providers.openai import transcribe as _t 76 elif row.class_name == "google": 77 from restai.speech_to_text.providers.google import transcribe as _t 78 elif row.class_name == "deepgram": 79 from restai.speech_to_text.providers.deepgram import transcribe as _t 80 else: 81 from restai.speech_to_text.providers.assemblyai import transcribe as _t 82 return _t(options, audio_bytes, filename, language) 83 84 if row.class_name == "local": 85 manager = getattr(brain, "audio_manager", None) or getattr(brain, "image_manager", None) 86 generators = brain.get_audio_generators([name]) if hasattr(brain, "get_audio_generators") else [] 87 if not generators: 88 raise UnknownModelError(name) 89 if manager is None: 90 raise RuntimeError( 91 f"Local STT model '{name}' needs the torch multiprocessing manager " 92 "(GPU mode). Start the API with RESTAI_GPU=true." 93 ) 94 # The legacy runner expects a FastAPI UploadFile. Wrap the file 95 # path in a minimal stand-in so we don't have to rewrite it. 96 from restai.audio.runner import generate as _runner 97 98 class _DiskUpload: 99 def __init__(self, path: str, name: str): 100 self.filename = name 101 self.file = open(path, "rb") 102 103 def __del__(self): 104 try: 105 self.file.close() 106 except Exception: 107 pass 108 109 upload = _DiskUpload(audio_path, filename or os.path.basename(audio_path)) 110 try: 111 return _runner(manager, generators[0], language or "", upload) 112 finally: 113 try: 114 upload.file.close() 115 except Exception: 116 pass 117 118 raise UnknownModelError(name)