bot.py
1 """Bob voice agent — Pipecat pipeline wiring STT → LLM → TTS.""" 2 3 import asyncio 4 import json 5 import os 6 import re 7 import sys 8 9 from loguru import logger 10 11 from pipecat.audio.vad.silero import SileroVADAnalyzer 12 from pipecat.audio.vad.vad_analyzer import VADParams 13 from pipecat.frames.frames import Frame, LLMMessagesFrame, TextFrame 14 from pipecat.processors.frame_processor import FrameDirection, FrameProcessor 15 from pipecat.pipeline.pipeline import Pipeline 16 from pipecat.pipeline.runner import PipelineRunner 17 from pipecat.pipeline.task import PipelineParams, PipelineTask 18 from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext 19 from pipecat.services.openai.llm import OpenAILLMService 20 from pipecat.services.openai.stt import OpenAISTTService 21 from pipecat.services.openai import tts as openai_tts 22 from pipecat.services.openai.tts import OpenAITTSService 23 24 # Patch VALID_VOICES to accept any voice name (Kokoro has 67+ voices) 25 class _AnyVoiceDict(dict): 26 def __contains__(self, key): return True 27 def __getitem__(self, key): return key 28 29 openai_tts.VALID_VOICES = _AnyVoiceDict() 30 from pipecat.transports.websocket.server import ( 31 WebsocketServerTransport, 32 WebsocketServerParams, 33 ) 34 from raw_pcm_serializer import RawPCMSerializer 35 from tools import TOOL_DEFINITIONS, TOOL_HANDLERS 36 37 logger.remove(0) 38 logger.add(sys.stderr, level="DEBUG") 39 40 41 class ThinkTagFilter(FrameProcessor): 42 """Strips <think>...</think> blocks from LLM output before TTS. 43 Passes ALL non-TextFrame frames through unchanged.""" 44 45 def __init__(self, **kwargs): 46 super().__init__(**kwargs) 47 self._in_think = False 48 49 async def process_frame(self, frame: Frame, direction: FrameDirection): 50 await super().process_frame(frame, direction) 51 52 if not isinstance(frame, TextFrame): 53 await self.push_frame(frame, direction) 54 return 55 56 text = frame.text 57 if "<think>" in text: 58 self._in_think = True 59 text = text[:text.index("<think>")] 60 if "</think>" in text: 61 text = text[text.index("</think>") + len("</think>"):] 62 self._in_think = False 63 if self._in_think: 64 return # swallow think content 65 text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL) 66 if text.strip(): # has non-whitespace content 67 await self.push_frame(TextFrame(text=text), direction) 68 elif text: # whitespace-only (spaces between words) — preserve 69 await self.push_frame(TextFrame(text=text), direction) 70 71 # --- Configuration from environment --- 72 LLM_BASE_URL = os.getenv("LLM_BASE_URL", "http://host.docker.internal:8000/v1") 73 LLM_MODEL = os.getenv("LLM_MODEL", "Qwen/Qwen3-32B-AWQ") 74 STT_BASE_URL = os.getenv("STT_BASE_URL", "http://host.docker.internal:10300/v1") 75 TTS_BASE_URL = os.getenv("TTS_BASE_URL", "http://host.docker.internal:10400/v1") 76 TTS_VOICE = os.getenv("TTS_VOICE", "bf_emma") 77 TTS_ENGINE = os.getenv("TTS_ENGINE", "kokoro") # "kokoro" or "fish" 78 WAKE_WORD_ENABLED = os.getenv("WAKE_WORD_ENABLED", "true").lower() == "true" 79 WAKE_WORD_HOST = os.getenv("WAKE_WORD_HOST", "127.0.0.1") 80 WAKE_WORD_PORT = int(os.getenv("WAKE_WORD_PORT", "10500")) 81 WAKE_WORD_NAME = os.getenv("WAKE_WORD_NAME", "hey_bob") 82 WAKE_WORD_IDLE_TIMEOUT = float(os.getenv("WAKE_WORD_IDLE_TIMEOUT", "15.0")) 83 FISH_SPEECH_URL = os.getenv("FISH_SPEECH_URL", "http://host.docker.internal:10600") 84 FISH_REFERENCE_AUDIO = os.getenv("FISH_REFERENCE_AUDIO", "") 85 FISH_REFERENCE_TEXT = os.getenv("FISH_REFERENCE_TEXT", "") 86 SPEAKER_ID_ENABLED = os.getenv("SPEAKER_ID_ENABLED", "false").lower() == "true" # Replaced by diarization 87 ENROLLMENT_URL = os.getenv("ENROLLMENT_URL", "http://127.0.0.1:10800") 88 DIARIZATION_ENABLED = os.getenv("DIARIZATION_ENABLED", "true").lower() == "true" 89 DIARIZATION_URL = os.getenv("DIARIZATION_URL", "ws://127.0.0.1:7007") 90 BOT_HOST = os.getenv("BOT_HOST", "0.0.0.0") 91 BOT_PORT = int(os.getenv("BOT_PORT", "8765")) 92 93 SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", """\ 94 You are Bob, the Hunt family's AI assistant in Tampa, FL. You speak naturally and concisely. 95 Keep responses brief — one to three sentences unless asked for more detail. 96 97 You have these tools: 98 - get_weather: current weather and forecast 99 - get_home_state / control_home_device: smart home status and control 100 - query_knowledge: family knowledge graph (SPARQL) 101 - execute_code: Python REPL for anything else — system diagnostics, calculations, data lookups, API queries 102 103 When a question can't be answered with the specific tools, use execute_code to write Python that investigates. You have access to Docker, Prometheus, HomeAssistant, and NATS APIs from the REPL. 104 105 IMPORTANT RULES: 106 - Do NOT use <think> tags or any internal reasoning. Respond directly. 107 - When you call a tool, wait for the result and give a natural spoken response based on the data. 108 - Keep responses conversational — you are being spoken aloud by text-to-speech. 109 - NEVER end with offers like "let me know if there's anything else." Just answer and stop. 110 - If you recognize the speaker (e.g. [Cam]), personalize your response.\ 111 """) + "\n/no_think" 112 113 114 async def main(): 115 # --- Services pointing at local stack --- 116 stt = OpenAISTTService( 117 api_key="not-needed", 118 base_url=STT_BASE_URL, 119 model="Systran/faster-whisper-large-v3", 120 ) 121 122 llm = OpenAILLMService( 123 api_key="not-needed", 124 base_url=LLM_BASE_URL, 125 model=LLM_MODEL, 126 ) 127 128 if TTS_ENGINE == "fish": 129 from fish_speech_tts import FishSpeechLocalService 130 tts = FishSpeechLocalService( 131 base_url=FISH_SPEECH_URL, 132 reference_audio_path=FISH_REFERENCE_AUDIO or None, 133 reference_text=FISH_REFERENCE_TEXT or None, 134 output_sample_rate=24000, 135 ) 136 logger.info(f"Using Fish Speech TTS at {FISH_SPEECH_URL}") 137 else: 138 tts = OpenAITTSService( 139 api_key="not-needed", 140 base_url=TTS_BASE_URL, 141 model="kokoro", 142 voice=TTS_VOICE, 143 speed=1.05, 144 ) 145 logger.info(f"Using Kokoro TTS at {TTS_BASE_URL}") 146 147 # --- Transport: standalone WebSocket server --- 148 transport = WebsocketServerTransport( 149 host=BOT_HOST, 150 port=BOT_PORT, 151 params=WebsocketServerParams( 152 audio_in_enabled=True, 153 audio_in_sample_rate=16000, 154 audio_out_enabled=True, 155 audio_out_sample_rate=24000, 156 add_wav_header=False, 157 vad_enabled=True, 158 vad_analyzer=SileroVADAnalyzer(params=VADParams( 159 confidence=0.85, 160 start_secs=0.4, 161 stop_secs=0.8, 162 min_volume=0.6, 163 )), 164 vad_audio_passthrough=True, 165 serializer=RawPCMSerializer(sample_rate=16000, num_channels=1), 166 ), 167 ) 168 169 # --- LLM context with tools --- 170 messages = [{"role": "system", "content": SYSTEM_PROMPT}] 171 context = OpenAILLMContext(messages, tools=TOOL_DEFINITIONS) 172 context_aggregator = llm.create_context_aggregator(context) 173 174 # --- Context compaction (prevent context window overflow) --- 175 from context_compactor import maybe_compact_context 176 177 original_process = llm.process_frame.__func__ if hasattr(llm.process_frame, '__func__') else None 178 179 # Hook: compact context before each LLM generation 180 # We'll check periodically in the tool handler since hooking LLM internals is fragile 181 _turn_count = [0] 182 183 # --- Register tool call handlers --- 184 async def _handle_tool(params): 185 # Check context compaction every few tool calls 186 _turn_count[0] += 1 187 if _turn_count[0] % 3 == 0: 188 maybe_compact_context(context) 189 190 handler = TOOL_HANDLERS.get(params.function_name) 191 if handler: 192 result = await handler(params.arguments) 193 await params.result_callback(result) 194 else: 195 await params.result_callback(json.dumps({"error": f"Unknown function: {params.function_name}"})) 196 197 for tool_name in TOOL_HANDLERS: 198 llm.register_function(tool_name, _handle_tool) 199 200 # --- Wake word gate (if enabled) --- 201 wake_gate = None 202 if WAKE_WORD_ENABLED: 203 from wake_word_gate import WakeWordGateProcessor 204 wake_gate = WakeWordGateProcessor( 205 wyoming_host=WAKE_WORD_HOST, 206 wyoming_port=WAKE_WORD_PORT, 207 wake_word_name=WAKE_WORD_NAME, 208 idle_timeout=WAKE_WORD_IDLE_TIMEOUT, 209 ) 210 logger.info(f"Wake word gate enabled: '{WAKE_WORD_NAME}' via {WAKE_WORD_HOST}:{WAKE_WORD_PORT}") 211 else: 212 logger.info("Wake word gate disabled — always listening") 213 214 # --- Diarization (if enabled — replaces basic speaker ID) --- 215 diarizer = None 216 if DIARIZATION_ENABLED: 217 from diarization_processor import DiarizationProcessor 218 diarizer = DiarizationProcessor(diarization_url=DIARIZATION_URL) 219 logger.info(f"Diarization enabled via {DIARIZATION_URL}") 220 else: 221 logger.info("Diarization disabled") 222 223 # --- Speaker identification (if enabled — fallback when diarization off) --- 224 speaker_id = None 225 if SPEAKER_ID_ENABLED and not DIARIZATION_ENABLED: 226 from speaker_id import SpeakerIdentifier 227 speaker_id = SpeakerIdentifier(enrollment_url=ENROLLMENT_URL) 228 logger.info(f"Speaker identification enabled via {ENROLLMENT_URL}") 229 else: 230 logger.info("Speaker identification disabled") 231 232 # --- Think tag filter (strips <think>...</think> before TTS) --- 233 think_filter = ThinkTagFilter() 234 235 # --- Pipeline --- 236 pipeline_stages = [transport.input()] 237 if wake_gate: 238 pipeline_stages.append(wake_gate) 239 if diarizer: 240 pipeline_stages.append(diarizer) 241 pipeline_stages.append(stt) 242 if speaker_id: 243 pipeline_stages.append(speaker_id) 244 245 # --- Fast-path handler (time, date — bypass LLM) --- 246 fast_path = None 247 if os.getenv("FAST_PATH_ENABLED", "true").lower() == "true": 248 from fast_path import FastPathProcessor 249 fast_path = FastPathProcessor() 250 logger.info("Fast-path deterministic queries enabled") 251 pipeline_stages.append(fast_path) 252 253 pipeline_stages.extend([ 254 context_aggregator.user(), 255 llm, 256 think_filter, 257 tts, 258 transport.output(), 259 context_aggregator.assistant(), 260 ]) 261 262 pipeline = Pipeline(pipeline_stages) 263 264 task = PipelineTask( 265 pipeline, 266 params=PipelineParams( 267 allow_interruptions=True, 268 enable_metrics=True, 269 enable_usage_metrics=True, 270 idle_timeout_secs=0, # Disable idle timeout — Bob is a persistent server 271 ), 272 ) 273 274 @transport.event_handler("on_client_connected") 275 async def on_client_connected(transport, client): 276 logger.info("Client connected, waiting for user to speak") 277 278 @transport.event_handler("on_client_disconnected") 279 async def on_client_disconnected(transport, client): 280 logger.info("Client disconnected, consolidating session") 281 try: 282 from session_consolidator import consolidate_session, get_conversation_messages 283 msgs = get_conversation_messages(context) 284 # Get current speaker from diarization or speaker_id 285 current_speaker = "Unknown" 286 if diarizer and diarizer._current_speaker: 287 current_speaker = diarizer._current_speaker 288 elif speaker_id and speaker_id._current_speaker: 289 current_speaker = speaker_id._current_speaker 290 await consolidate_session(msgs, current_speaker) 291 except Exception as e: 292 logger.warning(f"Session consolidation error: {e}") 293 294 runner = PipelineRunner() 295 await runner.run(task) 296 297 298 if __name__ == "__main__": 299 asyncio.run(main())