/ restai / speech_to_text / dispatch.py
dispatch.py
  1  """Speech-to-text dispatch — look up a model by name in the registry and
  2  route to its provider.
  3  
  4  Callers:
  5  - `restai/routers/audio.py` — `POST /audio/{generator}/transcript`.
  6  
  7  The dispatch takes a path to an audio file on disk (the router writes the
  8  upload to a tempfile, ffmpeg-converts to mp3 if needed, and hands us the
  9  final path). External providers (OpenAI, Google, Deepgram, AssemblyAI)
 10  read the bytes; the local branch hands the path to the existing
 11  `restai.audio.runner.generate` which spawns a worker subprocess.
 12  """
 13  from __future__ import annotations
 14  
 15  import json
 16  import logging
 17  import os
 18  
 19  from restai.models.databasemodels import SpeechToTextDatabase
 20  
 21  logger = logging.getLogger(__name__)
 22  
 23  
 24  class UnknownModelError(Exception):
 25      """Raised when no enabled STT model matches the requested name."""
 26  
 27  
 28  class ModelDisabledError(Exception):
 29      """Raised when a matching model exists but `enabled=False`."""
 30  
 31  
 32  def _load_options(row: SpeechToTextDatabase) -> dict:
 33      from restai.utils.crypto import decrypt_sensitive_options, LLM_SENSITIVE_KEYS
 34  
 35      try:
 36          raw = json.loads(row.options) if row.options else {}
 37      except Exception:
 38          raw = {}
 39      if isinstance(raw, dict):
 40          try:
 41              raw = decrypt_sensitive_options(raw, LLM_SENSITIVE_KEYS)
 42          except Exception:
 43              pass
 44      return raw if isinstance(raw, dict) else {}
 45  
 46  
 47  def list_available_stt_models(db_wrapper) -> list[str]:
 48      rows = db_wrapper.get_speech_to_text()
 49      return [r.name for r in rows if r.enabled]
 50  
 51  
 52  def transcribe_audio(
 53      name: str,
 54      audio_path: str,
 55      filename: str,
 56      language: str | None,
 57      brain,
 58      db_wrapper,
 59  ) -> str:
 60      """Resolve `name` to a model row and run transcription on the file at
 61      `audio_path`. Returns the transcript string."""
 62      row = db_wrapper.get_speech_to_text_by_name(name)
 63      if row is None:
 64          raise UnknownModelError(name)
 65      if not row.enabled:
 66          raise ModelDisabledError(name)
 67  
 68      options = _load_options(row)
 69  
 70      # External providers operate on raw bytes — slurp the file once.
 71      if row.class_name in ("openai", "google", "deepgram", "assemblyai"):
 72          with open(audio_path, "rb") as f:
 73              audio_bytes = f.read()
 74          if row.class_name == "openai":
 75              from restai.speech_to_text.providers.openai import transcribe as _t
 76          elif row.class_name == "google":
 77              from restai.speech_to_text.providers.google import transcribe as _t
 78          elif row.class_name == "deepgram":
 79              from restai.speech_to_text.providers.deepgram import transcribe as _t
 80          else:
 81              from restai.speech_to_text.providers.assemblyai import transcribe as _t
 82          return _t(options, audio_bytes, filename, language)
 83  
 84      if row.class_name == "local":
 85          manager = getattr(brain, "audio_manager", None) or getattr(brain, "image_manager", None)
 86          generators = brain.get_audio_generators([name]) if hasattr(brain, "get_audio_generators") else []
 87          if not generators:
 88              raise UnknownModelError(name)
 89          if manager is None:
 90              raise RuntimeError(
 91                  f"Local STT model '{name}' needs the torch multiprocessing manager "
 92                  "(GPU mode). Start the API with RESTAI_GPU=true."
 93              )
 94          # The legacy runner expects a FastAPI UploadFile. Wrap the file
 95          # path in a minimal stand-in so we don't have to rewrite it.
 96          from restai.audio.runner import generate as _runner
 97  
 98          class _DiskUpload:
 99              def __init__(self, path: str, name: str):
100                  self.filename = name
101                  self.file = open(path, "rb")
102  
103              def __del__(self):
104                  try:
105                      self.file.close()
106                  except Exception:
107                      pass
108  
109          upload = _DiskUpload(audio_path, filename or os.path.basename(audio_path))
110          try:
111              return _runner(manager, generators[0], language or "", upload)
112          finally:
113              try:
114                  upload.file.close()
115              except Exception:
116                  pass
117  
118      raise UnknownModelError(name)