/ services / diarization / service.py
service.py
  1  """Bob Diarization Service — real-time speaker diarization + identification.
  2  
  3  Combines diart (streaming diarization) with CAM++ (speaker identification).
  4  Exposes a WebSocket API that accepts audio and returns identified speaker segments.
  5  
  6  Architecture:
  7    Audio in (PCM16 16kHz) → diart (who's talking when) → CAM++ (that's Cam/AJ/Hailen) → JSON out
  8  
  9  WebSocket protocol:
 10    Client sends: raw PCM16 int16 bytes (16kHz mono)
 11    Server sends: JSON {"segments": [{"speaker": "Cam", "start": 1.5, "end": 3.2}]}
 12  """
 13  
 14  import asyncio
 15  import base64
 16  import json
 17  import os
 18  import sys
 19  import threading
 20  import time
 21  
 22  import numpy as np
 23  import torch
 24  
 25  LISTEN_HOST = os.getenv("LISTEN_HOST", "0.0.0.0")
 26  LISTEN_PORT = int(os.getenv("LISTEN_PORT", "7007"))
 27  HF_TOKEN = os.getenv("HF_TOKEN", "")
 28  ENROLLMENT_DB = os.getenv("ENROLLMENT_DB", "/srv/bob/voice-enrollment/speakers.db")
 29  DEVICE = os.getenv("DEVICE", "cuda")
 30  SAMPLE_RATE = 16000
 31  
 32  
 33  def load_diart_pipeline():
 34      """Load diart streaming diarization pipeline."""
 35      from diart import SpeakerDiarization
 36      from diart.blocks import SpeakerDiarizationConfig
 37      from diart.models import SegmentationModel, EmbeddingModel
 38  
 39      print("Loading diart models...")
 40      # diart's from_pretrained passes kwargs to pyannote which uses 'use_auth_token'
 41      # but the diart wrapper may use a different parameter name
 42      try:
 43          seg = SegmentationModel.from_pretrained("pyannote/segmentation-3.0", HF_TOKEN or True)
 44      except Exception as e:
 45          print(f"segmentation-3.0 failed ({e}), trying segmentation")
 46          seg = SegmentationModel.from_pretrained("pyannote/segmentation", HF_TOKEN or True)
 47      emb = EmbeddingModel.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM")
 48  
 49      config = SpeakerDiarizationConfig(
 50          segmentation=seg,
 51          embedding=emb,
 52          duration=5,
 53          step=0.5,
 54          latency=0.5,
 55          tau_active=0.5,
 56          rho_update=0.3,
 57          delta_new=1.0,
 58          max_speakers=5,
 59          sample_rate=SAMPLE_RATE,
 60      )
 61      pipeline = SpeakerDiarization(config)
 62      print("diart pipeline loaded")
 63      return pipeline
 64  
 65  
 66  def load_campplus():
 67      """Load CAM++ speaker embedding model from ModelScope."""
 68      try:
 69          from modelscope.hub.snapshot_download import snapshot_download
 70          from speakerlab.models.campplus.DTDNN import CAMPPlus
 71          from speakerlab.process.processor import FBank
 72  
 73          model_id = "iic/speech_campplus_sv_en_voxceleb_16k"
 74          print(f"Downloading CAM++ model: {model_id}...")
 75          cache_dir = snapshot_download(model_id, revision="v1.0.2")
 76  
 77          model = CAMPPlus(feat_dim=80, embedding_size=512)
 78          state = torch.load(f"{cache_dir}/campplus_voxceleb.bin", map_location="cpu")
 79          model.load_state_dict(state)
 80          model.eval()
 81  
 82          device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
 83          model.to(device)
 84  
 85          fbank = FBank(80, sample_rate=SAMPLE_RATE, mean_nor=True)
 86          print(f"CAM++ loaded on {device}")
 87          return model, fbank, device
 88      except Exception as e:
 89          print(f"WARNING: CAM++ not available: {e}")
 90          print("Speaker identification will be anonymous only.")
 91          return None, None, None
 92  
 93  
 94  def load_enrolled_embeddings():
 95      """Load enrolled speaker embeddings from SQLite."""
 96      import sqlite3
 97      enrolled = {}
 98      try:
 99          conn = sqlite3.connect(ENROLLMENT_DB)
