transcription_router.py
1 """ 2 Router for audio and video transcription features 3 -------------------------------------------------------------- 4 This module implements routes for transcribing audio and video files, 5 with or without speaker identification. 6 """ 7 8 import os 9 import logging 10 import time 11 import traceback 12 from typing import Optional, Dict, Any, List 13 from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks, Depends, Request 14 from fastapi.responses import JSONResponse 15 from pydantic import BaseModel 16 17 # Import response models 18 from .response_models import ( 19 TranscriptionResponse, 20 ErrorResponse, 21 SuccessResponse, 22 TaskResponse, 23 TaskStatusResponse 24 ) 25 26 # Import transcription functions 27 from transcription_models import ( 28 process_monologue, 29 process_multiple_speakers, 30 transcribe_external_audio, 31 get_available_models, 32 analyze_transcript 33 ) 34 35 # Import for authentication 36 from auth import get_current_active_user, User 37 38 # Import configuration 39 from config import api_config, model_config, system_prompts 40 41 # Import task manager 42 from inference_engine import ( 43 TaskType, 44 create_task, 45 update_task, 46 get_task_status, 47 ProgressTracker 48 ) 49 50 # Import prompt manager 51 from utils.prompt_manager import get_prompt_manager 52 53 # Logging configuration 54 logger = logging.getLogger("api.transcription") 55 56 # Create router 57 transcription_router = APIRouter( 58 prefix="/transcription", 59 tags=["Transcription"], 60 responses={ 61 400: {"model": ErrorResponse, "description": "Invalid request"}, 62 401: {"model": ErrorResponse, "description": "Unauthorized"}, 63 404: {"model": ErrorResponse, "description": "File not found"}, 64 500: {"model": ErrorResponse, "description": "Server error"} 65 } 66 ) 67 68 # Configuration 69 ALLOWED_EXTENSIONS = {'mp4', 'mov', 'avi', 'mkv', 'webm', 'mp3', 'wav', 'ogg', 'flac', 'm4a'} 70 UPLOAD_FOLDER = 'uploads' 71 RESULTS_FOLDER = 'results/transcriptions' 72 os.makedirs(UPLOAD_FOLDER, exist_ok=True) 73 os.makedirs(RESULTS_FOLDER, exist_ok=True) 74 75 # Pydantic models for responses 76 class ModelInfo(BaseModel): 77 name: str 78 description: str 79 languages: List[str] 80 size_mb: float 81 is_multilingual: bool 82 83 class ModelsResponse(BaseModel): 84 whisper: Dict[str, ModelInfo] 85 diarization: Optional[Dict[str, ModelInfo]] = None 86 87 class AnalysisRequest(BaseModel): 88 transcription: str 89 language: Optional[str] = None 90 analysis_type: Optional[str] = "general" 91 92 class TranscriptionAnalysisResponse(BaseModel): 93 analysis: Dict[str, Any] 94 plain_explanation: Optional[str] = None 95 message: str 96 97 def allowed_file(filename: str) -> bool: 98 """Checks if the file has an allowed extension""" 99 return '.' in filename and \ 100 filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS 101 102 async def save_uploaded_file(file: UploadFile) -> str: 103 """Saves an uploaded file and returns its path""" 104 if not file: 105 raise HTTPException(status_code=400, detail="No file was provided") 106 107 if not allowed_file(file.filename): 108 raise HTTPException(status_code=400, detail="Unsupported file type") 109 110 # Secure filename 111 filename = "".join(c for c in file.filename if c.isalnum() or c in "._- ") 112 timestamp = int(time.time()) 113 safe_path = os.path.join(UPLOAD_FOLDER, f"{timestamp}_{filename}") 114 115 # Save file 116 contents = await file.read() 117 with open(safe_path, "wb") as f: 118 f.write(contents) 119 120 return safe_path 121 122 def create_output_filename(original_filename: str) -> str: 123 """Creates a filename for the transcription output""" 124 base_name = os.path.basename(original_filename) 125 name_without_ext = os.path.splitext(base_name)[0] 126 timestamp = int(time.time()) 127 return os.path.join(RESULTS_FOLDER, f"{name_without_ext}_{timestamp}.txt") 128 129 def formatted_analyze_transcript(transcription: str, language: Optional[str] = None, analysis_type: str = "general"): 130 """Adapted version of analyze_transcript using prompt manager""" 131 prompt_manager = get_prompt_manager() 132 133 # Select prompt type based on analysis type 134 prompt_key = f"transcription_{analysis_type}_analysis" 135 fallback_prompt = "Analyze the following transcription: {text}" 136 137 # Use formatted prompt with placeholder {text} 138 if prompt_key in system_prompts: 139 prompt = prompt_manager.format_prompt_direct( 140 system_prompts[prompt_key], 141 text=transcription, 142 language=language or "unknown" 143 ) 144 else: 145 # Fallback to default prompt if not found 146 prompt = prompt_manager.format_prompt_direct( 147 fallback_prompt, 148 text=transcription 149 ) 150 151 # Call analysis function with formatted prompt 152 return analyze_transcript(prompt, language) 153 154 async def process_transcription_task(task_id: str, file_path: str, output_txt: str, 155 model_size: str, is_diarization: bool = False, 156 huggingface_token: Optional[str] = None, 157 analyze: bool = False, 158 analysis_type: str = "general"): 159 """Asynchronous function to process a transcription task in the background""" 160 try: 161 # Update status 162 update_task(task_id, { 163 "status": "running", 164 "message": "Transcription in progress..." 165 }) 166 167 # Initialize progress tracker 168 progress_tracker = ProgressTracker(task_id) 169 170 # Call appropriate function based on transcription type 171 if is_diarization: 172 result = process_multiple_speakers( 173 file_path, 174 output_txt=output_txt, 175 model_size=model_size, 176 huggingface_token=huggingface_token, 177 progress=progress_tracker 178 ) 179 else: 180 result = process_monologue( 181 file_path, 182 output_txt=output_txt, 183 model_size=model_size, 184 progress=progress_tracker 185 ) 186 187 # If analysis is requested, perform it 188 if analyze and "transcription" in result: 189 update_task(task_id, { 190 "message": "Analyzing transcription..." 191 }) 192 193 # Use formatted version of analysis 194 analysis_result = formatted_analyze_transcript( 195 result["transcription"], 196 language=result.get("language"), 197 analysis_type=analysis_type 198 ) 199 200 # Add analysis to results 201 result["analysis"] = analysis_result 202 203 # Update with results 204 update_task(task_id, { 205 "status": "completed", 206 "results": result, 207 "message": "Transcription completed successfully" 208 }) 209 210 logger.info(f"Transcription task {task_id} completed successfully") 211 212 except Exception as e: 213 logger.error(f"Error during transcription task {task_id}: {str(e)}") 214 logger.error(traceback.format_exc()) 215 update_task(task_id, { 216 "status": "failed", 217 "error": str(e), 218 "message": f"Error: {str(e)}" 219 }) 220 221 @transcription_router.post('/monologue', response_model=TranscriptionResponse) 222 async def transcribe_monologue( 223 request: Request, 224 file: UploadFile = File(...), 225 model_size: str = Form("medium"), 226 current_user: User = Depends(get_current_active_user) 227 ): 228 """Transcribes a video or audio file (monologue mode)""" 229 try: 230 # Save uploaded file 231 file_path = await save_uploaded_file(file) 232 logger.info(f"File saved to {file_path}") 233 234 # Create output file 235 output_txt = create_output_filename(file_path) 236 237 # Process transcription 238 result = process_monologue( 239 file_path, 240 output_txt=output_txt, 241 model_size=model_size, 242 progress=lambda progress, desc: logger.debug(f"Progress: {progress*100:.1f}% - {desc}") 243 ) 244 245 # Apply JSONSimplifier post-processor if available 246 json_simplifier = getattr(request.app.state, "json_simplifier", None) 247 if json_simplifier and json_simplifier.should_process("transcription"): 248 result_dict = {"result": result} 249 processed = json_simplifier.process(result_dict, "transcription") 250 result = processed.get("result", result) 251 252 # If plain text explanation is available, add it to results 253 if "plain_explanation" in processed: 254 result["plain_explanation"] = processed["plain_explanation"] 255 256 # Prepare response 257 response = TranscriptionResponse( 258 transcription=result["transcription"], 259 language=result.get("language", ""), 260 duration=result.get("duration", 0), 261 segments=result.get("segments", []), 262 file_path=output_txt, 263 message="Transcription successful", 264 plain_explanation=result.get("plain_explanation") 265 ) 266 267 return response 268 269 except HTTPException: 270 raise 271 except Exception as e: 272 logger.error(f"Error during transcription: {str(e)}") 273 raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}") 274 275 @transcription_router.post('/multiple_speakers', response_model=TranscriptionResponse) 276 async def transcribe_multiple_speakers( 277 request: Request, 278 file: UploadFile = File(...), 279 model_size: str = Form("medium"), 280 huggingface_token: Optional[str] = Form(None), 281 current_user: User = Depends(get_current_active_user) 282 ): 283 """Transcribes a video or audio file with speaker identification""" 284 # Use provided token or environment one 285 token = huggingface_token or os.environ.get('HUGGINGFACE_TOKEN') or model_config["diarization"]["huggingface_token"] 286 287 if not token: 288 raise HTTPException( 289 status_code=400, 290 detail="A Hugging Face token is required for speaker identification" 291 ) 292 293 try: 294 # Save uploaded file 295 file_path = await save_uploaded_file(file) 296 logger.info(f"File saved to {file_path}") 297 298 # Create output file 299 output_txt = create_output_filename(file_path) 300 301 # Process transcription with speaker identification 302 result = process_multiple_speakers( 303 file_path, 304 output_txt=output_txt, 305 model_size=model_size, 306 huggingface_token=token, 307 progress=lambda progress, desc: logger.debug(f"Progress: {progress*100:.1f}% - {desc}") 308 ) 309 310 # Apply JSONSimplifier post-processor if available 311 json_simplifier = getattr(request.app.state, "json_simplifier", None) 312 if json_simplifier and json_simplifier.should_process("transcription"): 313 result_dict = {"result": result} 314 processed = json_simplifier.process(result_dict, "transcription") 315 result = processed.get("result", result) 316 317 # If plain text explanation is available, add it to results 318 if "plain_explanation" in processed: 319 result["plain_explanation"] = processed["plain_explanation"] 320 321 # Prepare response 322 response = TranscriptionResponse( 323 transcription=result["transcription"], 324 language=result.get("language", ""), 325 duration=result.get("duration", 0), 326 segments=result.get("segments", []), 327 file_path=output_txt, 328 speakers=result.get("speakers", []), 329 message="Transcription with speaker identification successful", 330 plain_explanation=result.get("plain_explanation") 331 ) 332 333 return response 334 335 except HTTPException: 336 raise 337 except Exception as e: 338 logger.error(f"Error during transcription with speaker identification: {str(e)}") 339 raise HTTPException( 340 status_code=500, 341 detail=f"Error during transcription with speaker identification: {str(e)}" 342 ) 343 344 @transcription_router.post('/audio', response_model=TranscriptionResponse) 345 async def transcribe_audio( 346 request: Request, 347 file: UploadFile = File(...), 348 model_size: str = Form("medium"), 349 current_user: User = Depends(get_current_active_user) 350 ): 351 """Transcribes an existing audio file""" 352 # Check extension 353 if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')): 354 raise HTTPException( 355 status_code=400, 356 detail="File must be in audio format (mp3, wav, ogg, flac, m4a)" 357 ) 358 359 try: 360 # Save uploaded file 361 audio_path = await save_uploaded_file(file) 362 logger.info(f"Audio file saved to {audio_path}") 363 364 # Create output file 365 output_txt = create_output_filename(audio_path) 366 367 # Transcribe audio 368 result = transcribe_external_audio( 369 audio_path, 370 model_size=model_size, 371 output_txt=output_txt, 372 progress=lambda progress, desc: logger.debug(f"Progress: {progress*100:.1f}% - {desc}") 373 ) 374 375 # Apply JSONSimplifier post-processor if available 376 json_simplifier = getattr(request.app.state, "json_simplifier", None) 377 if json_simplifier and json_simplifier.should_process("transcription"): 378 result_dict = {"result": result} 379 processed = json_simplifier.process(result_dict, "transcription") 380 result = processed.get("result", result) 381 382 # If plain text explanation is available, add it to results 383 if "plain_explanation" in processed: 384 result["plain_explanation"] = processed["plain_explanation"] 385 386 # Prepare response 387 response = TranscriptionResponse( 388 transcription=result["transcription"], 389 language=result.get("language", ""), 390 duration=result.get("duration", 0), 391 segments=result.get("segments", []), 392 file_path=output_txt, 393 message="Audio transcription successful", 394 plain_explanation=result.get("plain_explanation") 395 ) 396 397 return response 398 399 except HTTPException: 400 raise 401 except Exception as e: 402 logger.error(f"Error during audio transcription: {str(e)}") 403 raise HTTPException(status_code=500, detail=f"Error during audio transcription: {str(e)}") 404 405 @transcription_router.post('/analyze', response_model=TranscriptionAnalysisResponse) 406 async def analyze_transcription( 407 request: Request, 408 analysis_req: AnalysisRequest, 409 current_user: User = Depends(get_current_active_user) 410 ): 411 """Analyzes an existing transcription""" 412 try: 413 # Use formatted version of analysis with prompt manager 414 analysis = formatted_analyze_transcript( 415 analysis_req.transcription, 416 language=analysis_req.language, 417 analysis_type=analysis_req.analysis_type 418 ) 419 420 # Apply JSONSimplifier post-processor if available 421 result = {"analysis": analysis, "message": "Transcription analysis successful"} 422 json_simplifier = getattr(request.app.state, "json_simplifier", None) 423 if json_simplifier and json_simplifier.should_process("transcription"): 424 processed = json_simplifier.process(result, "transcription") 425 426 # If plain text explanation is available, add it to results 427 if "plain_explanation" in processed: 428 result["plain_explanation"] = processed["plain_explanation"] 429 430 # Prepare response 431 response = TranscriptionAnalysisResponse(**result) 432 433 return response 434 435 except HTTPException: 436 raise 437 except Exception as e: 438 logger.error(f"Error during transcription analysis: {str(e)}") 439 raise HTTPException(status_code=500, detail=f"Error during transcription analysis: {str(e)}") 440 441 @transcription_router.post('/async_transcribe', response_model=TaskResponse) 442 async def async_transcribe( 443 background_tasks: BackgroundTasks, 444 file: UploadFile = File(...), 445 model_size: str = Form("medium"), 446 enable_diarization: bool = Form(False), 447 analyze: bool = Form(False), 448 analysis_type: str = Form("general"), 449 huggingface_token: Optional[str] = Form(None), 450 current_user: User = Depends(get_current_active_user) 451 ): 452 """Starts an asynchronous transcription (in background)""" 453 try: 454 # Save uploaded file 455 file_path = await save_uploaded_file(file) 456 logger.info(f"File saved to {file_path} for asynchronous transcription") 457 458 # Create output file 459 output_txt = create_output_filename(file_path) 460 461 # Check if diarization is requested and token is available 462 if enable_diarization: 463 token = huggingface_token or os.environ.get('HUGGINGFACE_TOKEN') or model_config["diarization"]["huggingface_token"] 464 if not token: 465 raise HTTPException( 466 status_code=400, 467 detail="A Hugging Face token is required for speaker identification" 468 ) 469 470 # Define task type 471 task_type = TaskType.TRANSCRIPTION_MULTISPEAKER 472 else: 473 token = None 474 task_type = TaskType.TRANSCRIPTION_MONOLOGUE 475 476 # Task parameters 477 task_params = { 478 "file_path": file_path, 479 "output_txt": output_txt, 480 "model_size": model_size, 481 "is_diarization": enable_diarization, 482 "huggingface_token": token, 483 "analyze": analyze, 484 "analysis_type": analysis_type 485 } 486 487 # Create task 488 task_id = create_task( 489 task_type=task_type, 490 user_id=current_user.username, 491 params=task_params 492 ) 493 494 # Launch task in background 495 background_tasks.add_task( 496 process_transcription_task, 497 task_id=task_id, 498 file_path=file_path, 499 output_txt=output_txt, 500 model_size=model_size, 501 is_diarization=enable_diarization, 502 huggingface_token=token, 503 analyze=analyze, 504 analysis_type=analysis_type 505 ) 506 507 return TaskResponse( 508 task_id=task_id, 509 status="pending", 510 message="Transcription task launched successfully" 511 ) 512 513 except HTTPException: 514 raise 515 except Exception as e: 516 logger.error(f"Error launching asynchronous transcription: {str(e)}") 517 raise HTTPException( 518 status_code=500, 519 detail=f"Error launching asynchronous transcription: {str(e)}" 520 ) 521 522 @transcription_router.get('/task/{task_id}/result', response_model=None) 523 async def get_task_result( 524 task_id: str, 525 request: Request, 526 current_user: User = Depends(get_current_active_user) 527 ): 528 """Retrieves the result of a transcription task""" 529 try: 530 # Get task status 531 task = get_task_status(task_id) 532 533 if not task: 534 raise HTTPException(status_code=404, detail="Task not found") 535 536 if task["status"] != "completed": 537 return { 538 "status": task["status"], 539 "message": task.get("message", "Task is being processed") 540 } 541 542 # Get results 543 result = task.get("results", {}) 544 545 # Apply JSONSimplifier post-processor if available and not already applied 546 if "plain_explanation" not in result: 547 json_simplifier = getattr(request.app.state, "json_simplifier", None) 548 if json_simplifier and json_simplifier.should_process("transcription"): 549 result_dict = {"result": result} 550 processed = json_simplifier.process(result_dict, "transcription") 551 552 # If plain text explanation is available, add it to results 553 if "plain_explanation" in processed: 554 result["plain_explanation"] = processed["plain_explanation"] 555 556 return { 557 "status": "completed", 558 "result": result, 559 "message": task.get("message", "Task completed successfully") 560 } 561 562 except HTTPException: 563 raise 564 except Exception as e: 565 logger.error(f"Error retrieving task result: {str(e)}") 566 raise HTTPException( 567 status_code=500, 568 detail=f"Error retrieving task result: {str(e)}" 569 ) 570 571 @transcription_router.get('/models', response_model=ModelsResponse) 572 async def get_models(): 573 """Retrieves information about available transcription models""" 574 try: 575 models_info = get_available_models() 576 577 return ModelsResponse( 578 whisper=models_info.get("whisper", {}), 579 diarization=models_info.get("diarization", {}) 580 ) 581 582 except Exception as e: 583 logger.error(f"Error retrieving models: {str(e)}") 584 raise HTTPException(status_code=500, detail=f"Error retrieving models: {str(e)}") 585 586 @transcription_router.get('/allowed_extensions') 587 async def get_allowed_extensions(): 588 """Retrieves the list of allowed audio file extensions""" 589 return { 590 "allowed_extensions": list(ALLOWED_EXTENSIONS), 591 "max_upload_size_mb": model_config["audio"]["max_upload_size_mb"] 592 }