service.py
1 """Bob Diarization Service — real-time speaker diarization + identification. 2 3 Combines diart (streaming diarization) with CAM++ (speaker identification). 4 Exposes a WebSocket API that accepts audio and returns identified speaker segments. 5 6 Architecture: 7 Audio in (PCM16 16kHz) → diart (who's talking when) → CAM++ (that's Cam/AJ/Hailen) → JSON out 8 9 WebSocket protocol: 10 Client sends: raw PCM16 int16 bytes (16kHz mono) 11 Server sends: JSON {"segments": [{"speaker": "Cam", "start": 1.5, "end": 3.2}]} 12 """ 13 14 import asyncio 15 import base64 16 import json 17 import os 18 import sys 19 import threading 20 import time 21 22 import numpy as np 23 import torch 24 25 LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0") 26 LISTEN_PORT = int(os.getenv("LISTEN_PORT", "7007")) 27 HF_TOKEN = os.getenv("HF_TOKEN", "") 28 ENROLLMENT_DB = os.getenv("ENROLLMENT_DB", "/srv/bob/voice-enrollment/speakers.db") 29 DEVICE = os.getenv("DEVICE", "cuda") 30 SAMPLE_RATE = 16000 31 32 33 def load_diart_pipeline(): 34 """Load diart streaming diarization pipeline.""" 35 from diart import SpeakerDiarization 36 from diart.blocks import SpeakerDiarizationConfig 37 from diart.models import SegmentationModel, EmbeddingModel 38 39 print("Loading diart models...") 40 # diart's from_pretrained passes kwargs to pyannote which uses 'use_auth_token' 41 # but the diart wrapper may use a different parameter name 42 try: 43 seg = SegmentationModel.from_pretrained("pyannote/segmentation-3.0", HF_TOKEN or True) 44 except Exception as e: 45 print(f"segmentation-3.0 failed ({e}), trying segmentation") 46 seg = SegmentationModel.from_pretrained("pyannote/segmentation", HF_TOKEN or True) 47 emb = EmbeddingModel.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") 48 49 config = SpeakerDiarizationConfig( 50 segmentation=seg, 51 embedding=emb, 52 duration=5, 53 step=0.5, 54 latency=0.5, 55 tau_active=0.5, 56 rho_update=0.3, 57 delta_new=1.0, 58 max_speakers=5, 59 sample_rate=SAMPLE_RATE, 60 ) 61 pipeline = SpeakerDiarization(config) 62 print("diart pipeline loaded") 63 return pipeline 64 65 66 def load_campplus(): 67 """Load CAM++ speaker embedding model from ModelScope.""" 68 try: 69 from modelscope.hub.snapshot_download import snapshot_download 70 from speakerlab.models.campplus.DTDNN import CAMPPlus 71 from speakerlab.process.processor import FBank 72 73 model_id = "iic/speech_campplus_sv_en_voxceleb_16k" 74 print(f"Downloading CAM++ model: {model_id}...") 75 cache_dir = snapshot_download(model_id, revision="v1.0.2") 76 77 model = CAMPPlus(feat_dim=80, embedding_size=512) 78 state = torch.load(f"{cache_dir}/campplus_voxceleb.bin", map_location="cpu") 79 model.load_state_dict(state) 80 model.eval() 81 82 device = torch.device(DEVICE if torch.cuda.is_available() else "cpu") 83 model.to(device) 84 85 fbank = FBank(80, sample_rate=SAMPLE_RATE, mean_nor=True) 86 print(f"CAM++ loaded on {device}") 87 return model, fbank, device 88 except Exception as e: 89 print(f"WARNING: CAM++ not available: {e}") 90 print("Speaker identification will be anonymous only.") 91 return None, None, None 92 93 94 def load_enrolled_embeddings(): 95 """Load enrolled speaker embeddings from SQLite.""" 96 import sqlite3 97 enrolled = {} 98 try: 99 conn = sqlite3.connect(ENROLLMENT_DB) 100 rows = conn.execute( 101 "SELECT id, name, embedding FROM speakers WHERE embedding IS NOT NULL" 102 ).fetchall() 103 conn.close() 104 for row in rows: 105 emb = np.frombuffer(row[2], dtype=np.float32) 106 # Use short names 107 name = row[1] 108 short_names = {"Cameron Hunt": "Cam", "Adriane Hunt": "AJ", "Hailen Hunt": "Hailen"} 109 name = short_names.get(name, name.split()[0]) 110 enrolled[name] = emb 111 print(f" Enrolled: {name} ({len(emb)}-dim embedding)") 112 except Exception as e: 113 print(f" Could not load enrollments: {e}") 114 return enrolled 115 116 117 def identify_speaker(audio_f32, campplus_model, fbank, device, enrolled, threshold=0.45): 118 """Identify speaker from audio segment using CAM++.""" 119 if campplus_model is None or not enrolled: 120 return None 121 122 if len(audio_f32) < 8000: # Need at least 0.5s 123 return None 124 125 try: 126 wav = torch.from_numpy(audio_f32).unsqueeze(0).float() 127 feat = fbank(wav).unsqueeze(0).to(device) 128 with torch.no_grad(): 129 emb = campplus_model(feat).detach().cpu().numpy().squeeze() 130 131 # Cosine similarity against enrolled speakers 132 best_name = None 133 best_score = -1 134 for name, enrolled_emb in enrolled.items(): 135 # Handle dimension mismatch (ECAPA-TDNN=192, CAM++=512) 136 if len(emb) != len(enrolled_emb): 137 continue 138 score = np.dot(emb, enrolled_emb) / (np.linalg.norm(emb) * np.linalg.norm(enrolled_emb) + 1e-8) 139 if score > best_score: 140 best_score = score 141 best_name = name 142 143 if best_score > threshold: 144 return best_name 145 except Exception as e: 146 pass 147 148 return None 149 150 151 async def run_server(): 152 """Run the diarization WebSocket server.""" 153 import websockets 154 155 # Load models 156 pipeline = load_diart_pipeline() 157 campplus_model, fbank, device = load_campplus() 158 enrolled = load_enrolled_embeddings() 159 160 print(f"\nDiarization service starting on ws://{LISTEN_HOST}:{LISTEN_PORT}") 161 print(f" Enrolled speakers: {list(enrolled.keys()) if enrolled else 'none'}") 162 163 # Use diart's built-in streaming via its WebSocket server 164 # But we'll wrap it to add CAM++ identification 165 from diart.inference import StreamingInference 166 from diart.sources import WebSocketAudioSource 167 168 source = WebSocketAudioSource(SAMPLE_RATE, LISTEN_HOST, LISTEN_PORT) 169 170 inference = StreamingInference( 171 pipeline, source, 172 batch_size=1, 173 do_profile=False, 174 do_plot=False, 175 show_progress=False, 176 ) 177 178 # Track anonymous → identified mapping 179 label_map = {} 180 181 def on_result(ann_wav): 182 annotation, waveform = ann_wav 183 segments = [] 184 for segment, track, speaker in annotation.itertracks(yield_label=True): 185 # Try to identify the speaker 186 real_name = label_map.get(speaker) 187 if real_name is None and campplus_model is not None: 188 start_sample = int(segment.start * SAMPLE_RATE) 189 end_sample = int(segment.end * SAMPLE_RATE) 190 if hasattr(waveform, 'data'): 191 audio_seg = waveform.data[start_sample:end_sample].flatten() 192 else: 193 audio_seg = np.array([]) 194 if len(audio_seg) > 8000: 195 real_name = identify_speaker(audio_seg, campplus_model, fbank, device, enrolled) 196 if real_name: 197 label_map[speaker] = real_name 198 199 segments.append({ 200 "speaker": real_name or speaker, 201 "start": round(segment.start, 3), 202 "end": round(segment.end, 3), 203 }) 204 205 if segments: 206 rttm = annotation.to_rttm() 207 # Send JSON via the source's WebSocket 208 try: 209 source.send(json.dumps({"segments": segments})) 210 except Exception: 211 pass 212 213 inference.attach_hooks(on_result) 214 215 print("Running diarization inference...") 216 inference() 217 218 219 if __name__ == "__main__": 220 asyncio.run(run_server())