100          rows = conn.execute(
101              "SELECT id, name, embedding FROM speakers WHERE embedding IS NOT NULL"
102          ).fetchall()
103          conn.close()
104          for row in rows:
105              emb = np.frombuffer(row[2], dtype=np.float32)
106              # Use short names
107              name = row[1]
108              short_names = {"Cameron Hunt": "Cam", "Adriane Hunt": "AJ", "Hailen Hunt": "Hailen"}
109              name = short_names.get(name, name.split()[0])
110              enrolled[name] = emb
111              print(f"  Enrolled: {name} ({len(emb)}-dim embedding)")
112      except Exception as e:
113          print(f"  Could not load enrollments: {e}")
114      return enrolled
115  
116  
117  def identify_speaker(audio_f32, campplus_model, fbank, device, enrolled, threshold=0.45):
118      """Identify speaker from audio segment using CAM++."""
119      if campplus_model is None or not enrolled:
120          return None
121  
122      if len(audio_f32) < 8000:  # Need at least 0.5s
123          return None
124  
125      try:
126          wav = torch.from_numpy(audio_f32).unsqueeze(0).float()
127          feat = fbank(wav).unsqueeze(0).to(device)
128          with torch.no_grad():
129              emb = campplus_model(feat).detach().cpu().numpy().squeeze()
130  
131          # Cosine similarity against enrolled speakers
132          best_name = None
133          best_score = -1
134          for name, enrolled_emb in enrolled.items():
135              # Handle dimension mismatch (ECAPA-TDNN=192, CAM++=512)
136              if len(emb) != len(enrolled_emb):
137                  continue
138              score = np.dot(emb, enrolled_emb) / (np.linalg.norm(emb) * np.linalg.norm(enrolled_emb) + 1e-8)
139              if score > best_score:
140                  best_score = score
141                  best_name = name
142  
143          if best_score > threshold:
144              return best_name
145      except Exception as e:
146          pass
147  
148      return None
149  
150  
151  async def run_server():
152      """Run the diarization WebSocket server."""
153      import websockets
154  
155      # Load models
156      pipeline = load_diart_pipeline()
157      campplus_model, fbank, device = load_campplus()
158      enrolled = load_enrolled_embeddings()
159  
160      print(f"\nDiarization service starting on ws://{LISTEN_HOST}:{LISTEN_PORT}")
161      print(f"  Enrolled speakers: {list(enrolled.keys()) if enrolled else 'none'}")
162  
163      # Use diart's built-in streaming via its WebSocket server
164      # But we'll wrap it to add CAM++ identification
165      from diart.inference import StreamingInference
166      from diart.sources import WebSocketAudioSource
167  
168      source = WebSocketAudioSource(SAMPLE_RATE, LISTEN_HOST, LISTEN_PORT)
169  
170      inference = StreamingInference(
171          pipeline, source,
172          batch_size=1,
173          do_profile=False,
174          do_plot=False,
175          show_progress=False,
176      )
177  
178      # Track anonymous → identified mapping
179      label_map = {}
180  
181      def on_result(ann_wav):
182          annotation, waveform = ann_wav
183          segments = []
184          for segment, track, speaker in annotation.itertracks(yield_label=True):
185              # Try to identify the speaker
186              real_name = label_map.get(speaker)
187              if real_name is None and campplus_model is not None:
188                  start_sample = int(segment.start * SAMPLE_RATE)
189                  end_sample = int(segment.end * SAMPLE_RATE)
190                  if hasattr(waveform, 'data'):
191                      audio_seg = waveform.data[start_sample:end_sample].flatten()
192                  else:
193                      audio_seg = np.array([])
194                  if len(audio_seg) > 8000:
195                      real_name = identify_speaker(audio_seg, campplus_model, fbank, device, enrolled)
196                      if real_name:
197                          label_map[speaker] = real_name
198  
199              segments.append({
200                  "speaker": real_name or speaker,
201                  "start": round(segment.start, 3),
202                  "end": round(segment.end, 3),
203              })
204  
205          if segments:
206              rttm = annotation.to_rttm()
207              # Send JSON via the source's WebSocket
208              try:
209                  source.send(json.dumps({"segments": segments}))
210              except Exception:
211                  pass
212  
213      inference.attach_hooks(on_result)
214  
215      print("Running diarization inference...")
216      inference()
217  
218  
219  if __name__ == "__main__":
220      asyncio.run(run_server())