/ services / pipecat-agent / wake_word_gate.py
wake_word_gate.py
  1  """Wake word gate processor for Pipecat.
  2  
  3  Streams mic audio to openWakeWord via the Wyoming protocol. Only forwards
  4  audio downstream (to STT) after the wake word is detected. Returns to
  5  waiting state after the bot finishes responding.
  6  
  7  States:
  8    WAITING  — audio sent to Wyoming only, not forwarded to STT
  9    ACTIVE   — audio forwarded to STT (wake word was detected)
 10  """
 11  
 12  import asyncio
 13  import json
 14  import struct
 15  from enum import Enum
 16  from typing import Optional
 17  
 18  from loguru import logger
 19  
 20  from pipecat.frames.frames import (
 21      Frame,
 22      InputAudioRawFrame,
 23      OutputTransportMessageUrgentFrame,
 24      StartFrame,
 25      EndFrame,
 26      CancelFrame,
 27  )
 28  from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
 29  
 30  
 31  class GateState(Enum):
 32      WAITING = "waiting"
 33      ACTIVE = "active"
 34  
 35  
 36  class WakeWordGateProcessor(FrameProcessor):
 37      """Gates audio through openWakeWord — only forwards to STT after wake word detection."""
 38  
 39      def __init__(
 40          self,
 41          *,
 42          wyoming_host: str = "127.0.0.1",
 43          wyoming_port: int = 10500,
 44          wake_word_name: str = "hey_bob",
 45          idle_timeout: float = 15.0,
 46          **kwargs,
 47      ):
 48          super().__init__(**kwargs)
 49          self._host = wyoming_host
 50          self._port = wyoming_port
 51          self._wake_word_name = wake_word_name
 52          self._idle_timeout = idle_timeout
 53          self._state = GateState.WAITING
 54          self._reader: Optional[asyncio.StreamReader] = None
 55          self._writer: Optional[asyncio.StreamWriter] = None
 56          self._detection_task: Optional[asyncio.Task] = None
 57          self._idle_task: Optional[asyncio.Task] = None
 58          self._connected = False
 59          self._audio_frame_count = 0
 60          self._audio_buffer = bytearray()
 61          self._wyoming_chunk_size = 3200  # 100ms at 16kHz 16-bit mono
 62  
 63      async def process_frame(self, frame: Frame, direction: FrameDirection):
 64          await super().process_frame(frame, direction)
 65  
 66          if isinstance(frame, StartFrame):
 67              await self.push_frame(frame, direction)
 68              await self._start_waiting()
 69              return
 70  
 71          if isinstance(frame, (EndFrame, CancelFrame)):
 72              await self._disconnect()
 73              await self.push_frame(frame, direction)
 74              return
 75  
 76          if isinstance(frame, InputAudioRawFrame):
 77              self._audio_frame_count += 1
 78              if self._audio_frame_count % 100 == 1:
 79                  logger.debug(f"WakeWordGate: frame #{self._audio_frame_count}, {len(frame.audio)}B, state={self._state.value}, wyoming={self._connected}")
 80              if self._state == GateState.WAITING:
 81                  # Send audio to Wyoming for wake word detection, don't forward
 82                  await self._send_audio_to_wyoming(frame.audio)
 83                  return
 84              elif self._state == GateState.ACTIVE:
 85                  # Forward audio downstream to STT
 86                  await self.push_frame(frame, direction)
 87                  # Reset idle timer
 88                  self._reset_idle_timer()
 89                  return
 90  
 91          # Check for bot-stopped-speaking to return to waiting
 92          # The frame type name varies by Pipecat version
 93          frame_type = type(frame).__name__
 94          if frame_type in ("BotStoppedSpeakingFrame", "TTSStoppedFrame"):
 95              if self._state == GateState.ACTIVE:
 96                  # Start idle timer — if no speech for idle_timeout, go back to waiting
 97                  self._reset_idle_timer()
 98  
 99          # Pass all non-audio frames through
