/ services / voice-enrollment / service.py
service.py
  1  """Bob Voice Enrollment Service — records family voice samples and extracts speaker embeddings.
  2  
  3  Provides HTTP API for:
  4  - Recording voice samples (wake word + natural speech)
  5  - Extracting CAM++ speaker embeddings
  6  - Storing enrollment data (embeddings + raw audio for wake word retraining)
  7  - Listing enrolled speakers
  8  - Identifying speaker from audio sample
  9  
 10  Storage layout:
 11    /srv/bob/voice-enrollment/
 12      speakers.db          — SQLite: speaker metadata + embeddings
 13      recordings/
 14        cam/
 15          hey_bob_001.wav  — wake word samples
 16          enroll_001.wav   — enrollment speech samples
 17        aj/
 18          ...
 19  """
 20  
 21  import asyncio
 22  import io
 23  import json
 24  import os
 25  import sqlite3
 26  import struct
 27  import time
 28  import wave
 29  from datetime import datetime, timezone
 30  from http.server import HTTPServer, BaseHTTPRequestHandler
 31  from pathlib import Path
 32  from urllib.parse import parse_qs, urlparse
 33  
 34  import numpy as np
 35  
 36  DATA_DIR = Path(os.getenv("DATA_DIR", "/srv/bob/voice-enrollment"))
 37  LISTEN_PORT = int(os.getenv("LISTEN_PORT", "10800"))
 38  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "cam++")
 39  
 40  # CAM++ or ECAPA-TDNN
 41  embedding_extractor = None
 42  
 43  
 44  def init_db():
 45      """Initialize SQLite database for speaker enrollment."""
 46      db_path = DATA_DIR / "speakers.db"
 47      DATA_DIR.mkdir(parents=True, exist_ok=True)
 48      conn = sqlite3.connect(str(db_path))
 49      conn.execute("""
 50          CREATE TABLE IF NOT EXISTS speakers (
 51              id TEXT PRIMARY KEY,
 52              name TEXT NOT NULL,
 53              family_role TEXT,
 54              embedding BLOB,
 55              embedding_dim INTEGER,
 56              enrollment_samples INTEGER DEFAULT 0,
 57              wakeword_samples INTEGER DEFAULT 0,
 58              created_at TEXT,
 59              updated_at TEXT
 60          )
 61      """)
 62      conn.execute("""
 63          CREATE TABLE IF NOT EXISTS recordings (
 64              id INTEGER PRIMARY KEY AUTOINCREMENT,
 65              speaker_id TEXT NOT NULL,
 66              type TEXT NOT NULL,  -- 'wakeword' or 'enrollment'
 67              filename TEXT NOT NULL,
 68              duration_secs REAL,
 69              created_at TEXT,
 70              FOREIGN KEY (speaker_id) REFERENCES speakers(id)
 71          )
 72      """)
 73      conn.commit()
 74      return conn
 75  
 76  
 77  def init_embedding_model():
 78      """Initialize the CAM++ speaker embedding model from 3D-Speaker."""
 79      global embedding_extractor
 80      try:
 81          import torch
 82          from modelscope.hub.snapshot_download import snapshot_download
 83          from speakerlab.models.campplus.DTDNN import CAMPPlus
 84          from speakerlab.process.processor import FBank
 85  
 86          model_id = "iic/speech_campplus_sv_en_voxceleb_16k"
 87          print(f"Loading CAM++ model: {model_id}...")
 88          cache_dir = snapshot_download(model_id, revision="v1.0.2")
 89  
 90          model = CAMPPlus(feat_dim=80, embedding_size=512)
 91          state = torch.load(f"{cache_dir}/campplus_voxceleb.bin", map_location="cpu", weights_only=True)
 92          model.load_state_dict(state)
 93          model.eval()
 94  
 95          fbank = FBank(80, sample_rate=16000, mean_nor=True)
 96          embedding_extractor = {"model": model, "fbank": fbank, "device": torch.device("cpu")}
 97          print(f"Loaded CAM++ speaker embedding model (CPU, 512-dim)")
 98      except Exception as e:
 99          print(f"WARNING: Could not load CAM++ model: {e}")
