/ restai / routers / speech_to_text.py
speech_to_text.py
  1  """Speech-to-text registry CRUD endpoints.
  2  
  3  Mirrors `restai/routers/image_generators.py`. The registry holds:
  4  
  5  - **Local** workers (auto-seeded on startup from `restai/audio/workers/*`).
  6    Always selectable; admin can flip `enabled` and rename for display, but
  7    cannot delete (re-seeded next boot).
  8  - **External** providers — `openai` (Whisper API + OpenAI-compat via
  9    `options.base_url`), `google`, `deepgram`, `assemblyai`. Created freely
 10    by the admin with per-row encrypted credentials.
 11  
 12  API key fields in `options` are masked as `"********"` on read; PATCH
 13  preserves the existing value when it sees that sentinel.
 14  """
 15  import json
 16  import logging
 17  import traceback
 18  from typing import Optional
 19  
 20  from fastapi import APIRouter, Depends, HTTPException, Path
 21  
 22  from restai import config
 23  from restai.auth import get_current_username, get_current_username_admin
 24  from restai.database import DBWrapper, get_db_wrapper
 25  from restai.models.databasemodels import SpeechToTextDatabase
 26  from restai.models.models import (
 27      SpeechToTextModel,
 28      SpeechToTextModelCreate,
 29      SpeechToTextModelUpdate,
 30      User,
 31  )
 32  
 33  logging.basicConfig(level=config.LOG_LEVEL)
 34  
 35  router = APIRouter()
 36  
 37  
 38  _SENSITIVE_OPT_KEYS = {"api_key", "key", "password", "secret"}
 39  
 40  
 41  def _mask_options(options: Optional[dict]) -> Optional[dict]:
 42      if not options:
 43          return options
 44      try:
 45          masked = dict(options)
 46          for k in _SENSITIVE_OPT_KEYS:
 47              if k in masked and masked[k]:
 48                  masked[k] = "********"
 49          return masked
 50      except Exception:
 51          return options
 52  
 53  
 54  @router.get("/speech_to_text", response_model=list[SpeechToTextModel])
 55  async def list_speech_to_text(
 56      user: User = Depends(get_current_username),
 57      db_wrapper: DBWrapper = Depends(get_db_wrapper),
 58  ):
 59      """List speech-to-text models. Non-admins see only those granted to a
 60      team they're a member of (via `teams_audio_generators`)."""
 61      rows = db_wrapper.get_speech_to_text()
 62  
 63      if not user.is_admin:
 64          allowed_names = set()
 65          for team in user.teams or []:
 66              for ag in (team.audio_generators or []):
 67                  allowed_names.add(getattr(ag, "generator_name", ag))
 68          rows = [r for r in rows if r.name in allowed_names]
 69  
 70      out: list[SpeechToTextModel] = []
 71      for r in rows:
 72          m = SpeechToTextModel.model_validate(r)
 73          m.options = _mask_options(m.options)
 74          out.append(m)
 75      return out
 76  
 77  
 78  @router.get("/speech_to_text/{model_id}", response_model=SpeechToTextModel)
 79  async def get_speech_to_text(
 80      model_id: int = Path(description="Speech-to-text model ID"),
 81      _: User = Depends(get_current_username),
 82      db_wrapper: DBWrapper = Depends(get_db_wrapper),
 83  ):
 84      row = db_wrapper.get_speech_to_text_by_id(model_id)
 85      if row is None:
 86          raise HTTPException(status_code=404, detail="Speech-to-text model not found")
 87      m = SpeechToTextModel.model_validate(row)
 88      m.options = _mask_options(m.options)
 89      return m
 90  
 91  
 92  @router.post("/speech_to_text", status_code=201, response_model=SpeechToTextModel)
 93  async def create_speech_to_text(
 94      body: SpeechToTextModelCreate,
 95      _: User = Depends(get_current_username_admin),
 96      db_wrapper: DBWrapper = Depends(get_db_wrapper),
 97  ):
 98      """Register a new STT model (admin only)."""
 99      if db_wrapper.get_speech_to_text_by_name(body.name):
100          raise HTTPException(status_code=409, detail=f"Speech-to-text model '{body.name}' already exists")
101      if body.class_name == "local":
102          raise HTTPException(
103              status_code=400,
104              detail="Local models are auto-discovered from restai/audio/workers/*; you can't create them manually.",
105          )
106      try:
107          opts = body.options if isinstance(body.options, dict) else (json.loads(body.options) if body.options else {})
108          row = db_wrapper.create_speech_to_text(
109              name=body.name,
110              class_name=body.class_name,
111              options=opts,
112              privacy=body.privacy,
113              description=body.description,
114              enabled=body.enabled,
115          )
116          m = SpeechToTextModel.model_validate(row)
117          m.options = _mask_options(m.options)
118          return m
119      except HTTPException:
120          raise
121      except Exception as e:
122          logging.error(e)
123          traceback.print_tb(e.__traceback__)
124          raise HTTPException(status_code=500, detail=f"Failed to create speech-to-text model '{body.name}'")
125  
126  
127  @router.patch("/speech_to_text/{model_id}", response_model=SpeechToTextModel)
128  async def update_speech_to_text(
129      model_id: int = Path(description="Speech-to-text model ID"),
130      body: SpeechToTextModelUpdate = ...,
131      _: User = Depends(get_current_username_admin),
132      db_wrapper: DBWrapper = Depends(get_db_wrapper),
133  ):
134      """Update a speech-to-text model (admin only). Local rows ignore
135      provider/options changes — those come from the worker module."""
136      row: Optional[SpeechToTextDatabase] = db_wrapper.get_speech_to_text_by_id(model_id)
137      if row is None:
138          raise HTTPException(status_code=404, detail="Speech-to-text model not found")
139  
140      if row.class_name == "local":
141          body.class_name = None
142          body.options = None
143  
144      db_wrapper.edit_speech_to_text(row, body)
145      m = SpeechToTextModel.model_validate(row)
146      m.options = _mask_options(m.options)
147      return m
148  
149  
150  @router.delete("/speech_to_text/{model_id}")
151  async def delete_speech_to_text(
152      model_id: int = Path(description="Speech-to-text model ID"),
153      _: User = Depends(get_current_username_admin),
154      db_wrapper: DBWrapper = Depends(get_db_wrapper),
155  ):
156      """Delete a speech-to-text model (admin only). Local models cannot be
157      deleted — disable them via `enabled=false` instead."""
158      row: Optional[SpeechToTextDatabase] = db_wrapper.get_speech_to_text_by_id(model_id)
159      if row is None:
160          raise HTTPException(status_code=404, detail="Speech-to-text model not found")
161      if row.class_name == "local":
162          raise HTTPException(
163              status_code=400,
164              detail=f"Cannot delete local model '{row.name}'. Set enabled=false instead.",
165          )
166      name = row.name
167      db_wrapper.delete_speech_to_text(row)
168      return {"deleted": name}