/ transcription_models / whisper_utils.py
whisper_utils.py
1 """ 2 Whisper usage module for audio transcription 3 ---------------------------------------------------------- 4 This module provides functions for audio transcription with the Whisper model. 5 """ 6 7 import os 8 import gc 9 import logging 10 import traceback 11 from typing import Dict, Any, Optional, Callable, Union 12 import torch 13 import whisper 14 # Logging 15 logger = logging.getLogger("transcription.whisper") 16 WHISPER_AVAILABLE = True 17 18 19 # Global model for reuse 20 whisper_model = None 21 current_model_size = None 22 23 # Configuration 24 WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "medium") 25 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 26 27 def get_whisper_model(model_size: Optional[str] = None) -> Any: 28 """ 29 Loads or retrieves the Whisper model 30 31 Args: 32 model_size: Size of the Whisper model to use ('tiny', 'base', 'small', 'medium', 'large') 33 34 Returns: 35 Whisper model instance 36 37 Raises: 38 ImportError: If Whisper is not available 39 """ 40 global whisper_model, current_model_size 41 42 if not WHISPER_AVAILABLE: 43 raise ImportError("Whisper is required for transcription") 44 45 # Define model size 46 selected_size = model_size or WHISPER_MODEL_SIZE 47 48 # Check if we need to load a new model 49 if whisper_model is None or current_model_size != selected_size: 50 # Free memory if a model was already loaded 51 if whisper_model is not None: 52 del whisper_model 53 gc.collect() 54 if torch.cuda.is_available(): 55 torch.cuda.empty_cache() 56 57 logger.info(f"Loading Whisper model {selected_size}...") 58 whisper_model = whisper.load_model(selected_size, device=DEVICE) 59 current_model_size = selected_size 60 61 return whisper_model 62 63 def transcribe_audio( 64 audio_path: str, 65 model_size: Optional[str] = None, 66 language: Optional[str] = None, 67 progress: Optional[Callable] = None, 68 **whisper_options 69 ) -> Dict[str, Any]: 70 """ 71 Transcribes an audio file to text 72 73 Args: 74 audio_path: Path to the audio file 75 model_size: Size of the Whisper model to use 76 language: Language code for transcription (e.g., 'fr', 'en') 77 progress: Progress tracking function (optional) 78 whisper_options: Additional options to pass to Whisper 79 80 Returns: 81 Dictionary containing the transcription results 82 83 Raises: 84 ImportError: If Whisper is not available 85 Exception: If an error occurs during transcription 86 """ 87 try: 88 if progress: 89 progress(0.4, desc="Loading transcription model...") 90 91 # Load the Whisper model 92 model = get_whisper_model(model_size) 93 94 if progress: 95 progress(0.5, desc="Audio transcription in progress...") 96 97 # Prepare transcription options 98 options = { 99 "fp16": torch.cuda.is_available(), 100 "verbose": False 101 } 102 103 # Add language if specified 104 if language: 105 options["language"] = language 106 107 # Add additional options 108 options.update(whisper_options) 109 110 # Transcribe the audio 111 result = model.transcribe(audio_path, **options) 112 113 if progress: 114 progress(0.8, desc="Transcription completed") 115 116 return result 117 118 except Exception as e: 119 error_msg = f"Error during transcription: {str(e)}" 120 logger.error(error_msg) 121 logger.error(traceback.format_exc()) 122 raise Exception(error_msg) 123 124 def cleanup_whisper_model() -> bool: 125 """ 126 Frees the Whisper model memory 127 128 Returns: 129 True if the model was freed, False otherwise 130 """ 131 global whisper_model, current_model_size 132 133 if whisper_model is not None: 134 try: 135 del whisper_model 136 whisper_model = None 137 current_model_size = None 138 gc.collect() 139 140 if torch.cuda.is_available(): 141 torch.cuda.empty_cache() 142 143 return True 144 except Exception as e: 145 logger.error(f"Error when freeing the Whisper model: {str(e)}") 146 147 return False 148 149 def get_available_whisper_models() -> Dict[str, Dict[str, Any]]: 150 """ 151 Returns available Whisper models with their characteristics 152 153 Returns: 154 Dictionary of available models 155 """ 156 if not WHISPER_AVAILABLE: 157 return {} 158 159 return { 160 "tiny": {"parameters": "39M", "english_only": False, "multilingual": True, "required_vram": "1 GB"}, 161 "base": {"parameters": "74M", "english_only": False, "multilingual": True, "required_vram": "1 GB"}, 162 "small": {"parameters": "244M", "english_only": False, "multilingual": True, "required_vram": "2 GB"}, 163 "medium": {"parameters": "769M", "english_only": False, "multilingual": True, "required_vram": "5 GB"}, 164 "large": {"parameters": "1550M", "english_only": False, "multilingual": True, "required_vram": "10 GB"} 165 } 166 167 def format_whisper_result(result: Dict[str, Any], include_timestamps: bool = True) -> str: 168 """ 169 Formats the Whisper result into readable text 170 171 Args: 172 result: Whisper transcription result 173 include_timestamps: Include timestamps in the output 174 175 Returns: 176 Formatted text 177 """ 178 if not include_timestamps: 179 return result["text"].strip() 180 181 formatted_text = [] 182 for segment in result["segments"]: 183 start = format_time(segment["start"]) 184 end = format_time(segment["end"]) 185 text = segment["text"].strip() 186 formatted_text.append(f"[{start}-{end}] {text}") 187 188 return "\n".join(formatted_text) 189 190 def format_time(seconds: float) -> str: 191 """ 192 Formats seconds into hh:mm:ss format 193 194 Args: 195 seconds: Number of seconds 196 197 Returns: 198 Formatted string 199 """ 200 m, s = divmod(seconds, 60) 201 h, m = divmod(m, 60) 202 return f"{int(h):02d}:{int(m):02d}:{int(s):02d}"