/ api / transcription_router.py
transcription_router.py
  1  """
  2  Router for audio and video transcription features
  3  --------------------------------------------------------------
  4  This module implements routes for transcribing audio and video files,
  5  with or without speaker identification.
  6  """
  7  
  8  import os
  9  import logging
 10  import time
 11  import traceback
 12  from typing import Optional, Dict, Any, List
 13  from fastapi import APIRouter, UploadFile, File, Form, HTTPException, BackgroundTasks, Depends, Request
 14  from fastapi.responses import JSONResponse
 15  from pydantic import BaseModel
 16  
 17  # Import response models
 18  from .response_models import (
 19      TranscriptionResponse,
 20      ErrorResponse,
 21      SuccessResponse,
 22      TaskResponse,
 23      TaskStatusResponse
 24  )
 25  
 26  # Import transcription functions
 27  from transcription_models import (
 28      process_monologue,
 29      process_multiple_speakers,
 30      transcribe_external_audio,
 31      get_available_models,
 32      analyze_transcript
 33  )
 34  
 35  # Import for authentication
 36  from auth import get_current_active_user, User
 37  
 38  # Import configuration
 39  from config import api_config, model_config, system_prompts
 40  
 41  # Import task manager
 42  from inference_engine import (
 43      TaskType, 
 44      create_task, 
 45      update_task, 
 46      get_task_status,
 47      ProgressTracker
 48  )
 49  
 50  # Import prompt manager
 51  from utils.prompt_manager import get_prompt_manager
 52  
 53  # Logging configuration
 54  logger = logging.getLogger("api.transcription")
 55  
 56  # Create router
 57  transcription_router = APIRouter(
 58      prefix="/transcription",
 59      tags=["Transcription"],
 60      responses={
 61          400: {"model": ErrorResponse, "description": "Invalid request"},
 62          401: {"model": ErrorResponse, "description": "Unauthorized"},
 63          404: {"model": ErrorResponse, "description": "File not found"},
 64          500: {"model": ErrorResponse, "description": "Server error"}
 65      }
 66  )
 67  
 68  # Configuration
 69  ALLOWED_EXTENSIONS = {'mp4', 'mov', 'avi', 'mkv', 'webm', 'mp3', 'wav', 'ogg', 'flac', 'm4a'}
 70  UPLOAD_FOLDER = 'uploads'
 71  RESULTS_FOLDER = 'results/transcriptions'
 72  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
 73  os.makedirs(RESULTS_FOLDER, exist_ok=True)
 74  
 75  # Pydantic models for responses
 76  class ModelInfo(BaseModel):
 77      name: str
 78      description: str
 79      languages: List[str]
 80      size_mb: float
 81      is_multilingual: bool
 82  
 83  class ModelsResponse(BaseModel):
 84      whisper: Dict[str, ModelInfo]
 85      diarization: Optional[Dict[str, ModelInfo]] = None
 86  
 87  class AnalysisRequest(BaseModel):
 88      transcription: str
 89      language: Optional[str] = None
 90      analysis_type: Optional[str] = "general"
 91  
 92  class TranscriptionAnalysisResponse(BaseModel):
 93      analysis: Dict[str, Any]
 94      plain_explanation: Optional[str] = None
 95      message: str
 96  
 97  def allowed_file(filename: str) -> bool:
 98      """Checks if the file has an allowed extension"""
 99      return '.' in filename and \
