/ services / pipecat-agent / speaker_id.py
speaker_id.py
  1  """Speaker identification processor for Pipecat.
  2  
  3  Extracts speaker embeddings from audio segments using the voice enrollment
  4  service, identifies which family member is speaking, and tags transcriptions
  5  with speaker identity.
  6  
  7  Works by buffering audio during speech (VAD-detected), then sending the
  8  buffer to the enrollment service's /identify endpoint after speech ends.
  9  """
 10  
 11  import asyncio
 12  import io
 13  import json
 14  import os
 15  import struct
 16  import urllib.request
 17  import urllib.error
 18  from typing import Optional
 19  
 20  from loguru import logger
 21  
 22  from pipecat.frames.frames import (
 23      Frame,
 24      InputAudioRawFrame,
 25      TranscriptionFrame,
 26      InterimTranscriptionFrame,
 27  )
 28  from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
 29  
 30  ENROLLMENT_URL = os.getenv("ENROLLMENT_URL", "http://127.0.0.1:10800")
 31  SPEAKER_ID_ENABLED = os.getenv("SPEAKER_ID_ENABLED", "true").lower() == "true"
 32  MIN_SPEECH_SECS = float(os.getenv("MIN_SPEECH_SECS", "1.0"))  # Min speech to attempt ID
 33  
 34  # Map full names to short/preferred names for transcriptions
 35  SHORT_NAMES = {
 36      "Cameron Hunt": "Cam",
 37      "Adriane Hunt": "AJ",
 38      "Hailen Hunt": "Hailen",
 39  }
 40  
 41  
 42  class SpeakerIdentifier(FrameProcessor):
 43      """Identifies the current speaker by sending audio to the enrollment service.
 44  
 45      Buffers audio during speech, sends to /identify after speech ends.
 46      Tags subsequent TranscriptionFrames with the identified speaker name.
 47      """
 48  
 49      def __init__(self, *, enrollment_url: str = ENROLLMENT_URL, **kwargs):
 50          super().__init__(**kwargs)
 51          self._enrollment_url = enrollment_url.rstrip("/")
 52          self._audio_buffer = bytearray()
 53          self._is_speaking = False
 54          self._current_speaker: Optional[str] = None
 55          self._sample_rate = 16000
 56  
 57      async def process_frame(self, frame: Frame, direction: FrameDirection):
 58          await super().process_frame(frame, direction)
 59  
 60          if not SPEAKER_ID_ENABLED:
 61              await self.push_frame(frame, direction)
 62              return
 63  
 64          frame_type = type(frame).__name__
 65  
 66          # Buffer audio during speech
 67          if isinstance(frame, InputAudioRawFrame):
 68              if self._is_speaking:
 69                  self._audio_buffer.extend(frame.audio)
 70              await self.push_frame(frame, direction)
 71              return
 72  
 73          # Track speech start/stop from VAD
 74          if frame_type == "UserStartedSpeakingFrame":
 75              self._is_speaking = True
 76              self._audio_buffer = bytearray()
 77              await self.push_frame(frame, direction)
 78              return
 79  
 80          if frame_type == "UserStoppedSpeakingFrame":
 81              self._is_speaking = False
 82              # Attempt identification if we have enough audio
 83              duration = len(self._audio_buffer) / 2 / self._sample_rate
 84              if duration >= MIN_SPEECH_SECS:
 85                  speaker = await self._identify_speaker()
 86                  if speaker:
 87                      self._current_speaker = speaker
 88                      logger.info(f"Speaker identified: {speaker}")
 89              await self.push_frame(frame, direction)
 90              return
 91  
 92          # Tag transcriptions with speaker identity
 93          if isinstance(frame, TranscriptionFrame) and self._current_speaker:
 94              frame.text = f"[{self._current_speaker}] {frame.text}"
 95              await self.push_frame(frame, direction)
 96              return
 97  
 98          # Pass everything else through
 99          await self.push_frame(frame, direction)
100  
101      async def _identify_speaker(self) -> Optional[str]:
102          """Send buffered audio to enrollment service for identification."""
103          try:
104              # Encode as WAV
105              wav_data = self._encode_wav(bytes(self._audio_buffer), self._sample_rate)
106  
107              # Send to enrollment service
108              req = urllib.request.Request(
109                  f"{self._enrollment_url}/identify",
110                  data=wav_data,
111                  headers={"Content-Type": "audio/wav"},
112                  method="POST",
113              )
114              with urllib.request.urlopen(req, timeout=5) as resp:
115                  result = json.loads(resp.read().decode())
116  
117              if result.get("identified"):
118                  speaker = result["speaker"]
119                  full_name = speaker["name"]
120                  short_name = SHORT_NAMES.get(full_name, full_name.split()[0])
121                  logger.debug(f"Speaker ID: {short_name} (score={speaker['score']})")
122                  return short_name
123              else:
124                  best = result.get("best_match")
125                  if best:
126                      logger.debug(f"Speaker not confident: best={best['name']} score={best['score']}")
127                  return None
128  
129          except Exception as e:
130              logger.debug(f"Speaker ID failed: {e}")
131              return None
132  
133      @staticmethod
134      def _encode_wav(pcm_data: bytes, sample_rate: int) -> bytes:
135          """Encode raw PCM16 as WAV."""
136          n_samples = len(pcm_data) // 2
137          buf = io.BytesIO()
138          with io.BufferedWriter(buf) as bw:
139              # RIFF header
140              data_size = n_samples * 2
141              bw.write(b"RIFF")
142              bw.write(struct.pack("<I", 36 + data_size))
143              bw.write(b"WAVE")
144              # fmt chunk
145              bw.write(b"fmt ")
146              bw.write(struct.pack("<I", 16))
147              bw.write(struct.pack("<HHIIHH", 1, 1, sample_rate, sample_rate * 2, 2, 16))
148              # data chunk
149              bw.write(b"data")
150              bw.write(struct.pack("<I", data_size))
151              bw.write(pcm_data)
152          return buf.getvalue()