speech_to_text.py
1 """Speech-to-text registry CRUD endpoints. 2 3 Mirrors `restai/routers/image_generators.py`. The registry holds: 4 5 - **Local** workers (auto-seeded on startup from `restai/audio/workers/*`). 6 Always selectable; admin can flip `enabled` and rename for display, but 7 cannot delete (re-seeded next boot). 8 - **External** providers — `openai` (Whisper API + OpenAI-compat via 9 `options.base_url`), `google`, `deepgram`, `assemblyai`. Created freely 10 by the admin with per-row encrypted credentials. 11 12 API key fields in `options` are masked as `"********"` on read; PATCH 13 preserves the existing value when it sees that sentinel. 14 """ 15 import json 16 import logging 17 import traceback 18 from typing import Optional 19 20 from fastapi import APIRouter, Depends, HTTPException, Path 21 22 from restai import config 23 from restai.auth import get_current_username, get_current_username_admin 24 from restai.database import DBWrapper, get_db_wrapper 25 from restai.models.databasemodels import SpeechToTextDatabase 26 from restai.models.models import ( 27 SpeechToTextModel, 28 SpeechToTextModelCreate, 29 SpeechToTextModelUpdate, 30 User, 31 ) 32 33 logging.basicConfig(level=config.LOG_LEVEL) 34 35 router = APIRouter() 36 37 38 _SENSITIVE_OPT_KEYS = {"api_key", "key", "password", "secret"} 39 40 41 def _mask_options(options: Optional[dict]) -> Optional[dict]: 42 if not options: 43 return options 44 try: 45 masked = dict(options) 46 for k in _SENSITIVE_OPT_KEYS: 47 if k in masked and masked[k]: 48 masked[k] = "********" 49 return masked 50 except Exception: 51 return options 52 53 54 @router.get("/speech_to_text", response_model=list[SpeechToTextModel]) 55 async def list_speech_to_text( 56 user: User = Depends(get_current_username), 57 db_wrapper: DBWrapper = Depends(get_db_wrapper), 58 ): 59 """List speech-to-text models. Non-admins see only those granted to a 60 team they're a member of (via `teams_audio_generators`).""" 61 rows = db_wrapper.get_speech_to_text() 62 63 if not user.is_admin: 64 allowed_names = set() 65 for team in user.teams or []: 66 for ag in (team.audio_generators or []): 67 allowed_names.add(getattr(ag, "generator_name", ag)) 68 rows = [r for r in rows if r.name in allowed_names] 69 70 out: list[SpeechToTextModel] = [] 71 for r in rows: 72 m = SpeechToTextModel.model_validate(r) 73 m.options = _mask_options(m.options) 74 out.append(m) 75 return out 76 77 78 @router.get("/speech_to_text/{model_id}", response_model=SpeechToTextModel) 79 async def get_speech_to_text( 80 model_id: int = Path(description="Speech-to-text model ID"), 81 _: User = Depends(get_current_username), 82 db_wrapper: DBWrapper = Depends(get_db_wrapper), 83 ): 84 row = db_wrapper.get_speech_to_text_by_id(model_id) 85 if row is None: 86 raise HTTPException(status_code=404, detail="Speech-to-text model not found") 87 m = SpeechToTextModel.model_validate(row) 88 m.options = _mask_options(m.options) 89 return m 90 91 92 @router.post("/speech_to_text", status_code=201, response_model=SpeechToTextModel) 93 async def create_speech_to_text( 94 body: SpeechToTextModelCreate, 95 _: User = Depends(get_current_username_admin), 96 db_wrapper: DBWrapper = Depends(get_db_wrapper), 97 ): 98 """Register a new STT model (admin only).""" 99 if db_wrapper.get_speech_to_text_by_name(body.name): 100 raise HTTPException(status_code=409, detail=f"Speech-to-text model '{body.name}' already exists") 101 if body.class_name == "local": 102 raise HTTPException( 103 status_code=400, 104 detail="Local models are auto-discovered from restai/audio/workers/*; you can't create them manually.", 105 ) 106 try: 107 opts = body.options if isinstance(body.options, dict) else (json.loads(body.options) if body.options else {}) 108 row = db_wrapper.create_speech_to_text( 109 name=body.name, 110 class_name=body.class_name, 111 options=opts, 112 privacy=body.privacy, 113 description=body.description, 114 enabled=body.enabled, 115 ) 116 m = SpeechToTextModel.model_validate(row) 117 m.options = _mask_options(m.options) 118 return m 119 except HTTPException: 120 raise 121 except Exception as e: 122 logging.error(e) 123 traceback.print_tb(e.__traceback__) 124 raise HTTPException(status_code=500, detail=f"Failed to create speech-to-text model '{body.name}'") 125 126 127 @router.patch("/speech_to_text/{model_id}", response_model=SpeechToTextModel) 128 async def update_speech_to_text( 129 model_id: int = Path(description="Speech-to-text model ID"), 130 body: SpeechToTextModelUpdate = ..., 131 _: User = Depends(get_current_username_admin), 132 db_wrapper: DBWrapper = Depends(get_db_wrapper), 133 ): 134 """Update a speech-to-text model (admin only). Local rows ignore 135 provider/options changes — those come from the worker module.""" 136 row: Optional[SpeechToTextDatabase] = db_wrapper.get_speech_to_text_by_id(model_id) 137 if row is None: 138 raise HTTPException(status_code=404, detail="Speech-to-text model not found") 139 140 if row.class_name == "local": 141 body.class_name = None 142 body.options = None 143 144 db_wrapper.edit_speech_to_text(row, body) 145 m = SpeechToTextModel.model_validate(row) 146 m.options = _mask_options(m.options) 147 return m 148 149 150 @router.delete("/speech_to_text/{model_id}") 151 async def delete_speech_to_text( 152 model_id: int = Path(description="Speech-to-text model ID"), 153 _: User = Depends(get_current_username_admin), 154 db_wrapper: DBWrapper = Depends(get_db_wrapper), 155 ): 156 """Delete a speech-to-text model (admin only). Local models cannot be 157 deleted — disable them via `enabled=false` instead.""" 158 row: Optional[SpeechToTextDatabase] = db_wrapper.get_speech_to_text_by_id(model_id) 159 if row is None: 160 raise HTTPException(status_code=404, detail="Speech-to-text model not found") 161 if row.class_name == "local": 162 raise HTTPException( 163 status_code=400, 164 detail=f"Cannot delete local model '{row.name}'. Set enabled=false instead.", 165 ) 166 name = row.name 167 db_wrapper.delete_speech_to_text(row) 168 return {"deleted": name}