/ transcription_models / whisper_utils.py
whisper_utils.py
  1  """
  2  Whisper usage module for audio transcription
  3  ----------------------------------------------------------
  4  This module provides functions for audio transcription with the Whisper model.
  5  """
  6  
  7  import os
  8  import gc
  9  import logging
 10  import traceback
 11  from typing import Dict, Any, Optional, Callable, Union
 12  import torch
 13  import whisper
 14  # Logging
 15  logger = logging.getLogger("transcription.whisper")
 16  WHISPER_AVAILABLE = True
 17  
 18  
 19  # Global model for reuse
 20  whisper_model = None
 21  current_model_size = None
 22  
 23  # Configuration
 24  WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "medium")
 25  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 26  
 27  def get_whisper_model(model_size: Optional[str] = None) -> Any:
 28      """
 29      Loads or retrieves the Whisper model
 30      
 31      Args:
 32          model_size: Size of the Whisper model to use ('tiny', 'base', 'small', 'medium', 'large')
 33          
 34      Returns:
 35          Whisper model instance
 36          
 37      Raises:
 38          ImportError: If Whisper is not available
 39      """
 40      global whisper_model, current_model_size
 41      
 42      if not WHISPER_AVAILABLE:
 43          raise ImportError("Whisper is required for transcription")
 44      
 45      # Define model size
 46      selected_size = model_size or WHISPER_MODEL_SIZE
 47      
 48      # Check if we need to load a new model
 49      if whisper_model is None or current_model_size != selected_size:
 50          # Free memory if a model was already loaded
 51          if whisper_model is not None:
 52              del whisper_model
 53              gc.collect()
 54              if torch.cuda.is_available():
 55                  torch.cuda.empty_cache()
 56          
 57          logger.info(f"Loading Whisper model {selected_size}...")
 58          whisper_model = whisper.load_model(selected_size, device=DEVICE)
 59          current_model_size = selected_size
 60          
 61      return whisper_model
 62  
 63  def transcribe_audio(
 64      audio_path: str, 
 65      model_size: Optional[str] = None, 
 66      language: Optional[str] = None,
 67      progress: Optional[Callable] = None,
 68      **whisper_options
 69  ) -> Dict[str, Any]:
 70      """
 71      Transcribes an audio file to text
 72      
 73      Args:
 74          audio_path: Path to the audio file
 75          model_size: Size of the Whisper model to use
 76          language: Language code for transcription (e.g., 'fr', 'en')
 77          progress: Progress tracking function (optional)
 78          whisper_options: Additional options to pass to Whisper
 79          
 80      Returns:
 81          Dictionary containing the transcription results
 82          
 83      Raises:
 84          ImportError: If Whisper is not available
 85          Exception: If an error occurs during transcription
 86      """
 87      try:
 88          if progress:
 89              progress(0.4, desc="Loading transcription model...")
 90          
 91          # Load the Whisper model
 92          model = get_whisper_model(model_size)
 93          
 94          if progress:
 95              progress(0.5, desc="Audio transcription in progress...")
 96          
 97          # Prepare transcription options
 98          options = {
 99              "fp16": torch.cuda.is_available(),
100              "verbose": False
101          }
102          
103          # Add language if specified
104          if language:
105              options["language"] = language
106              
107          # Add additional options
108          options.update(whisper_options)
109          
110          # Transcribe the audio
111          result = model.transcribe(audio_path, **options)
112          
113          if progress:
114              progress(0.8, desc="Transcription completed")
115          
116          return result
117          
118      except Exception as e:
119          error_msg = f"Error during transcription: {str(e)}"
120          logger.error(error_msg)
121          logger.error(traceback.format_exc())
122          raise Exception(error_msg)
123  
124  def cleanup_whisper_model() -> bool:
125      """
126      Frees the Whisper model memory
127      
128      Returns:
129          True if the model was freed, False otherwise
130      """
131      global whisper_model, current_model_size
132      
133      if whisper_model is not None:
134          try:
135              del whisper_model
136              whisper_model = None
137              current_model_size = None
138              gc.collect()
139              
140              if torch.cuda.is_available():
141                  torch.cuda.empty_cache()
142                  
143              return True
144          except Exception as e:
145              logger.error(f"Error when freeing the Whisper model: {str(e)}")
146              
147      return False
148  
149  def get_available_whisper_models() -> Dict[str, Dict[str, Any]]:
150      """
151      Returns available Whisper models with their characteristics
152      
153      Returns:
154          Dictionary of available models
155      """
156      if not WHISPER_AVAILABLE:
157          return {}
158          
159      return {
160          "tiny": {"parameters": "39M", "english_only": False, "multilingual": True, "required_vram": "1 GB"},
161          "base": {"parameters": "74M", "english_only": False, "multilingual": True, "required_vram": "1 GB"},
162          "small": {"parameters": "244M", "english_only": False, "multilingual": True, "required_vram": "2 GB"},
163          "medium": {"parameters": "769M", "english_only": False, "multilingual": True, "required_vram": "5 GB"},
164          "large": {"parameters": "1550M", "english_only": False, "multilingual": True, "required_vram": "10 GB"}
165      }
166  
167  def format_whisper_result(result: Dict[str, Any], include_timestamps: bool = True) -> str:
168      """
169      Formats the Whisper result into readable text
170      
171      Args:
172          result: Whisper transcription result
173          include_timestamps: Include timestamps in the output
174          
175      Returns:
176          Formatted text
177      """
178      if not include_timestamps:
179          return result["text"].strip()
180      
181      formatted_text = []
182      for segment in result["segments"]:
183          start = format_time(segment["start"])
184          end = format_time(segment["end"])
185          text = segment["text"].strip()
186          formatted_text.append(f"[{start}-{end}] {text}")
187      
188      return "\n".join(formatted_text)
189  
190  def format_time(seconds: float) -> str:
191      """
192      Formats seconds into hh:mm:ss format
193      
194      Args:
195          seconds: Number of seconds
196          
197      Returns:
198          Formatted string
199      """
200      m, s = divmod(seconds, 60)
201      h, m = divmod(m, 60)
202      return f"{int(h):02d}:{int(m):02d}:{int(s):02d}"