100          await self.push_frame(frame, direction)
101  
102      def _reset_idle_timer(self):
103          """Reset the idle timer. If no speech for idle_timeout seconds, return to WAITING."""
104          if self._idle_task:
105              self._idle_task.cancel()
106          self._idle_task = asyncio.create_task(self._idle_timeout_handler())
107  
108      async def _idle_timeout_handler(self):
109          """After idle_timeout seconds of no activity, return to waiting for wake word."""
110          try:
111              await asyncio.sleep(self._idle_timeout)
112              if self._state == GateState.ACTIVE:
113                  logger.info("Idle timeout — returning to wake word listening")
114                  await self._start_waiting()
115          except asyncio.CancelledError:
116              pass
117  
118      async def _notify_state(self, state: str):
119          """Send wake word state to the web client via RTVI message."""
120          msg = {"label": "rtvi-ai", "type": "wake-word-state", "data": {"state": state}}
121          await self.push_frame(OutputTransportMessageUrgentFrame(message=msg))
122  
123      async def _start_waiting(self):
124          """Enter WAITING state — connect to Wyoming and start listening for wake word."""
125          self._state = GateState.WAITING
126          if self._idle_task:
127              self._idle_task.cancel()
128              self._idle_task = None
129          await self._connect_wyoming()
130          await self._notify_state("waiting")
131          logger.info(f"Waiting for wake word '{self._wake_word_name}'...")
132  
133      async def _connect_wyoming(self):
134          """Open a TCP connection to the Wyoming wake word server."""
135          # Guard against concurrent reconnection attempts
136          if self._connected and self._writer and not self._writer.is_closing():
137              return
138  
139          self._audio_buffer.clear()
140          await self._disconnect()
141          try:
142              self._reader, self._writer = await asyncio.wait_for(
143                  asyncio.open_connection(self._host, self._port),
144                  timeout=5.0,
145              )
146              self._connected = True
147  
148              # Send Detect event
149              await self._write_wyoming_event("detect", {"names": [self._wake_word_name]})
150  
151              # Send AudioStart event
152              await self._write_wyoming_event("audio-start", {
153                  "rate": 16000, "width": 2, "channels": 1,
154              })
155  
156              # Start background task reading for detection events
157              self._detection_task = asyncio.create_task(self._read_detection())
158              logger.debug(f"Wyoming connected to {self._host}:{self._port}")
159  
160          except Exception as e:
161              logger.error(f"Failed to connect to Wyoming at {self._host}:{self._port}: {e}")
162              self._connected = False
163              self._reader = None
164              self._writer = None
165  
166      async def _disconnect(self):
167          """Close the Wyoming TCP connection."""
168          if self._detection_task:
169              self._detection_task.cancel()
170              self._detection_task = None
171          if self._writer:
172              try:
173                  self._writer.close()
174                  await self._writer.wait_closed()
175              except Exception:
176                  pass
177          self._reader = None
178          self._writer = None
179          self._connected = False
180  
181      async def _write_wyoming_event(self, event_type: str, data: dict, payload: bytes = b""):
182          """Write a Wyoming protocol event to the TCP connection."""
183          if not self._writer:
184              return
185  
186          data_bytes = json.dumps(data).encode("utf-8")
187          header = {
188              "type": event_type,
189              "version": "1.0.0",
190              "data_length": len(data_bytes),
191          }
192          if payload:
193              header["payload_length"] = len(payload)
194  
195          header_line = json.dumps(header).encode("utf-8") + b"\n"
196          try:
197              self._writer.write(header_line + data_bytes + payload)
198              await self._writer.drain()
199          except Exception as e:
200              logger.warning(f"Wyoming write error: {e}, will reconnect")
201              self._connected = False
202              self._reader = None
203              self._writer = None
204  
205      async def _send_audio_to_wyoming(self, audio_bytes: bytes):
206          """Buffer small audio frames and send to Wyoming in larger chunks.
207  
208          openWakeWord needs ~80-100ms chunks for meaningful analysis.
209          Pipecat sends tiny ~8ms (256 byte) frames, so we buffer them.
210          """
211          if not self._connected or not self._writer or self._writer.is_closing():
212              try:
213                  await self._connect_wyoming()
214              except Exception as e:
215                  logger.debug(f"Wyoming reconnect failed: {e}")
216              if not self._connected:
217                  return
218  
219          self._audio_buffer.extend(audio_bytes)
220  
221          while len(self._audio_buffer) >= self._wyoming_chunk_size:
222              chunk = bytes(self._audio_buffer[:self._wyoming_chunk_size])
223              del self._audio_buffer[:self._wyoming_chunk_size]
224              data = {"rate": 16000, "width": 2, "channels": 1, "timestamp": 0}
225              await self._write_wyoming_event("audio-chunk", data, payload=chunk)
226  
227      async def _read_detection(self):
228          """Background task: read Wyoming events and detect wake word."""
229          try:
230              while self._reader and self._connected:
231                  # Long timeout — Wyoming only sends events on detection/not-detected.
232                  # Audio keeps flowing via _send_audio_to_wyoming, so the connection
233                  # stays alive. We just need to wait patiently for a detection event.
234                  line = await self._reader.readline()
235                  if not line:
236                      break
237  
238                  try:
239                      header = json.loads(line.decode("utf-8").strip())
240                  except (json.JSONDecodeError, UnicodeDecodeError):
241                      continue
242  
243                  event_type = header.get("type", "")
244                  data_length = header.get("data_length", 0)
245                  payload_length = header.get("payload_length", 0)
246  
247                  # Read data and payload if present
248                  data = {}
249                  if data_length > 0:
250                      data_bytes = await self._reader.readexactly(data_length)
251                      try:
252                          data = json.loads(data_bytes.decode("utf-8"))
253                      except (json.JSONDecodeError, UnicodeDecodeError):
254                          pass
255                  if payload_length > 0:
256                      await self._reader.readexactly(payload_length)  # discard
257  
258                  if event_type == "detection":
259                      wake_name = data.get("name", "unknown")
260                      logger.info(f"Wake word detected: '{wake_name}' — activating pipeline")
261                      self._state = GateState.ACTIVE
262                      await self._notify_state("activated")
263                      self._reset_idle_timer()
264                      return
265  
266                  elif event_type == "not-detected":
267                      logger.debug("Wyoming: not-detected, reconnecting...")
268                      await self._connect_wyoming()
269                      return
270  
271          except asyncio.CancelledError:
272              pass
273          except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError):
274              logger.debug("Wyoming connection closed, reconnecting...")
275              if self._state == GateState.WAITING:
276                  await asyncio.sleep(1)
277                  await self._connect_wyoming()
278          except Exception as e:
279              logger.warning(f"Wyoming read error: {e}")
280              if self._state == GateState.WAITING:
281                  await asyncio.sleep(1)
282                  await self._connect_wyoming()