diarization_processor.py
1 """Diarization processor for Pipecat — routes audio through diart for speaker separation. 2 3 Streams audio to the diart WebSocket service, receives speaker segments, 4 and tags transcriptions with identified speaker names. 5 6 Pipeline position: audio input → DiarizationProcessor → STT → LLM 7 """ 8 9 import asyncio 10 import base64 11 import json 12 import os 13 import struct 14 from typing import Optional 15 16 import numpy as np 17 from loguru import logger 18 19 from pipecat.frames.frames import ( 20 Frame, 21 InputAudioRawFrame, 22 TranscriptionFrame, 23 ) 24 from pipecat.processors.frame_processor import FrameDirection, FrameProcessor 25 26 DIARIZATION_URL = os.getenv("DIARIZATION_URL", "ws://127.0.0.1:7007") 27 DIARIZATION_ENABLED = os.getenv("DIARIZATION_ENABLED", "true").lower() == "true" 28 29 30 class DiarizationProcessor(FrameProcessor): 31 """Streams audio to diart diarization service and tracks current speaker. 32 33 The processor: 34 1. Forwards all audio frames downstream (to STT) unchanged 35 2. Simultaneously streams audio to diart via WebSocket 36 3. Receives speaker segment events from diart 37 4. Tracks the current active speaker 38 5. Tags TranscriptionFrames with the current speaker name 39 """ 40 41 def __init__(self, *, diarization_url: str = DIARIZATION_URL, **kwargs): 42 super().__init__(**kwargs) 43 self._url = diarization_url 44 self._ws = None 45 self._current_speaker: Optional[str] = None 46 self._connected = False 47 self._receive_task: Optional[asyncio.Task] = None 48 self._audio_buffer = bytearray() 49 self._buffer_size = 16000 # 0.5s at 16kHz 16-bit = 16000 bytes 50 51 async def process_frame(self, frame: Frame, direction: FrameDirection): 52 await super().process_frame(frame, direction) 53 54 if isinstance(frame, InputAudioRawFrame): 55 # Always forward audio downstream 56 await self.push_frame(frame, direction) 57 58 # Connect on first audio frame if not yet connected 59 if not self._connected and DIARIZATION_ENABLED: 60 await self._connect() 61 62 # Also send to diarization service 63 if self._connected: 64 self._audio_buffer.extend(frame.audio) 65 if len(self._audio_buffer) >= self._buffer_size: 66 await self._send_audio(bytes(self._audio_buffer)) 67 self._audio_buffer = bytearray() 68 return 69 70 # Tag transcriptions with current speaker 71 if isinstance(frame, TranscriptionFrame) and self._current_speaker: 72 frame.text = f"[{self._current_speaker}] {frame.text}" 73 await self.push_frame(frame, direction) 74 return 75 76 # Pass everything else through 77 await self.push_frame(frame, direction) 78 79 async def _connect(self): 80 """Connect to the diarization WebSocket service.""" 81 try: 82 import websockets 83 self._ws = await websockets.connect(self._url) 84 self._connected = True 85 self._receive_task = asyncio.create_task(self._receive_loop()) 86 logger.info(f"Connected to diarization service at {self._url}") 87 except Exception as e: 88 logger.warning(f"Could not connect to diarization service: {e}") 89 self._connected = False 90 91 async def _send_audio(self, pcm16_bytes: bytes): 92 """Send audio chunk to diart as base64-encoded float32.""" 93 if not self._ws: 94 return 95 try: 96 # Convert PCM16 int16 to float32 97 n_samples = len(pcm16_bytes) // 2 98 samples = np.frombuffer(pcm16_bytes, dtype=np.int16).astype(np.float32) / 32768.0 99 # diart expects base64-encoded float32 100 b64 = base64.b64encode(samples.tobytes()).decode("utf-8") 101 await self._ws.send(b64) 102 except Exception as e: 103 logger.debug(f"Diarization send error: {e}") 104 self._connected = False 105 # Try to reconnect 106 asyncio.create_task(self._reconnect()) 107 108 async def _receive_loop(self): 109 """Receive speaker segment events from diart.""" 110 try: 111 async for message in self._ws: 112 try: 113 data = json.loads(message) 114 segments = data.get("segments", []) 115 if segments: 116 # Use the most recent/longest segment as current speaker 117 latest = max(segments, key=lambda s: s.get("end", 0)) 118 speaker = latest.get("speaker", "Unknown") 119 if speaker != self._current_speaker: 120 self._current_speaker = speaker 121 logger.info(f"Speaker changed: {speaker}") 122 except (json.JSONDecodeError, KeyError): 123 # diart might send RTTM text instead of JSON 124 if "SPEAKER" in str(message): 125 # Parse RTTM: SPEAKER <uri> 1 <start> <dur> <NA> <NA> <label> <NA> <NA> 126 parts = str(message).strip().split() 127 if len(parts) >= 8: 128 speaker = parts[7] 129 if speaker != self._current_speaker: 130 self._current_speaker = speaker 131 logger.info(f"Speaker changed: {speaker}") 132 except Exception as e: 133 logger.debug(f"Diarization receive error: {e}") 134 self._connected = False 135 136 async def _reconnect(self): 137 """Attempt to reconnect to the diarization service.""" 138 await asyncio.sleep(5) 139 if not self._connected: 140 await self._connect() 141 142 async def start(self): 143 """Called when the pipeline starts.""" 144 if DIARIZATION_ENABLED: 145 await self._connect() 146 147 async def cleanup(self): 148 """Called when the pipeline stops.""" 149 if self._receive_task: 150 self._receive_task.cancel() 151 if self._ws: 152 await self._ws.close()