bot.py
  1  """Bob voice agent — Pipecat pipeline wiring STT → LLM → TTS."""
  2  
  3  import asyncio
  4  import json
  5  import os
  6  import re
  7  import sys
  8  
  9  from loguru import logger
 10  
 11  from pipecat.audio.vad.silero import SileroVADAnalyzer
 12  from pipecat.audio.vad.vad_analyzer import VADParams
 13  from pipecat.frames.frames import Frame, LLMMessagesFrame, TextFrame
 14  from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
 15  from pipecat.pipeline.pipeline import Pipeline
 16  from pipecat.pipeline.runner import PipelineRunner
 17  from pipecat.pipeline.task import PipelineParams, PipelineTask
 18  from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
 19  from pipecat.services.openai.llm import OpenAILLMService
 20  from pipecat.services.openai.stt import OpenAISTTService
 21  from pipecat.services.openai import tts as openai_tts
 22  from pipecat.services.openai.tts import OpenAITTSService
 23  
 24  # Patch VALID_VOICES to accept any voice name (Kokoro has 67+ voices)
 25  class _AnyVoiceDict(dict):
 26      def __contains__(self, key): return True
 27      def __getitem__(self, key): return key
 28  
 29  openai_tts.VALID_VOICES = _AnyVoiceDict()
 30  from pipecat.transports.websocket.server import (
 31      WebsocketServerTransport,
 32      WebsocketServerParams,
 33  )
 34  from raw_pcm_serializer import RawPCMSerializer
 35  from tools import TOOL_DEFINITIONS, TOOL_HANDLERS
 36  
 37  logger.remove(0)
 38  logger.add(sys.stderr, level="DEBUG")
 39  
 40  
 41  class ThinkTagFilter(FrameProcessor):
 42      """Strips <think>...</think> blocks from LLM output before TTS.
 43      Passes ALL non-TextFrame frames through unchanged."""
 44  
 45      def __init__(self, **kwargs):
 46          super().__init__(**kwargs)
 47          self._in_think = False
 48  
 49      async def process_frame(self, frame: Frame, direction: FrameDirection):
 50          await super().process_frame(frame, direction)
 51  
 52          if not isinstance(frame, TextFrame):
 53              await self.push_frame(frame, direction)
 54              return
 55  
 56          text = frame.text
 57          if "<think>" in text:
 58              self._in_think = True
 59              text = text[:text.index("<think>")]
 60          if "</think>" in text:
 61              text = text[text.index("</think>") + len("</think>"):]
 62              self._in_think = False
 63          if self._in_think:
 64              return  # swallow think content
 65          text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
 66          if text.strip():  # has non-whitespace content
 67              await self.push_frame(TextFrame(text=text), direction)
 68          elif text:  # whitespace-only (spaces between words) — preserve
 69              await self.push_frame(TextFrame(text=text), direction)
 70  
 71  # --- Configuration from environment ---
 72  LLM_BASE_URL = os.getenv("LLM_BASE_URL", "http://host.docker.internal:8000/v1")
 73  LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen3-32B-AWQ")
 74  STT_BASE_URL = os.getenv("STT_BASE_URL", "http://host.docker.internal:10300/v1")
 75  TTS_BASE_URL = os.getenv("TTS_BASE_URL", "http://host.docker.internal:10400/v1")
 76  TTS_VOICE = os.getenv("TTS_VOICE", "bf_emma")
 77  TTS_ENGINE = os.getenv("TTS_ENGINE", "kokoro")  # "kokoro" or "fish"
 78  WAKE_WORD_ENABLED = os.getenv("WAKE_WORD_ENABLED", "true").lower() == "true"
 79  WAKE_WORD_HOST = os.getenv("WAKE_WORD_HOST", "127.0.0.1")
 80  WAKE_WORD_PORT = int(os.getenv("WAKE_WORD_PORT", "10500"))
 81  WAKE_WORD_NAME = os.getenv("WAKE_WORD_NAME", "hey_bob")
 82  WAKE_WORD_IDLE_TIMEOUT = float(os.getenv("WAKE_WORD_IDLE_TIMEOUT", "15.0"))
 83  FISH_SPEECH_URL = os.getenv("FISH_SPEECH_URL", "http://host.docker.internal:10600")
 84  FISH_REFERENCE_AUDIO = os.getenv("FISH_REFERENCE_AUDIO", "")
 85  FISH_REFERENCE_TEXT = os.getenv("FISH_REFERENCE_TEXT", "")
 86  SPEAKER_ID_ENABLED = os.getenv("SPEAKER_ID_ENABLED", "false").lower() == "true"  # Replaced by diarization
 87  ENROLLMENT_URL = os.getenv("ENROLLMENT_URL", "http://127.0.0.1:10800")
 88  DIARIZATION_ENABLED = os.getenv("DIARIZATION_ENABLED", "true").lower() == "true"
 89  DIARIZATION_URL = os.getenv("DIARIZATION_URL", "ws://127.0.0.1:7007")
 90  BOT_HOST = os.getenv("BOT_HOST", "0.0.0.0")
 91  BOT_PORT = int(os.getenv("BOT_PORT", "8765"))
 92  
 93  SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", """\
 94  You are Bob, the Hunt family's AI assistant in Tampa, FL. You speak naturally and concisely.
 95  Keep responses brief — one to three sentences unless asked for more detail.
 96  
 97  You have these tools:
 98  - get_weather: current weather and forecast
 99  - get_home_state / control_home_device: smart home status and control
