service.py
1 """Bob Voice Enrollment Service — records family voice samples and extracts speaker embeddings. 2 3 Provides HTTP API for: 4 - Recording voice samples (wake word + natural speech) 5 - Extracting CAM++ speaker embeddings 6 - Storing enrollment data (embeddings + raw audio for wake word retraining) 7 - Listing enrolled speakers 8 - Identifying speaker from audio sample 9 10 Storage layout: 11 /srv/bob/voice-enrollment/ 12 speakers.db — SQLite: speaker metadata + embeddings 13 recordings/ 14 cam/ 15 hey_bob_001.wav — wake word samples 16 enroll_001.wav — enrollment speech samples 17 aj/ 18 ... 19 """ 20 21 import asyncio 22 import io 23 import json 24 import os 25 import sqlite3 26 import struct 27 import time 28 import wave 29 from datetime import datetime, timezone 30 from http.server import HTTPServer, BaseHTTPRequestHandler 31 from pathlib import Path 32 from urllib.parse import parse_qs, urlparse 33 34 import numpy as np 35 36 DATA_DIR = Path(os.getenv("DATA_DIR", "/srv/bob/voice-enrollment")) 37 LISTEN_PORT = int(os.getenv("LISTEN_PORT", "10800")) 38 EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "cam++") 39 40 # CAM++ or ECAPA-TDNN 41 embedding_extractor = None 42 43 44 def init_db(): 45 """Initialize SQLite database for speaker enrollment.""" 46 db_path = DATA_DIR / "speakers.db" 47 DATA_DIR.mkdir(parents=True, exist_ok=True) 48 conn = sqlite3.connect(str(db_path)) 49 conn.execute(""" 50 CREATE TABLE IF NOT EXISTS speakers ( 51 id TEXT PRIMARY KEY, 52 name TEXT NOT NULL, 53 family_role TEXT, 54 embedding BLOB, 55 embedding_dim INTEGER, 56 enrollment_samples INTEGER DEFAULT 0, 57 wakeword_samples INTEGER DEFAULT 0, 58 created_at TEXT, 59 updated_at TEXT 60 ) 61 """) 62 conn.execute(""" 63 CREATE TABLE IF NOT EXISTS recordings ( 64 id INTEGER PRIMARY KEY AUTOINCREMENT, 65 speaker_id TEXT NOT NULL, 66 type TEXT NOT NULL, -- 'wakeword' or 'enrollment' 67 filename TEXT NOT NULL, 68 duration_secs REAL, 69 created_at TEXT, 70 FOREIGN KEY (speaker_id) REFERENCES speakers(id) 71 ) 72 """) 73 conn.commit() 74 return conn 75 76 77 def init_embedding_model(): 78 """Initialize the CAM++ speaker embedding model from 3D-Speaker.""" 79 global embedding_extractor 80 try: 81 import torch 82 from modelscope.hub.snapshot_download import snapshot_download 83 from speakerlab.models.campplus.DTDNN import CAMPPlus 84 from speakerlab.process.processor import FBank 85 86 model_id = "iic/speech_campplus_sv_en_voxceleb_16k" 87 print(f"Loading CAM++ model: {model_id}...") 88 cache_dir = snapshot_download(model_id, revision="v1.0.2") 89 90 model = CAMPPlus(feat_dim=80, embedding_size=512) 91 state = torch.load(f"{cache_dir}/campplus_voxceleb.bin", map_location="cpu", weights_only=True) 92 model.load_state_dict(state) 93 model.eval() 94 95 fbank = FBank(80, sample_rate=16000, mean_nor=True) 96 embedding_extractor = {"model": model, "fbank": fbank, "device": torch.device("cpu")} 97 print(f"Loaded CAM++ speaker embedding model (CPU, 512-dim)") 98 except Exception as e: 99 print(f"WARNING: Could not load CAM++ model: {e}") 100 print("Speaker identification will be unavailable. Enrollment recordings will still be saved.") 101 102 103 def extract_embedding(audio_bytes: bytes) -> np.ndarray | None: 104 """Extract 512-dim speaker embedding from WAV audio bytes using CAM++.""" 105 if embedding_extractor is None: 106 return None 107 108 try: 109 import torch 110 import soundfile as sf 111 112 model = embedding_extractor["model"] 113 fbank = embedding_extractor["fbank"] 114 device = embedding_extractor["device"] 115 116 # Load audio from bytes using soundfile (avoids torchcodec dep) 117 buf = io.BytesIO(audio_bytes) 118 data, sr = sf.read(buf, dtype="float32") 119 120 # Resample to 16kHz if needed 121 if sr != 16000: 122 import torchaudio 123 waveform = torch.from_numpy(data).unsqueeze(0) if data.ndim == 1 else torch.from_numpy(data.T) 124 waveform = torchaudio.functional.resample(waveform, sr, 16000) 125 else: 126 waveform = torch.from_numpy(data).unsqueeze(0) if data.ndim == 1 else torch.from_numpy(data.T) 127 128 # Mono 129 if waveform.shape[0] > 1: 130 waveform = waveform.mean(dim=0, keepdim=True) 131 132 # Extract features and embedding 133 feat = fbank(waveform).unsqueeze(0).to(device) 134 with torch.no_grad(): 135 emb = model(feat).detach().cpu().numpy().squeeze() 136 137 return emb # shape: (512,) 138 except Exception as e: 139 print(f"Embedding extraction error: {e}") 140 return None 141 142 143 def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: 144 """Compute cosine similarity between two vectors.""" 145 return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)) 146 147 148 class EnrollmentHandler(BaseHTTPRequestHandler): 149 """HTTP handler for voice enrollment API.""" 150 151 def _send_json(self, status: int, data: dict): 152 self.send_response(status) 153 self.send_header("Content-Type", "application/json") 154 self.send_header("Access-Control-Allow-Origin", "*") 155 self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS") 156 self.send_header("Access-Control-Allow-Headers", "Content-Type") 157 self.end_headers() 158 self.wfile.write(json.dumps(data).encode()) 159 160 def do_OPTIONS(self): 161 self._send_json(200, {}) 162 163 def do_GET(self): 164 path = urlparse(self.path).path 165 166 if path == "/speakers": 167 conn = init_db() 168 rows = conn.execute( 169 "SELECT id, name, family_role, enrollment_samples, wakeword_samples FROM speakers" 170 ).fetchall() 171 conn.close() 172 speakers = [ 173 {"id": r[0], "name": r[1], "family_role": r[2], 174 "enrollment_samples": r[3], "wakeword_samples": r[4]} 175 for r in rows 176 ] 177 self._send_json(200, {"speakers": speakers}) 178 179 elif path == "/prompts": 180 # Return text prompts for enrollment recordings 181 self._send_json(200, { 182 "wakeword_prompts": [ 183 "Hey Bob", 184 "Hey Bob!", 185 "Hey Bob, are you there?", 186 "Hey Bob, help me out", 187 "Hey Bob, what's up", 188 ], 189 "enrollment_prompts": [ 190 "The quick brown fox jumps over the lazy dog. I need Bob to recognize my voice so he can help me personally.", 191 "My name is {name} and I live in Tampa, Florida. Bob is our family assistant and I want him to know who I am.", 192 "Good morning Bob, can you tell me what the weather is like today? I'd also like to know what's on my calendar.", 193 ], 194 "instructions": { 195 "wakeword": "Say each prompt clearly, at a natural pace. Record 5-10 samples.", 196 "enrollment": "Read each paragraph naturally. These help Bob learn your voice.", 197 }, 198 }) 199 200 elif path == "/health": 201 self._send_json(200, { 202 "status": "ok", 203 "embedding_model": EMBEDDING_MODEL, 204 "model_loaded": embedding_extractor is not None, 205 }) 206 207 else: 208 self._send_json(404, {"error": "Not found"}) 209 210 def do_POST(self): 211 path = urlparse(self.path).path 212 content_length = int(self.headers.get("Content-Length", 0)) 213 214 if path == "/enroll": 215 # Register a new speaker 216 body = json.loads(self.rfile.read(content_length)) 217 speaker_id = body.get("id", "").lower().replace(" ", "_") 218 name = body.get("name", "") 219 family_role = body.get("family_role", "") 220 221 if not speaker_id or not name: 222 self._send_json(400, {"error": "id and name required"}) 223 return 224 225 conn = init_db() 226 # Create speaker record directory 227 (DATA_DIR / "recordings" / speaker_id).mkdir(parents=True, exist_ok=True) 228 229 conn.execute( 230 "INSERT OR REPLACE INTO speakers (id, name, family_role, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", 231 (speaker_id, name, family_role, 232 datetime.now(timezone.utc).isoformat(), 233 datetime.now(timezone.utc).isoformat()), 234 ) 235 conn.commit() 236 conn.close() 237 print(f"Enrolled speaker: {name} ({speaker_id})") 238 self._send_json(200, {"status": "enrolled", "id": speaker_id, "name": name}) 239 240 elif path == "/upload": 241 # Upload a voice recording 242 # Expect multipart or raw WAV with query params 243 params = parse_qs(urlparse(self.path).query) 244 speaker_id = params.get("speaker_id", [""])[0] 245 rec_type = params.get("type", ["enrollment"])[0] # 'wakeword' or 'enrollment' 246 247 if not speaker_id: 248 self._send_json(400, {"error": "speaker_id query param required"}) 249 return 250 251 audio_data = self.rfile.read(content_length) 252 253 # Save recording 254 rec_dir = DATA_DIR / "recordings" / speaker_id 255 rec_dir.mkdir(parents=True, exist_ok=True) 256 timestamp = int(time.time()) 257 filename = f"{rec_type}_{timestamp}.wav" 258 filepath = rec_dir / filename 259 260 with open(filepath, "wb") as f: 261 f.write(audio_data) 262 263 # Get duration 264 try: 265 with wave.open(str(filepath), "rb") as w: 266 duration = w.getnframes() / w.getframerate() 267 except Exception: 268 duration = 0 269 270 # Store recording metadata 271 conn = init_db() 272 conn.execute( 273 "INSERT INTO recordings (speaker_id, type, filename, duration_secs, created_at) VALUES (?, ?, ?, ?, ?)", 274 (speaker_id, rec_type, filename, duration, 275 datetime.now(timezone.utc).isoformat()), 276 ) 277 278 # Update sample counts 279 if rec_type == "wakeword": 280 conn.execute( 281 "UPDATE speakers SET wakeword_samples = wakeword_samples + 1, updated_at = ? WHERE id = ?", 282 (datetime.now(timezone.utc).isoformat(), speaker_id), 283 ) 284 else: 285 conn.execute( 286 "UPDATE speakers SET enrollment_samples = enrollment_samples + 1, updated_at = ? WHERE id = ?", 287 (datetime.now(timezone.utc).isoformat(), speaker_id), 288 ) 289 290 conn.commit() 291 292 # Extract embedding if we have enough enrollment samples 293 enrollment_count = conn.execute( 294 "SELECT enrollment_samples FROM speakers WHERE id = ?", (speaker_id,) 295 ).fetchone()[0] 296 297 embedding_status = "skipped" 298 if rec_type == "enrollment" and embedding_extractor is not None: 299 emb = extract_embedding(audio_data) 300 if emb is not None: 301 # Average with existing embedding if present 302 existing = conn.execute( 303 "SELECT embedding, embedding_dim FROM speakers WHERE id = ?", (speaker_id,) 304 ).fetchone() 305 306 if existing[0] is not None: 307 old_emb = np.frombuffer(existing[0], dtype=np.float32) 308 # Running average 309 new_emb = (old_emb * (enrollment_count - 1) + emb) / enrollment_count 310 else: 311 new_emb = emb 312 313 conn.execute( 314 "UPDATE speakers SET embedding = ?, embedding_dim = ?, updated_at = ? WHERE id = ?", 315 (new_emb.astype(np.float32).tobytes(), len(new_emb), 316 datetime.now(timezone.utc).isoformat(), speaker_id), 317 ) 318 conn.commit() 319 embedding_status = "updated" 320 321 conn.close() 322 print(f"Saved {rec_type} recording for {speaker_id}: {filename} ({duration:.1f}s)") 323 self._send_json(200, { 324 "status": "saved", 325 "filename": filename, 326 "duration_secs": round(duration, 1), 327 "type": rec_type, 328 "embedding_status": embedding_status, 329 "enrollment_samples": enrollment_count if rec_type == "enrollment" else None, 330 }) 331 332 elif path == "/identify": 333 # Identify speaker from audio 334 audio_data = self.rfile.read(content_length) 335 336 if embedding_extractor is None: 337 self._send_json(503, {"error": "Embedding model not loaded"}) 338 return 339 340 emb = extract_embedding(audio_data) 341 if emb is None: 342 self._send_json(500, {"error": "Could not extract embedding"}) 343 return 344 345 conn = init_db() 346 rows = conn.execute( 347 "SELECT id, name, embedding FROM speakers WHERE embedding IS NOT NULL" 348 ).fetchall() 349 conn.close() 350 351 best_match = None 352 best_score = 0 353 for row in rows: 354 stored_emb = np.frombuffer(row[2], dtype=np.float32) 355 score = cosine_similarity(emb, stored_emb) 356 if score > best_score: 357 best_score = score 358 best_match = {"id": row[0], "name": row[1], "score": round(score, 4)} 359 360 if best_match and best_score > 0.65: 361 self._send_json(200, {"identified": True, "speaker": best_match}) 362 else: 363 self._send_json(200, {"identified": False, "best_match": best_match}) 364 365 else: 366 self._send_json(404, {"error": "Not found"}) 367 368 def log_message(self, format, *args): 369 pass # Suppress default logging 370 371 372 def main(): 373 DATA_DIR.mkdir(parents=True, exist_ok=True) 374 init_db() 375 init_embedding_model() 376 377 # Pre-register family members 378 conn = init_db() 379 for speaker in [ 380 ("cam", "Cameron Hunt", "Dad"), 381 ("aj", "Adriane Hunt", "Mom"), 382 ("hailen", "Hailen Hunt", "Son"), 383 ]: 384 conn.execute( 385 "INSERT OR IGNORE INTO speakers (id, name, family_role, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", 386 (speaker[0], speaker[1], speaker[2], 387 datetime.now(timezone.utc).isoformat(), 388 datetime.now(timezone.utc).isoformat()), 389 ) 390 conn.commit() 391 conn.close() 392 print(f"Family members pre-registered: cam, aj, hailen") 393 394 server = HTTPServer(("0.0.0.0", LISTEN_PORT), EnrollmentHandler) 395 print(f"Voice enrollment service listening on :{LISTEN_PORT}") 396 server.serve_forever() 397 398 399 if __name__ == "__main__": 400 main()