100             filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
101  
102  async def save_uploaded_file(file: UploadFile) -> str:
103      """Saves an uploaded file and returns its path"""
104      if not file:
105          raise HTTPException(status_code=400, detail="No file was provided")
106      
107      if not allowed_file(file.filename):
108          raise HTTPException(status_code=400, detail="Unsupported file type")
109      
110      # Secure filename
111      filename = "".join(c for c in file.filename if c.isalnum() or c in "._- ")
112      timestamp = int(time.time())
113      safe_path = os.path.join(UPLOAD_FOLDER, f"{timestamp}_{filename}")
114      
115      # Save file
116      contents = await file.read()
117      with open(safe_path, "wb") as f:
118          f.write(contents)
119      
120      return safe_path
121  
122  def create_output_filename(original_filename: str) -> str:
123      """Creates a filename for the transcription output"""
124      base_name = os.path.basename(original_filename)
125      name_without_ext = os.path.splitext(base_name)[0]
126      timestamp = int(time.time())
127      return os.path.join(RESULTS_FOLDER, f"{name_without_ext}_{timestamp}.txt")
128  
129  def formatted_analyze_transcript(transcription: str, language: Optional[str] = None, analysis_type: str = "general"):
130      """Adapted version of analyze_transcript using prompt manager"""
131      prompt_manager = get_prompt_manager()
132      
133      # Select prompt type based on analysis type
134      prompt_key = f"transcription_{analysis_type}_analysis"
135      fallback_prompt = "Analyze the following transcription: {text}"
136      
137      # Use formatted prompt with placeholder {text}
138      if prompt_key in system_prompts:
139          prompt = prompt_manager.format_prompt_direct(
140              system_prompts[prompt_key], 
141              text=transcription,
142              language=language or "unknown"
143          )
144      else:
145          # Fallback to default prompt if not found
146          prompt = prompt_manager.format_prompt_direct(
147              fallback_prompt,
148              text=transcription
149          )
150      
151      # Call analysis function with formatted prompt
152      return analyze_transcript(prompt, language)
153  
154  async def process_transcription_task(task_id: str, file_path: str, output_txt: str, 
155                                       model_size: str, is_diarization: bool = False,
156                                       huggingface_token: Optional[str] = None,
157                                       analyze: bool = False,
158                                       analysis_type: str = "general"):
159      """Asynchronous function to process a transcription task in the background"""
160      try:
161          # Update status
162          update_task(task_id, {
163              "status": "running",
164              "message": "Transcription in progress..."
165          })
166          
167          # Initialize progress tracker
168          progress_tracker = ProgressTracker(task_id)
169          
170          # Call appropriate function based on transcription type
171          if is_diarization:
172              result = process_multiple_speakers(
173                  file_path, 
174                  output_txt=output_txt,
175                  model_size=model_size,
176                  huggingface_token=huggingface_token,
177                  progress=progress_tracker
178              )
179          else:
180              result = process_monologue(
181                  file_path, 
182                  output_txt=output_txt,
183                  model_size=model_size,
184                  progress=progress_tracker
185              )
186          
187          # If analysis is requested, perform it
188          if analyze and "transcription" in result:
189              update_task(task_id, {
190                  "message": "Analyzing transcription..."
191              })
192              
193              # Use formatted version of analysis
194              analysis_result = formatted_analyze_transcript(
195                  result["transcription"],
196                  language=result.get("language"),
197                  analysis_type=analysis_type
198              )
199              
200              # Add analysis to results
201              result["analysis"] = analysis_result
202              
203          # Update with results
204          update_task(task_id, {
205              "status": "completed",
206              "results": result,
207              "message": "Transcription completed successfully"
208          })
209          
210          logger.info(f"Transcription task {task_id} completed successfully")
211          
212      except Exception as e:
213          logger.error(f"Error during transcription task {task_id}: {str(e)}")
214          logger.error(traceback.format_exc())
215          update_task(task_id, {
216              "status": "failed",
217              "error": str(e),
218              "message": f"Error: {str(e)}"
219          })
220  
221  @transcription_router.post('/monologue', response_model=TranscriptionResponse)
222  async def transcribe_monologue(
223      request: Request,
224      file: UploadFile = File(...),
225      model_size: str = Form("medium"),
226      current_user: User = Depends(get_current_active_user)
227  ):
228      """Transcribes a video or audio file (monologue mode)"""
229      try:
230          # Save uploaded file
231          file_path = await save_uploaded_file(file)
232          logger.info(f"File saved to {file_path}")
233          
234          # Create output file
235          output_txt = create_output_filename(file_path)
236          
237          # Process transcription
238          result = process_monologue(
239              file_path, 
240              output_txt=output_txt,
241              model_size=model_size,
242              progress=lambda progress, desc: logger.debug(f"Progress: {progress*100:.1f}% - {desc}")
243          )
244          
245          # Apply JSONSimplifier post-processor if available
246          json_simplifier = getattr(request.app.state, "json_simplifier", None)
247          if json_simplifier and json_simplifier.should_process("transcription"):
248              result_dict = {"result": result}
249              processed = json_simplifier.process(result_dict, "transcription")
250              result = processed.get("result", result)
251              
252              # If plain text explanation is available, add it to results
253              if "plain_explanation" in processed:
254                  result["plain_explanation"] = processed["plain_explanation"]
255          
256          # Prepare response
257          response = TranscriptionResponse(
258              transcription=result["transcription"],
259              language=result.get("language", ""),
260              duration=result.get("duration", 0),
261              segments=result.get("segments", []),
262              file_path=output_txt,
263              message="Transcription successful",
264              plain_explanation=result.get("plain_explanation")
265          )
266          
267          return response
268          
269      except HTTPException:
270          raise
271      except Exception as e:
272          logger.error(f"Error during transcription: {str(e)}")
273          raise HTTPException(status_code=500, detail=f"Error during transcription: {str(e)}")
274  
275  @transcription_router.post('/multiple_speakers', response_model=TranscriptionResponse)
276  async def transcribe_multiple_speakers(
277      request: Request,
278      file: UploadFile = File(...),
279      model_size: str = Form("medium"),
280      huggingface_token: Optional[str] = Form(None),
281      current_user: User = Depends(get_current_active_user)
282  ):
283      """Transcribes a video or audio file with speaker identification"""
284      # Use provided token or environment one
285      token = huggingface_token or os.environ.get('HUGGINGFACE_TOKEN') or model_config["diarization"]["huggingface_token"]
286      
287      if not token:
288          raise HTTPException(
289              status_code=400, 
290              detail="A Hugging Face token is required for speaker identification"
291          )
292      
293      try:
294          # Save uploaded file
295          file_path = await save_uploaded_file(file)
296          logger.info(f"File saved to {file_path}")
297          
298          # Create output file
299          output_txt = create_output_filename(file_path)
300          
301          # Process transcription with speaker identification
302          result = process_multiple_speakers(
303              file_path, 
304              output_txt=output_txt,
305              model_size=model_size,
306              huggingface_token=token,
307              progress=lambda progress, desc: logger.debug(f"Progress: {progress*100:.1f}% - {desc}")
308          )
309          
310          # Apply JSONSimplifier post-processor if available
311          json_simplifier = getattr(request.app.state, "json_simplifier", None)
312          if json_simplifier and json_simplifier.should_process("transcription"):
313              result_dict = {"result": result}
314              processed = json_simplifier.process(result_dict, "transcription")
315              result = processed.get("result", result)
316              
317              # If plain text explanation is available, add it to results
318              if "plain_explanation" in processed:
319                  result["plain_explanation"] = processed["plain_explanation"]
320          
321          # Prepare response
322          response = TranscriptionResponse(
323              transcription=result["transcription"],
324              language=result.get("language", ""),
325              duration=result.get("duration", 0),
326              segments=result.get("segments", []),
327              file_path=output_txt,
328              speakers=result.get("speakers", []),
329              message="Transcription with speaker identification successful",
330              plain_explanation=result.get("plain_explanation")
331          )
332          
333          return response
334          
335      except HTTPException:
336          raise
337      except Exception as e:
338          logger.error(f"Error during transcription with speaker identification: {str(e)}")
339          raise HTTPException(
340              status_code=500, 
341              detail=f"Error during transcription with speaker identification: {str(e)}"
342          )
343  
344  @transcription_router.post('/audio', response_model=TranscriptionResponse)
345  async def transcribe_audio(
346      request: Request,
347      file: UploadFile = File(...),
348      model_size: str = Form("medium"),
349      current_user: User = Depends(get_current_active_user)
350  ):
351      """Transcribes an existing audio file"""
352      # Check extension
353      if not file.filename.lower().endswith(('.mp3', '.wav', '.ogg', '.flac', '.m4a')):
354          raise HTTPException(
355              status_code=400, 
356              detail="File must be in audio format (mp3, wav, ogg, flac, m4a)"
357          )
358      
359      try:
360          # Save uploaded file
361          audio_path = await save_uploaded_file(file)
362          logger.info(f"Audio file saved to {audio_path}")
363          
364          # Create output file
365          output_txt = create_output_filename(audio_path)
366          
367          # Transcribe audio
368          result = transcribe_external_audio(
369              audio_path, 
370              model_size=model_size,
371              output_txt=output_txt,
372              progress=lambda progress, desc: logger.debug(f"Progress: {progress*100:.1f}% - {desc}")
373          )
374          
375          # Apply JSONSimplifier post-processor if available
376          json_simplifier = getattr(request.app.state, "json_simplifier", None)
377          if json_simplifier and json_simplifier.should_process("transcription"):
378              result_dict = {"result": result}
379              processed = json_simplifier.process(result_dict, "transcription")
380              result = processed.get("result", result)
381              
382              # If plain text explanation is available, add it to results
383              if "plain_explanation" in processed:
384                  result["plain_explanation"] = processed["plain_explanation"]
385          
386          # Prepare response
387          response = TranscriptionResponse(
388              transcription=result["transcription"],
389              language=result.get("language", ""),
390              duration=result.get("duration", 0),
391              segments=result.get("segments", []),
392              file_path=output_txt,
393              message="Audio transcription successful",
394              plain_explanation=result.get("plain_explanation")
395          )
396          
397          return response
398          
399      except HTTPException:
400          raise
401      except Exception as e:
402          logger.error(f"Error during audio transcription: {str(e)}")
403          raise HTTPException(status_code=500, detail=f"Error during audio transcription: {str(e)}")
404  
405  @transcription_router.post('/analyze', response_model=TranscriptionAnalysisResponse)
406  async def analyze_transcription(
407      request: Request,
408      analysis_req: AnalysisRequest,
409      current_user: User = Depends(get_current_active_user)
410  ):
411      """Analyzes an existing transcription"""
412      try:
413          # Use formatted version of analysis with prompt manager
414          analysis = formatted_analyze_transcript(
415              analysis_req.transcription,
416              language=analysis_req.language,
417              analysis_type=analysis_req.analysis_type
418          )
419          
420          # Apply JSONSimplifier post-processor if available
421          result = {"analysis": analysis, "message": "Transcription analysis successful"}
422          json_simplifier = getattr(request.app.state, "json_simplifier", None)
423          if json_simplifier and json_simplifier.should_process("transcription"):
424              processed = json_simplifier.process(result, "transcription")
425              
426              # If plain text explanation is available, add it to results
427              if "plain_explanation" in processed:
428                  result["plain_explanation"] = processed["plain_explanation"]
429          
430          # Prepare response
431          response = TranscriptionAnalysisResponse(**result)
432          
433          return response
434          
435      except HTTPException:
436          raise
437      except Exception as e:
438          logger.error(f"Error during transcription analysis: {str(e)}")
439          raise HTTPException(status_code=500, detail=f"Error during transcription analysis: {str(e)}")
440  
441  @transcription_router.post('/async_transcribe', response_model=TaskResponse)
442  async def async_transcribe(
443      background_tasks: BackgroundTasks,
444      file: UploadFile = File(...),
445      model_size: str = Form("medium"),
446      enable_diarization: bool = Form(False),
447      analyze: bool = Form(False),
448      analysis_type: str = Form("general"),
449      huggingface_token: Optional[str] = Form(None),
450      current_user: User = Depends(get_current_active_user)
451  ):
452      """Starts an asynchronous transcription (in background)"""
453      try:
454          # Save uploaded file
455          file_path = await save_uploaded_file(file)
456          logger.info(f"File saved to {file_path} for asynchronous transcription")
457          
458          # Create output file
459          output_txt = create_output_filename(file_path)
460          
461          # Check if diarization is requested and token is available
462          if enable_diarization:
463              token = huggingface_token or os.environ.get('HUGGINGFACE_TOKEN') or model_config["diarization"]["huggingface_token"]
464              if not token:
465                  raise HTTPException(
466                      status_code=400, 
467                      detail="A Hugging Face token is required for speaker identification"
468                  )
469              
470              # Define task type
471              task_type = TaskType.TRANSCRIPTION_MULTISPEAKER
472          else:
473              token = None
474              task_type = TaskType.TRANSCRIPTION_MONOLOGUE
475          
476          # Task parameters
477          task_params = {
478              "file_path": file_path,
479              "output_txt": output_txt,
480              "model_size": model_size,
481              "is_diarization": enable_diarization,
482              "huggingface_token": token,
483              "analyze": analyze,
484              "analysis_type": analysis_type
485          }
486          
487          # Create task
488          task_id = create_task(
489              task_type=task_type,
490              user_id=current_user.username,
491              params=task_params
492          )
493          
494          # Launch task in background
495          background_tasks.add_task(
496              process_transcription_task,
497              task_id=task_id,
498              file_path=file_path,
499              output_txt=output_txt,
500              model_size=model_size,
501              is_diarization=enable_diarization,
502              huggingface_token=token,
503              analyze=analyze,
504              analysis_type=analysis_type
505          )
506          
507          return TaskResponse(
508              task_id=task_id,
509              status="pending",
510              message="Transcription task launched successfully"
511          )
512          
513      except HTTPException:
514          raise
515      except Exception as e:
516          logger.error(f"Error launching asynchronous transcription: {str(e)}")
517          raise HTTPException(
518              status_code=500, 
519              detail=f"Error launching asynchronous transcription: {str(e)}"
520          )
521  
522  @transcription_router.get('/task/{task_id}/result', response_model=None)
523  async def get_task_result(
524      task_id: str,
525      request: Request,
526      current_user: User = Depends(get_current_active_user)
527  ):
528      """Retrieves the result of a transcription task"""
529      try:
530          # Get task status
531          task = get_task_status(task_id)
532          
533          if not task:
534              raise HTTPException(status_code=404, detail="Task not found")
535              
536          if task["status"] != "completed":
537              return {
538                  "status": task["status"],
539                  "message": task.get("message", "Task is being processed")
540              }
541          
542          # Get results
543          result = task.get("results", {})
544          
545          # Apply JSONSimplifier post-processor if available and not already applied
546          if "plain_explanation" not in result:
547              json_simplifier = getattr(request.app.state, "json_simplifier", None)
548              if json_simplifier and json_simplifier.should_process("transcription"):
549                  result_dict = {"result": result}
550                  processed = json_simplifier.process(result_dict, "transcription")
551                  
552                  # If plain text explanation is available, add it to results
553                  if "plain_explanation" in processed:
554                      result["plain_explanation"] = processed["plain_explanation"]
555          
556          return {
557              "status": "completed",
558              "result": result,
559              "message": task.get("message", "Task completed successfully")
560          }
561          
562      except HTTPException:
563          raise
564      except Exception as e:
565          logger.error(f"Error retrieving task result: {str(e)}")
566          raise HTTPException(
567              status_code=500, 
568              detail=f"Error retrieving task result: {str(e)}"
569          )
570  
571  @transcription_router.get('/models', response_model=ModelsResponse)
572  async def get_models():
573      """Retrieves information about available transcription models"""
574      try:
575          models_info = get_available_models()
576          
577          return ModelsResponse(
578              whisper=models_info.get("whisper", {}),
579              diarization=models_info.get("diarization", {})
580          )
581          
582      except Exception as e:
583          logger.error(f"Error retrieving models: {str(e)}")
584          raise HTTPException(status_code=500, detail=f"Error retrieving models: {str(e)}")
585  
586  @transcription_router.get('/allowed_extensions')
587  async def get_allowed_extensions():
588      """Retrieves the list of allowed audio file extensions"""
589      return {
590          "allowed_extensions": list(ALLOWED_EXTENSIONS),
591          "max_upload_size_mb": model_config["audio"]["max_upload_size_mb"]
592      }