/ transcription_models / diarization.py
diarization.py
  1  """
  2  Speaker diarization module for video transcription
  3  ------------------------------------------------
  4  This module provides functions for speaker identification (diarization)
  5  in audio files and for attributing speakers to transcription segments.
  6  """
  7  
  8  import os
  9  import logging
 10  import traceback
 11  from typing import List, Dict, Tuple, Optional, Callable, Any, Union
 12  
 13  # Logging configuration
 14  logger = logging.getLogger("transcription.diarization")
 15  
 16  # Check optional dependencies
 17  try:
 18      from pyannote.audio import Pipeline
 19      PYANNOTE_AVAILABLE = True
 20  except ImportError:
 21      PYANNOTE_AVAILABLE = False
 22      logger.warning("Pyannote.audio not available. Diarization will be disabled.")
 23  
 24  # Configuration
 25  HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", "")
 26  DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1"
 27  
 28  # Global variable for the model
 29  diarization_pipeline = None
 30  
 31  def load_diarization_model(huggingface_token: Optional[str] = None):
 32      """
 33      Loads the diarization model
 34      
 35      Args:
 36          huggingface_token: Hugging Face token for model access
 37          
 38      Returns:
 39          Diarization pipeline
 40          
 41      Raises:
 42          ImportError: If pyannote.audio is not available
 43          ValueError: If no token is provided
 44      """
 45      global diarization_pipeline
 46      
 47      if not PYANNOTE_AVAILABLE:
 48          raise ImportError("Pyannote.audio is required for diarization")
 49      
 50      token = huggingface_token or HUGGINGFACE_TOKEN
 51      if not token:
 52          raise ValueError("A Hugging Face token is required for diarization")
 53      
 54      if diarization_pipeline is None:
 55          diarization_pipeline = Pipeline.from_pretrained(
 56              DIARIZATION_MODEL,
 57              use_auth_token=token
 58          )
 59      
 60      return diarization_pipeline
 61  
 62  def diarize_audio(
 63      audio_path: str, 
 64      huggingface_token: Optional[str] = None, 
 65      progress: Optional[Callable] = None
 66  ) -> List[Tuple[float, float, str]]:
 67      """
 68      Identifies speakers in an audio file
 69      
 70      Args:
 71          audio_path: Path to the audio file
 72          huggingface_token: Hugging Face token for model access
 73          progress: Progress tracking function (optional)
 74          
 75      Returns:
 76          List of segments with speaker information (start, end, speaker)
 77          
 78      Raises:
 79          ImportError: If pyannote.audio is not available
 80          ValueError: If no token is provided
 81          Exception: If an error occurs during diarization
 82      """
 83      try:
 84          if progress:
 85              progress(0.6, desc="Loading diarization model...")
 86          
 87          # Initialize diarization pipeline
 88          pipeline = load_diarization_model(huggingface_token)
 89          
 90          if progress:
 91              progress(0.7, desc="Speaker identification in progress...")
 92          
 93          # Perform diarization
 94          diarization = pipeline(audio_path)
 95          
 96          # Extract segments with speakers
 97          speaker_segments = []
 98          for segment, _, speaker in diarization.itertracks(yield_label=True):
 99              speaker_segments.append((segment.start, segment.end, speaker))
100          
101          if progress:
102              progress(0.9, desc="Speaker identification completed")
103          
104          return speaker_segments
105          
106      except Exception as e:
107          error_msg = f"Error during diarization: {str(e)}"
108          logger.error(error_msg)
109          logger.error(traceback.format_exc())
110          raise Exception(error_msg)
111  
112  def assign_speakers(
113      transcription: Dict[str, Any], 
114      diarization: List[Tuple[float, float, str]]
115  ) -> List[Dict[str, Any]]:
116      """
117      Associates identified speakers with transcription segments
118      
119      Args:
120          transcription: Transcription result from Whisper
121          diarization: Diarization result from Pyannote
122          
123      Returns:
124          List of segments with text and assigned speaker
125      """
126      final_transcription = []
127      
128      segments = transcription["segments"]
129      
130      for segment in segments:
131          start, end, text = segment["start"], segment["end"], segment["text"]
132          speaker = "Unknown"
133          
134          # Find the main speaker for this segment
135          speaker_times = {}
136          
137          for d_start, d_end, d_speaker in diarization:
138              # Calculate overlap
139              overlap_start = max(d_start, start)
140              overlap_end = min(d_end, end)
141              
142              if overlap_start < overlap_end:
143                  overlap_duration = overlap_end - overlap_start
144                  
145                  if d_speaker in speaker_times:
146                      speaker_times[d_speaker] += overlap_duration
147                  else:
148                      speaker_times[d_speaker] = overlap_duration
149          
150          # Select the speaker with the most speaking time in this segment
151          if speaker_times:
152              speaker = max(speaker_times, key=speaker_times.get)
153          
154          # Add the segment with its speaker
155          final_transcription.append({
156              "start": start,
157              "end": end,
158              "speaker": speaker,
159              "text": text
160          })
161      
162      return final_transcription
163  
164  def format_diarized_transcription(
165      transcription: List[Dict[str, Any]], 
166      include_timestamps: bool = True
167  ) -> str:
168      """
169      Formats the diarized transcription into readable text
170      
171      Args:
172          transcription: List of segments with text and speaker
173          include_timestamps: Include timestamps in the output
174          
175      Returns:
176          Formatted text with speakers
177      """
178      formatted_text = []
179      current_speaker = None
180      
181      for segment in transcription:
182          speaker = segment["speaker"]
183          text = segment["text"].strip()
184          start = segment["start"]
185          end = segment["end"]
186          
187          # Format the segment text
188          if speaker != current_speaker:
189              current_speaker = speaker
190              if include_timestamps:
191                  formatted_text.append(f"\n[{format_time(start)}-{format_time(end)}] {speaker}: {text}")
192              else:
193                  formatted_text.append(f"\n{speaker}: {text}")
194          else:
195              if include_timestamps:
196                  formatted_text.append(f" [{format_time(start)}-{format_time(end)}] {text}")
197              else:
198                  formatted_text.append(f" {text}")
199      
200      return "".join(formatted_text).strip()
201  
202  def format_time(seconds: float) -> str:
203      """
204      Formats seconds into hh:mm:ss format
205      
206      Args:
207          seconds: Number of seconds
208          
209      Returns:
210          Formatted string
211      """
212      m, s = divmod(seconds, 60)
213      h, m = divmod(m, 60)
214      return f"{int(h):02d}:{int(m):02d}:{int(s):02d}"