100          print("Speaker identification will be unavailable. Enrollment recordings will still be saved.")
101  
102  
103  def extract_embedding(audio_bytes: bytes) -> np.ndarray | None:
104      """Extract 512-dim speaker embedding from WAV audio bytes using CAM++."""
105      if embedding_extractor is None:
106          return None
107  
108      try:
109          import torch
110          import soundfile as sf
111  
112          model = embedding_extractor["model"]
113          fbank = embedding_extractor["fbank"]
114          device = embedding_extractor["device"]
115  
116          # Load audio from bytes using soundfile (avoids torchcodec dep)
117          buf = io.BytesIO(audio_bytes)
118          data, sr = sf.read(buf, dtype="float32")
119  
120          # Resample to 16kHz if needed
121          if sr != 16000:
122              import torchaudio
123              waveform = torch.from_numpy(data).unsqueeze(0) if data.ndim == 1 else torch.from_numpy(data.T)
124              waveform = torchaudio.functional.resample(waveform, sr, 16000)
125          else:
126              waveform = torch.from_numpy(data).unsqueeze(0) if data.ndim == 1 else torch.from_numpy(data.T)
127  
128          # Mono
129          if waveform.shape[0] > 1:
130              waveform = waveform.mean(dim=0, keepdim=True)
131  
132          # Extract features and embedding
133          feat = fbank(waveform).unsqueeze(0).to(device)
134          with torch.no_grad():
135              emb = model(feat).detach().cpu().numpy().squeeze()
136  
137          return emb  # shape: (512,)
138      except Exception as e:
139          print(f"Embedding extraction error: {e}")
140          return None
141  
142  
143  def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
144      """Compute cosine similarity between two vectors."""
145      return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8))
146  
147  
148  class EnrollmentHandler(BaseHTTPRequestHandler):
149      """HTTP handler for voice enrollment API."""
150  
151      def _send_json(self, status: int, data: dict):
152          self.send_response(status)
153          self.send_header("Content-Type", "application/json")
154          self.send_header("Access-Control-Allow-Origin", "*")
155          self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
156          self.send_header("Access-Control-Allow-Headers", "Content-Type")
157          self.end_headers()
158          self.wfile.write(json.dumps(data).encode())
159  
160      def do_OPTIONS(self):
161          self._send_json(200, {})
162  
163      def do_GET(self):
164          path = urlparse(self.path).path
165  
166          if path == "/speakers":
167              conn = init_db()
168              rows = conn.execute(
169                  "SELECT id, name, family_role, enrollment_samples, wakeword_samples FROM speakers"
170              ).fetchall()
171              conn.close()
172              speakers = [
173                  {"id": r[0], "name": r[1], "family_role": r[2],
174                   "enrollment_samples": r[3], "wakeword_samples": r[4]}
175                  for r in rows
176              ]
177              self._send_json(200, {"speakers": speakers})
178  
179          elif path == "/prompts":
180              # Return text prompts for enrollment recordings
181              self._send_json(200, {
182                  "wakeword_prompts": [
183                      "Hey Bob",
184                      "Hey Bob!",
185                      "Hey Bob, are you there?",
186                      "Hey Bob, help me out",
187                      "Hey Bob, what's up",
188                  ],
189                  "enrollment_prompts": [
190                      "The quick brown fox jumps over the lazy dog. I need Bob to recognize my voice so he can help me personally.",
191                      "My name is {name} and I live in Tampa, Florida. Bob is our family assistant and I want him to know who I am.",
192                      "Good morning Bob, can you tell me what the weather is like today? I'd also like to know what's on my calendar.",
193                  ],
194                  "instructions": {
195                      "wakeword": "Say each prompt clearly, at a natural pace. Record 5-10 samples.",
196                      "enrollment": "Read each paragraph naturally. These help Bob learn your voice.",
197                  },
198              })
199  
200          elif path == "/health":
201              self._send_json(200, {
202                  "status": "ok",
203                  "embedding_model": EMBEDDING_MODEL,
204                  "model_loaded": embedding_extractor is not None,
205              })
206  
207          else:
208              self._send_json(404, {"error": "Not found"})
209  
210      def do_POST(self):
211          path = urlparse(self.path).path
212          content_length = int(self.headers.get("Content-Length", 0))
213  
214          if path == "/enroll":
215              # Register a new speaker
216              body = json.loads(self.rfile.read(content_length))
217              speaker_id = body.get("id", "").lower().replace(" ", "_")
218              name = body.get("name", "")
219              family_role = body.get("family_role", "")
220  
221              if not speaker_id or not name:
222                  self._send_json(400, {"error": "id and name required"})
223                  return
224  
225              conn = init_db()
226              # Create speaker record directory
227              (DATA_DIR / "recordings" / speaker_id).mkdir(parents=True, exist_ok=True)
228  
229              conn.execute(
230                  "INSERT OR REPLACE INTO speakers (id, name, family_role, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
231                  (speaker_id, name, family_role,
232                   datetime.now(timezone.utc).isoformat(),
233                   datetime.now(timezone.utc).isoformat()),
234              )
235              conn.commit()
236              conn.close()
237              print(f"Enrolled speaker: {name} ({speaker_id})")
238              self._send_json(200, {"status": "enrolled", "id": speaker_id, "name": name})
239  
240          elif path == "/upload":
241              # Upload a voice recording
242              # Expect multipart or raw WAV with query params
243              params = parse_qs(urlparse(self.path).query)
244              speaker_id = params.get("speaker_id", [""])[0]
245              rec_type = params.get("type", ["enrollment"])[0]  # 'wakeword' or 'enrollment'
246  
247              if not speaker_id:
248                  self._send_json(400, {"error": "speaker_id query param required"})
249                  return
250  
251              audio_data = self.rfile.read(content_length)
252  
253              # Save recording
254              rec_dir = DATA_DIR / "recordings" / speaker_id
255              rec_dir.mkdir(parents=True, exist_ok=True)
256              timestamp = int(time.time())
257              filename = f"{rec_type}_{timestamp}.wav"
258              filepath = rec_dir / filename
259  
260              with open(filepath, "wb") as f:
261                  f.write(audio_data)
262  
263              # Get duration
264              try:
265                  with wave.open(str(filepath), "rb") as w:
266                      duration = w.getnframes() / w.getframerate()
267              except Exception:
268                  duration = 0
269  
270              # Store recording metadata
271              conn = init_db()
272              conn.execute(
273                  "INSERT INTO recordings (speaker_id, type, filename, duration_secs, created_at) VALUES (?, ?, ?, ?, ?)",
274                  (speaker_id, rec_type, filename, duration,
275                   datetime.now(timezone.utc).isoformat()),
276              )
277  
278              # Update sample counts
279              if rec_type == "wakeword":
280                  conn.execute(
281                      "UPDATE speakers SET wakeword_samples = wakeword_samples + 1, updated_at = ? WHERE id = ?",
282                      (datetime.now(timezone.utc).isoformat(), speaker_id),
283                  )
284              else:
285                  conn.execute(
286                      "UPDATE speakers SET enrollment_samples = enrollment_samples + 1, updated_at = ? WHERE id = ?",
287                      (datetime.now(timezone.utc).isoformat(), speaker_id),
288                  )
289  
290              conn.commit()
291  
292              # Extract embedding if we have enough enrollment samples
293              enrollment_count = conn.execute(
294                  "SELECT enrollment_samples FROM speakers WHERE id = ?", (speaker_id,)
295              ).fetchone()[0]
296  
297              embedding_status = "skipped"
298              if rec_type == "enrollment" and embedding_extractor is not None:
299                  emb = extract_embedding(audio_data)
300                  if emb is not None:
301                      # Average with existing embedding if present
302                      existing = conn.execute(
303                          "SELECT embedding, embedding_dim FROM speakers WHERE id = ?", (speaker_id,)
304                      ).fetchone()
305  
306                      if existing[0] is not None:
307                          old_emb = np.frombuffer(existing[0], dtype=np.float32)
308                          # Running average
309                          new_emb = (old_emb * (enrollment_count - 1) + emb) / enrollment_count
310                      else:
311                          new_emb = emb
312  
313                      conn.execute(
314                          "UPDATE speakers SET embedding = ?, embedding_dim = ?, updated_at = ? WHERE id = ?",
315                          (new_emb.astype(np.float32).tobytes(), len(new_emb),
316                           datetime.now(timezone.utc).isoformat(), speaker_id),
317                      )
318                      conn.commit()
319                      embedding_status = "updated"
320  
321              conn.close()
322              print(f"Saved {rec_type} recording for {speaker_id}: {filename} ({duration:.1f}s)")
323              self._send_json(200, {
324                  "status": "saved",
325                  "filename": filename,
326                  "duration_secs": round(duration, 1),
327                  "type": rec_type,
328                  "embedding_status": embedding_status,
329                  "enrollment_samples": enrollment_count if rec_type == "enrollment" else None,
330              })
331  
332          elif path == "/identify":
333              # Identify speaker from audio
334              audio_data = self.rfile.read(content_length)
335  
336              if embedding_extractor is None:
337                  self._send_json(503, {"error": "Embedding model not loaded"})
338                  return
339  
340              emb = extract_embedding(audio_data)
341              if emb is None:
342                  self._send_json(500, {"error": "Could not extract embedding"})
343                  return
344  
345              conn = init_db()
346              rows = conn.execute(
347                  "SELECT id, name, embedding FROM speakers WHERE embedding IS NOT NULL"
348              ).fetchall()
349              conn.close()
350  
351              best_match = None
352              best_score = 0
353              for row in rows:
354                  stored_emb = np.frombuffer(row[2], dtype=np.float32)
355                  score = cosine_similarity(emb, stored_emb)
356                  if score > best_score:
357                      best_score = score
358                      best_match = {"id": row[0], "name": row[1], "score": round(score, 4)}
359  
360              if best_match and best_score > 0.65:
361                  self._send_json(200, {"identified": True, "speaker": best_match})
362              else:
363                  self._send_json(200, {"identified": False, "best_match": best_match})
364  
365          else:
366              self._send_json(404, {"error": "Not found"})
367  
368      def log_message(self, format, *args):
369          pass  # Suppress default logging
370  
371  
372  def main():
373      DATA_DIR.mkdir(parents=True, exist_ok=True)
374      init_db()
375      init_embedding_model()
376  
377      # Pre-register family members
378      conn = init_db()
379      for speaker in [
380          ("cam", "Cameron Hunt", "Dad"),
381          ("aj", "Adriane Hunt", "Mom"),
382          ("hailen", "Hailen Hunt", "Son"),
383      ]:
384          conn.execute(
385              "INSERT OR IGNORE INTO speakers (id, name, family_role, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
386              (speaker[0], speaker[1], speaker[2],
387               datetime.now(timezone.utc).isoformat(),
388               datetime.now(timezone.utc).isoformat()),
389          )
390      conn.commit()
391      conn.close()
392      print(f"Family members pre-registered: cam, aj, hailen")
393  
394      server = HTTPServer(("0.0.0.0", LISTEN_PORT), EnrollmentHandler)
395      print(f"Voice enrollment service listening on :{LISTEN_PORT}")
396      server.serve_forever()
397  
398  
399  if __name__ == "__main__":
400      main()