/ transcription_models / diarization.py
diarization.py
1 """ 2 Speaker diarization module for video transcription 3 ------------------------------------------------ 4 This module provides functions for speaker identification (diarization) 5 in audio files and for attributing speakers to transcription segments. 6 """ 7 8 import os 9 import logging 10 import traceback 11 from typing import List, Dict, Tuple, Optional, Callable, Any, Union 12 13 # Logging configuration 14 logger = logging.getLogger("transcription.diarization") 15 16 # Check optional dependencies 17 try: 18 from pyannote.audio import Pipeline 19 PYANNOTE_AVAILABLE = True 20 except ImportError: 21 PYANNOTE_AVAILABLE = False 22 logger.warning("Pyannote.audio not available. Diarization will be disabled.") 23 24 # Configuration 25 HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN", "") 26 DIARIZATION_MODEL = "pyannote/speaker-diarization-3.1" 27 28 # Global variable for the model 29 diarization_pipeline = None 30 31 def load_diarization_model(huggingface_token: Optional[str] = None): 32 """ 33 Loads the diarization model 34 35 Args: 36 huggingface_token: Hugging Face token for model access 37 38 Returns: 39 Diarization pipeline 40 41 Raises: 42 ImportError: If pyannote.audio is not available 43 ValueError: If no token is provided 44 """ 45 global diarization_pipeline 46 47 if not PYANNOTE_AVAILABLE: 48 raise ImportError("Pyannote.audio is required for diarization") 49 50 token = huggingface_token or HUGGINGFACE_TOKEN 51 if not token: 52 raise ValueError("A Hugging Face token is required for diarization") 53 54 if diarization_pipeline is None: 55 diarization_pipeline = Pipeline.from_pretrained( 56 DIARIZATION_MODEL, 57 use_auth_token=token 58 ) 59 60 return diarization_pipeline 61 62 def diarize_audio( 63 audio_path: str, 64 huggingface_token: Optional[str] = None, 65 progress: Optional[Callable] = None 66 ) -> List[Tuple[float, float, str]]: 67 """ 68 Identifies speakers in an audio file 69 70 Args: 71 audio_path: Path to the audio file 72 huggingface_token: Hugging Face token for model access 73 progress: Progress tracking function (optional) 74 75 Returns: 76 List of segments with speaker information (start, end, speaker) 77 78 Raises: 79 ImportError: If pyannote.audio is not available 80 ValueError: If no token is provided 81 Exception: If an error occurs during diarization 82 """ 83 try: 84 if progress: 85 progress(0.6, desc="Loading diarization model...") 86 87 # Initialize diarization pipeline 88 pipeline = load_diarization_model(huggingface_token) 89 90 if progress: 91 progress(0.7, desc="Speaker identification in progress...") 92 93 # Perform diarization 94 diarization = pipeline(audio_path) 95 96 # Extract segments with speakers 97 speaker_segments = [] 98 for segment, _, speaker in diarization.itertracks(yield_label=True): 99 speaker_segments.append((segment.start, segment.end, speaker)) 100 101 if progress: 102 progress(0.9, desc="Speaker identification completed") 103 104 return speaker_segments 105 106 except Exception as e: 107 error_msg = f"Error during diarization: {str(e)}" 108 logger.error(error_msg) 109 logger.error(traceback.format_exc()) 110 raise Exception(error_msg) 111 112 def assign_speakers( 113 transcription: Dict[str, Any], 114 diarization: List[Tuple[float, float, str]] 115 ) -> List[Dict[str, Any]]: 116 """ 117 Associates identified speakers with transcription segments 118 119 Args: 120 transcription: Transcription result from Whisper 121 diarization: Diarization result from Pyannote 122 123 Returns: 124 List of segments with text and assigned speaker 125 """ 126 final_transcription = [] 127 128 segments = transcription["segments"] 129 130 for segment in segments: 131 start, end, text = segment["start"], segment["end"], segment["text"] 132 speaker = "Unknown" 133 134 # Find the main speaker for this segment 135 speaker_times = {} 136 137 for d_start, d_end, d_speaker in diarization: 138 # Calculate overlap 139 overlap_start = max(d_start, start) 140 overlap_end = min(d_end, end) 141 142 if overlap_start < overlap_end: 143 overlap_duration = overlap_end - overlap_start 144 145 if d_speaker in speaker_times: 146 speaker_times[d_speaker] += overlap_duration 147 else: 148 speaker_times[d_speaker] = overlap_duration 149 150 # Select the speaker with the most speaking time in this segment 151 if speaker_times: 152 speaker = max(speaker_times, key=speaker_times.get) 153 154 # Add the segment with its speaker 155 final_transcription.append({ 156 "start": start, 157 "end": end, 158 "speaker": speaker, 159 "text": text 160 }) 161 162 return final_transcription 163 164 def format_diarized_transcription( 165 transcription: List[Dict[str, Any]], 166 include_timestamps: bool = True 167 ) -> str: 168 """ 169 Formats the diarized transcription into readable text 170 171 Args: 172 transcription: List of segments with text and speaker 173 include_timestamps: Include timestamps in the output 174 175 Returns: 176 Formatted text with speakers 177 """ 178 formatted_text = [] 179 current_speaker = None 180 181 for segment in transcription: 182 speaker = segment["speaker"] 183 text = segment["text"].strip() 184 start = segment["start"] 185 end = segment["end"] 186 187 # Format the segment text 188 if speaker != current_speaker: 189 current_speaker = speaker 190 if include_timestamps: 191 formatted_text.append(f"\n[{format_time(start)}-{format_time(end)}] {speaker}: {text}") 192 else: 193 formatted_text.append(f"\n{speaker}: {text}") 194 else: 195 if include_timestamps: 196 formatted_text.append(f" [{format_time(start)}-{format_time(end)}] {text}") 197 else: 198 formatted_text.append(f" {text}") 199 200 return "".join(formatted_text).strip() 201 202 def format_time(seconds: float) -> str: 203 """ 204 Formats seconds into hh:mm:ss format 205 206 Args: 207 seconds: Number of seconds 208 209 Returns: 210 Formatted string 211 """ 212 m, s = divmod(seconds, 60) 213 h, m = divmod(m, 60) 214 return f"{int(h):02d}:{int(m):02d}:{int(s):02d}"