/ services / pipecat-agent / diarization_processor.py
diarization_processor.py
  1  """Diarization processor for Pipecat — routes audio through diart for speaker separation.
  2  
  3  Streams audio to the diart WebSocket service, receives speaker segments,
  4  and tags transcriptions with identified speaker names.
  5  
  6  Pipeline position: audio input → DiarizationProcessor → STT → LLM
  7  """
  8  
  9  import asyncio
 10  import base64
 11  import json
 12  import os
 13  import struct
 14  from typing import Optional
 15  
 16  import numpy as np
 17  from loguru import logger
 18  
 19  from pipecat.frames.frames import (
 20      Frame,
 21      InputAudioRawFrame,
 22      TranscriptionFrame,
 23  )
 24  from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
 25  
 26  DIARIZATION_URL = os.getenv("DIARIZATION_URL", "ws://127.0.0.1:7007")
 27  DIARIZATION_ENABLED = os.getenv("DIARIZATION_ENABLED", "true").lower() == "true"
 28  
 29  
 30  class DiarizationProcessor(FrameProcessor):
 31      """Streams audio to diart diarization service and tracks current speaker.
 32  
 33      The processor:
 34      1. Forwards all audio frames downstream (to STT) unchanged
 35      2. Simultaneously streams audio to diart via WebSocket
 36      3. Receives speaker segment events from diart
 37      4. Tracks the current active speaker
 38      5. Tags TranscriptionFrames with the current speaker name
 39      """
 40  
 41      def __init__(self, *, diarization_url: str = DIARIZATION_URL, **kwargs):
 42          super().__init__(**kwargs)
 43          self._url = diarization_url
 44          self._ws = None
 45          self._current_speaker: Optional[str] = None
 46          self._connected = False
 47          self._receive_task: Optional[asyncio.Task] = None
 48          self._audio_buffer = bytearray()
 49          self._buffer_size = 16000  # 0.5s at 16kHz 16-bit = 16000 bytes
 50  
 51      async def process_frame(self, frame: Frame, direction: FrameDirection):
 52          await super().process_frame(frame, direction)
 53  
 54          if isinstance(frame, InputAudioRawFrame):
 55              # Always forward audio downstream
 56              await self.push_frame(frame, direction)
 57  
 58              # Connect on first audio frame if not yet connected
 59              if not self._connected and DIARIZATION_ENABLED:
 60                  await self._connect()
 61  
 62              # Also send to diarization service
 63              if self._connected:
 64                  self._audio_buffer.extend(frame.audio)
 65                  if len(self._audio_buffer) >= self._buffer_size:
 66                      await self._send_audio(bytes(self._audio_buffer))
 67                      self._audio_buffer = bytearray()
 68              return
 69  
 70          # Tag transcriptions with current speaker
 71          if isinstance(frame, TranscriptionFrame) and self._current_speaker:
 72              frame.text = f"[{self._current_speaker}] {frame.text}"
 73              await self.push_frame(frame, direction)
 74              return
 75  
 76          # Pass everything else through
 77          await self.push_frame(frame, direction)
 78  
 79      async def _connect(self):
 80          """Connect to the diarization WebSocket service."""
 81          try:
 82              import websockets
 83              self._ws = await websockets.connect(self._url)
 84              self._connected = True
 85              self._receive_task = asyncio.create_task(self._receive_loop())
 86              logger.info(f"Connected to diarization service at {self._url}")
 87          except Exception as e:
 88              logger.warning(f"Could not connect to diarization service: {e}")
 89              self._connected = False
 90  
 91      async def _send_audio(self, pcm16_bytes: bytes):
 92          """Send audio chunk to diart as base64-encoded float32."""
 93          if not self._ws:
 94              return
 95          try:
 96              # Convert PCM16 int16 to float32
 97              n_samples = len(pcm16_bytes) // 2
 98              samples = np.frombuffer(pcm16_bytes, dtype=np.int16).astype(np.float32) / 32768.0
 99              # diart expects base64-encoded float32
100              b64 = base64.b64encode(samples.tobytes()).decode("utf-8")
101              await self._ws.send(b64)
102          except Exception as e:
103              logger.debug(f"Diarization send error: {e}")
104              self._connected = False
105              # Try to reconnect
106              asyncio.create_task(self._reconnect())
107  
108      async def _receive_loop(self):
109          """Receive speaker segment events from diart."""
110          try:
111              async for message in self._ws:
112                  try:
113                      data = json.loads(message)
114                      segments = data.get("segments", [])
115                      if segments:
116                          # Use the most recent/longest segment as current speaker
117                          latest = max(segments, key=lambda s: s.get("end", 0))
118                          speaker = latest.get("speaker", "Unknown")
119                          if speaker != self._current_speaker:
120                              self._current_speaker = speaker
121                              logger.info(f"Speaker changed: {speaker}")
122                  except (json.JSONDecodeError, KeyError):
123                      # diart might send RTTM text instead of JSON
124                      if "SPEAKER" in str(message):
125                          # Parse RTTM: SPEAKER <uri> 1 <start> <dur> <NA> <NA> <label> <NA> <NA>
126                          parts = str(message).strip().split()
127                          if len(parts) >= 8:
128                              speaker = parts[7]
129                              if speaker != self._current_speaker:
130                                  self._current_speaker = speaker
131                                  logger.info(f"Speaker changed: {speaker}")
132          except Exception as e:
133              logger.debug(f"Diarization receive error: {e}")
134              self._connected = False
135  
136      async def _reconnect(self):
137          """Attempt to reconnect to the diarization service."""
138          await asyncio.sleep(5)
139          if not self._connected:
140              await self._connect()
141  
142      async def start(self):
143          """Called when the pipeline starts."""
144          if DIARIZATION_ENABLED:
145              await self._connect()
146  
147      async def cleanup(self):
148          """Called when the pipeline stops."""
149          if self._receive_task:
150              self._receive_task.cancel()
151          if self._ws:
152              await self._ws.close()