100  - query_knowledge: family knowledge graph (SPARQL)
101  - execute_code: Python REPL for anything else — system diagnostics, calculations, data lookups, API queries
102  
103  When a question can't be answered with the specific tools, use execute_code to write Python that investigates. You have access to Docker, Prometheus, HomeAssistant, and NATS APIs from the REPL.
104  
105  IMPORTANT RULES:
106  - Do NOT use <think> tags or any internal reasoning. Respond directly.
107  - When you call a tool, wait for the result and give a natural spoken response based on the data.
108  - Keep responses conversational — you are being spoken aloud by text-to-speech.
109  - NEVER end with offers like "let me know if there's anything else." Just answer and stop.
110  - If you recognize the speaker (e.g. [Cam]), personalize your response.\
111  """) + "\n/no_think"
112  
113  
114  async def main():
115      # --- Services pointing at local stack ---
116      stt = OpenAISTTService(
117          api_key="not-needed",
118          base_url=STT_BASE_URL,
119          model="Systran/faster-whisper-large-v3",
120      )
121  
122      llm = OpenAILLMService(
123          api_key="not-needed",
124          base_url=LLM_BASE_URL,
125          model=LLM_MODEL,
126      )
127  
128      if TTS_ENGINE == "fish":
129          from fish_speech_tts import FishSpeechLocalService
130          tts = FishSpeechLocalService(
131              base_url=FISH_SPEECH_URL,
132              reference_audio_path=FISH_REFERENCE_AUDIO or None,
133              reference_text=FISH_REFERENCE_TEXT or None,
134              output_sample_rate=24000,
135          )
136          logger.info(f"Using Fish Speech TTS at {FISH_SPEECH_URL}")
137      else:
138          tts = OpenAITTSService(
139              api_key="not-needed",
140              base_url=TTS_BASE_URL,
141              model="kokoro",
142              voice=TTS_VOICE,
143              speed=1.05,
144          )
145          logger.info(f"Using Kokoro TTS at {TTS_BASE_URL}")
146  
147      # --- Transport: standalone WebSocket server ---
148      transport = WebsocketServerTransport(
149          host=BOT_HOST,
150          port=BOT_PORT,
151          params=WebsocketServerParams(
152              audio_in_enabled=True,
153              audio_in_sample_rate=16000,
154              audio_out_enabled=True,
155              audio_out_sample_rate=24000,
156              add_wav_header=False,
157              vad_enabled=True,
158              vad_analyzer=SileroVADAnalyzer(params=VADParams(
159                  confidence=0.85,
160                  start_secs=0.4,
161                  stop_secs=0.8,
162                  min_volume=0.6,
163              )),
164              vad_audio_passthrough=True,
165              serializer=RawPCMSerializer(sample_rate=16000, num_channels=1),
166          ),
167      )
168  
169      # --- LLM context with tools ---
170      messages = [{"role": "system", "content": SYSTEM_PROMPT}]
171      context = OpenAILLMContext(messages, tools=TOOL_DEFINITIONS)
172      context_aggregator = llm.create_context_aggregator(context)
173  
174      # --- Context compaction (prevent context window overflow) ---
175      from context_compactor import maybe_compact_context
176  
177      original_process = llm.process_frame.__func__ if hasattr(llm.process_frame, '__func__') else None
178  
179      # Hook: compact context before each LLM generation
180      # We'll check periodically in the tool handler since hooking LLM internals is fragile
181      _turn_count = [0]
182  
183      # --- Register tool call handlers ---
184      async def _handle_tool(params):
185          # Check context compaction every few tool calls
186          _turn_count[0] += 1
187          if _turn_count[0] % 3 == 0:
188              maybe_compact_context(context)
189  
190          handler = TOOL_HANDLERS.get(params.function_name)
191          if handler:
192              result = await handler(params.arguments)
193              await params.result_callback(result)
194          else:
195              await params.result_callback(json.dumps({"error": f"Unknown function: {params.function_name}"}))
196  
197      for tool_name in TOOL_HANDLERS:
198          llm.register_function(tool_name, _handle_tool)
199  
200      # --- Wake word gate (if enabled) ---
201      wake_gate = None
202      if WAKE_WORD_ENABLED:
203          from wake_word_gate import WakeWordGateProcessor
204          wake_gate = WakeWordGateProcessor(
205              wyoming_host=WAKE_WORD_HOST,
206              wyoming_port=WAKE_WORD_PORT,
207              wake_word_name=WAKE_WORD_NAME,
208              idle_timeout=WAKE_WORD_IDLE_TIMEOUT,
209          )
210          logger.info(f"Wake word gate enabled: '{WAKE_WORD_NAME}' via {WAKE_WORD_HOST}:{WAKE_WORD_PORT}")
211      else:
212          logger.info("Wake word gate disabled — always listening")
213  
214      # --- Diarization (if enabled — replaces basic speaker ID) ---
215      diarizer = None
216      if DIARIZATION_ENABLED:
217          from diarization_processor import DiarizationProcessor
218          diarizer = DiarizationProcessor(diarization_url=DIARIZATION_URL)
219          logger.info(f"Diarization enabled via {DIARIZATION_URL}")
220      else:
221          logger.info("Diarization disabled")
222  
223      # --- Speaker identification (if enabled — fallback when diarization off) ---
224      speaker_id = None
225      if SPEAKER_ID_ENABLED and not DIARIZATION_ENABLED:
226          from speaker_id import SpeakerIdentifier
227          speaker_id = SpeakerIdentifier(enrollment_url=ENROLLMENT_URL)
228          logger.info(f"Speaker identification enabled via {ENROLLMENT_URL}")
229      else:
230          logger.info("Speaker identification disabled")
231  
232      # --- Think tag filter (strips <think>...</think> before TTS) ---
233      think_filter = ThinkTagFilter()
234  
235      # --- Pipeline ---
236      pipeline_stages = [transport.input()]
237      if wake_gate:
238          pipeline_stages.append(wake_gate)
239      if diarizer:
240          pipeline_stages.append(diarizer)
241      pipeline_stages.append(stt)
242      if speaker_id:
243          pipeline_stages.append(speaker_id)
244  
245      # --- Fast-path handler (time, date — bypass LLM) ---
246      fast_path = None
247      if os.getenv("FAST_PATH_ENABLED", "true").lower() == "true":
248          from fast_path import FastPathProcessor
249          fast_path = FastPathProcessor()
250          logger.info("Fast-path deterministic queries enabled")
251          pipeline_stages.append(fast_path)
252  
253      pipeline_stages.extend([
254          context_aggregator.user(),
255          llm,
256          think_filter,
257          tts,
258          transport.output(),
259          context_aggregator.assistant(),
260      ])
261  
262      pipeline = Pipeline(pipeline_stages)
263  
264      task = PipelineTask(
265          pipeline,
266          params=PipelineParams(
267              allow_interruptions=True,
268              enable_metrics=True,
269              enable_usage_metrics=True,
270              idle_timeout_secs=0,  # Disable idle timeout — Bob is a persistent server
271          ),
272      )
273  
274      @transport.event_handler("on_client_connected")
275      async def on_client_connected(transport, client):
276          logger.info("Client connected, waiting for user to speak")
277  
278      @transport.event_handler("on_client_disconnected")
279      async def on_client_disconnected(transport, client):
280          logger.info("Client disconnected, consolidating session")
281          try:
282              from session_consolidator import consolidate_session, get_conversation_messages
283              msgs = get_conversation_messages(context)
284              # Get current speaker from diarization or speaker_id
285              current_speaker = "Unknown"
286              if diarizer and diarizer._current_speaker:
287                  current_speaker = diarizer._current_speaker
288              elif speaker_id and speaker_id._current_speaker:
289                  current_speaker = speaker_id._current_speaker
290              await consolidate_session(msgs, current_speaker)
291          except Exception as e:
292              logger.warning(f"Session consolidation error: {e}")
293  
294      runner = PipelineRunner()
295      await runner.run(task)
296  
297  
298  if __name__ == "__main__":
299      asyncio.run(main())