coordinator.py
1 """Bob Coordinator Agent — classifies requests, routes to model tiers, dispatches agents. 2 3 Subscribes to bob.coordinator.request on NATS, classifies user text via 4 Qwen3.5-9B (lightweight classifier on GPU 2), routes to appropriate model tier, 5 and publishes responses to bob.coordinator.response. Also handles demand-driven 6 agent dispatch and result aggregation. 7 8 Model tiers: 9 - deterministic: regex/lookup, no LLM (time, date, greetings) 10 - simple: Qwen3.5-9B direct response (trivia, casual conversation) 11 - moderate: Qwen3-32B, no tools (weather, explanations, analysis) 12 - complex: Qwen3-32B with full tool set (HA control, agent dispatch) 13 """ 14 15 import asyncio 16 import json 17 import os 18 import re 19 import sys 20 import time 21 import uuid 22 from datetime import datetime, timezone 23 from zoneinfo import ZoneInfo 24 25 import httpx 26 27 LOCAL_TZ = ZoneInfo(os.getenv("TIMEZONE", "America/New_York")) 28 import nats 29 from prometheus_client import Counter, Histogram, Gauge, start_http_server 30 31 # ── Configuration ──────────────────────────────────────────────────── 32 33 NATS_URL = os.getenv("NATS_URL", "nats://127.0.0.1:4222") 34 CLASSIFIER_URL = os.getenv("CLASSIFIER_URL", "http://127.0.0.1:8001/v1") 35 CLASSIFIER_MODEL = os.getenv("CLASSIFIER_MODEL", "QuantTrio/Qwen3.5-9B-AWQ") 36 PRIMARY_LLM_URL = os.getenv("PRIMARY_LLM_URL", "http://127.0.0.1:8000/v1") 37 PRIMARY_LLM_MODEL = os.getenv("PRIMARY_LLM_MODEL", "Qwen/Qwen3-32B-AWQ") 38 HA_URL = os.getenv("HA_URL", "http://127.0.0.1:8123") 39 HA_TOKEN = os.getenv("HA_TOKEN", "") 40 OXIGRAPH_URL = os.getenv("OXIGRAPH_URL", "http://127.0.0.1:7878") 41 REPL_URL = os.getenv("REPL_URL", "http://127.0.0.1:10900") 42 METRICS_PORT = int(os.getenv("METRICS_PORT", "8002")) 43 AGENT_DISPATCH_TIMEOUT = int(os.getenv("AGENT_DISPATCH_TIMEOUT", "60")) 44 CONFIDENCE_THRESHOLD = float(os.getenv("CONFIDENCE_THRESHOLD", "0.7")) 45 CLASSIFIER_HEALTH_INTERVAL = int(os.getenv("CLASSIFIER_HEALTH_INTERVAL", "5")) 46 MAX_TOOL_ROUNDS = int(os.getenv("MAX_TOOL_ROUNDS", "5")) 47 48 STREAM_NAME = "BOB_COORDINATOR" 49 50 # ── Prometheus Metrics ─────────────────────────────────────────────── 51 52 classifications_total = Counter( 53 "coordinator_classifications_total", "Total classifications", ["tier"] 54 ) 55 classification_latency = Histogram( 56 "coordinator_classification_latency_seconds", "Classification latency" 57 ) 58 routing_latency = Histogram( 59 "coordinator_routing_latency_seconds", "Routing latency by tier", ["tier"] 60 ) 61 fallback_activations = Counter( 62 "coordinator_fallback_activations_total", "Fallback activations" 63 ) 64 agent_dispatches = Counter( 65 "coordinator_agent_dispatches_total", "Agent dispatches", ["agent", "outcome"] 66 ) 67 agent_dispatch_timeouts = Counter( 68 "coordinator_agent_dispatch_timeouts_total", "Agent dispatch timeouts", ["agent"] 69 ) 70 requests_total = Counter("coordinator_requests_total", "Total requests") 71 errors_total = Counter("coordinator_errors_total", "Errors", ["type"]) 72 coordinator_mode = Gauge("coordinator_mode", "0=normal, 1=degraded, 2=unavailable") 73 74 # ── Deterministic Patterns ─────────────────────────────────────────── 75 76 DETERMINISTIC_PATTERNS = [ 77 ( 78 re.compile( 79 r"\b(what\s+time|current\s+time|tell\s+me\s+the\s+time)\b", re.I 80 ), 81 lambda m: f"It's {datetime.now(LOCAL_TZ).strftime('%-I:%M %p')}.", 82 ), 83 ( 84 re.compile( 85 r"\b(what('s| is)\s+the\s+date|what\s+day\s+is\s+it|today'?s?\s+date)\b", 86 re.I, 87 ), 88 lambda m: f"Today is {datetime.now(LOCAL_TZ).strftime('%A, %B %-d')}.", 89 ), 90 ( 91 re.compile( 92 r"^(?:(?:hi|hey|hello|good\s+(?:morning|afternoon|evening))\s*,?\s*bob" 93 r"|(?:hi|hey|hello)\s+bob\s*,?\s*(?:good\s+(?:morning|afternoon|evening))?)" 94 r"[\s!.?]*$", 95 re.I, 96 ), 97 lambda m: "Hey! What can I help you with?", 98 ), 99 ] 100 101 # ── Classification Prompt ──────────────────────────────────────────── 102 103 CLASSIFICATION_PROMPT = """\ 104 You are a request classifier. Classify the user's message into ONE tier. 105 106 SIMPLE — Can be answered from general knowledge alone. No APIs, no tools, no internet needed. 107 Examples: trivia, jokes, unit conversions, translations, casual chat, "what's the capital of France". 108 109 MODERATE — Needs a powerful model for reasoning or detailed analysis, but does NOT need to control devices or run system commands. 110 Examples: weather forecasts, detailed explanations, in-depth comparisons, technical summaries, creative writing, math problems, "explain transformers", "compare TCP vs UDP". 111 112 COMPLEX — Needs to interact with external systems: smart home devices, system diagnostics, databases, agent dispatch, memory recall, news/current events, proxy status, or operational commands. 113 Examples: "turn off lights", "run a health check", "who's home", "check GPU temps", "update knowledge base", "give me my morning briefing", "run the evening summary", "do you remember", "what did we talk about", "what did Cam say", "what's in the news", "any headlines today", "what's happening in the world", "is the proxy working", "proxy status", "check the proxy". 114 115 KEY RULES: 116 1. If the request does NOT require controlling a device, querying a database, or running a system command, it is NOT complex. 117 2. Requests that trigger named agents or operational routines (briefings, health checks, knowledge updates) ARE complex. 118 3. Requests about past conversations, memories, or "do you remember" ARE complex (requires memory database lookup). 119 4. If in doubt between simple and moderate, prefer moderate for anything requiring multi-paragraph explanation. 120 121 Respond ONLY with this JSON (no other text): 122 {"tier": "<simple|moderate|complex>", "confidence": <0.0-1.0>, "reasoning": "<10 words max>"}""" 123 124 # ── Tool Definitions (for complex tier) ────────────────────────────── 125 126 COMPLEX_TOOLS = [ 127 { 128 "type": "function", 129 "function": { 130 "name": "get_weather", 131 "description": "Get current weather for a location.", 132 "parameters": { 133 "type": "object", 134 "properties": { 135 "location": {"type": "string", "description": "City and state/country"} 136 }, 137 "required": ["location"], 138 }, 139 }, 140 }, 141 { 142 "type": "function", 143 "function": { 144 "name": "get_home_state", 145 "description": "Get state of a smart home entity from Home Assistant.", 146 "parameters": { 147 "type": "object", 148 "properties": { 149 "entity_id": {"type": "string", "description": "HA entity ID"} 150 }, 151 "required": ["entity_id"], 152 }, 153 }, 154 }, 155 { 156 "type": "function", 157 "function": { 158 "name": "control_home_device", 159 "description": "Control a smart home device (turn on/off/toggle).", 160 "parameters": { 161 "type": "object", 162 "properties": { 163 "entity_id": {"type": "string"}, 164 "action": {"type": "string", "enum": ["turn_on", "turn_off", "toggle"]}, 165 }, 166 "required": ["entity_id", "action"], 167 }, 168 }, 169 }, 170 { 171 "type": "function", 172 "function": { 173 "name": "query_knowledge", 174 "description": "Query the family knowledge graph via SPARQL.", 175 "parameters": { 176 "type": "object", 177 "properties": { 178 "query": {"type": "string", "description": "SPARQL query"} 179 }, 180 "required": ["query"], 181 }, 182 }, 183 }, 184 { 185 "type": "function", 186 "function": { 187 "name": "execute_code", 188 "description": "Execute Python in a sandboxed REPL. Access to Docker, Prometheus, HA, Oxigraph, NATS APIs.", 189 "parameters": { 190 "type": "object", 191 "properties": { 192 "code": {"type": "string"}, 193 "timeout": {"type": "integer", "description": "Timeout seconds (max 60)"}, 194 }, 195 "required": ["code"], 196 }, 197 }, 198 }, 199 { 200 "type": "function", 201 "function": { 202 "name": "get_news", 203 "description": "Get the latest news headlines and weather alerts.", 204 "parameters": { 205 "type": "object", 206 "properties": { 207 "category": {"type": "string", "description": "Filter: world, general, tech, local, or all"}, 208 }, 209 }, 210 }, 211 }, 212 { 213 "type": "function", 214 "function": { 215 "name": "recall_memory", 216 "description": "Search Bob's memory for past conversations and facts about family members. Use when the user references prior interactions or asks 'do you remember'.", 217 "parameters": { 218 "type": "object", 219 "properties": { 220 "query": {"type": "string", "description": "What to search for in memory"}, 221 }, 222 "required": ["query"], 223 }, 224 }, 225 }, 226 { 227 "type": "function", 228 "function": { 229 "name": "dispatch_agent", 230 "description": "Trigger an operational agent and wait for result. Agents: home_keeper (health checks), system_sentinel (deep monitoring), knowledge_gardener (knowledge maintenance), morning_coordinator (daily briefing), evening_coordinator (evening summary).", 231 "parameters": { 232 "type": "object", 233 "properties": { 234 "agent": { 235 "type": "string", 236 "enum": [ 237 "home_keeper", 238 "system_sentinel", 239 "knowledge_gardener", 240 "morning_coordinator", 241 "evening_coordinator", 242 ], 243 }, 244 }, 245 "required": ["agent"], 246 }, 247 }, 248 }, 249 { 250 "type": "function", 251 "function": { 252 "name": "get_proxy_status", 253 "description": "Check the residential HTTP proxy health — container status, functional test, auth failure count.", 254 "parameters": { 255 "type": "object", 256 "properties": {}, 257 }, 258 }, 259 }, 260 ] 261 262 SYSTEM_PROMPT = """\ 263 You are Bob, the Hunt family's AI assistant in Tampa, FL. 264 Keep responses brief — one to three sentences unless asked for more detail. 265 When you call a tool, wait for the result and respond naturally. 266 NEVER end with offers like "let me know if there's anything else." 267 Do NOT use <think> tags. Respond directly. 268 /no_think""" 269 270 271 class CoordinatorState: 272 """Tracks health of classifier and primary LLM.""" 273 274 def __init__(self): 275 self.classifier_healthy = True 276 self.primary_healthy = True 277 self.classifier_fail_count = 0 278 self.classifier_pass_count = 0 279 self.start_time = time.time() 280 281 @property 282 def mode(self) -> str: 283 if not self.primary_healthy and not self.classifier_healthy: 284 return "unavailable" 285 if not self.classifier_healthy: 286 return "degraded" 287 return "normal" 288 289 def classifier_failed(self): 290 self.classifier_fail_count += 1 291 self.classifier_pass_count = 0 292 if self.classifier_fail_count >= 3: 293 if self.classifier_healthy: 294 print("Classifier unhealthy — entering DEGRADED mode") 295 fallback_activations.inc() 296 self.classifier_healthy = False 297 298 def classifier_passed(self): 299 self.classifier_pass_count += 1 300 self.classifier_fail_count = 0 301 if self.classifier_pass_count >= 3 and not self.classifier_healthy: 302 print("Classifier recovered — entering NORMAL mode") 303 self.classifier_healthy = True 304 305 def primary_failed(self): 306 if self.primary_healthy: 307 print("Primary LLM unhealthy") 308 self.primary_healthy = False 309 310 def primary_passed(self): 311 self.primary_healthy = True 312 313 314 state = CoordinatorState() 315 316 317 # ── Tool Execution ─────────────────────────────────────────────────── 318 319 async def execute_tool(client: httpx.AsyncClient, name: str, args: dict, js, nc) -> str: 320 """Execute a tool call and return the result as a string.""" 321 try: 322 if name == "get_weather": 323 return await _tool_weather(client, args) 324 elif name == "get_home_state": 325 return await _tool_ha_state(client, args) 326 elif name == "control_home_device": 327 return await _tool_ha_control(client, args) 328 elif name == "query_knowledge": 329 return await _tool_sparql(client, args) 330 elif name == "execute_code": 331 return await _tool_repl(client, args) 332 elif name == "get_proxy_status": 333 return await _tool_proxy_status(client, args) 334 elif name == "get_news": 335 return await _tool_get_news(nc, args) 336 elif name == "recall_memory": 337 return await _tool_recall_memory(client, args) 338 elif name == "dispatch_agent": 339 return await _tool_dispatch_agent(js, nc, args) 340 else: 341 return json.dumps({"error": f"Unknown tool: {name}"}) 342 except Exception as e: 343 errors_total.labels(type="tool_execution").inc() 344 return json.dumps({"error": str(e)}) 345 346 347 async def _tool_weather(client: httpx.AsyncClient, args: dict) -> str: 348 location = args.get("location", "Tampa, FL") 349 # Geocode 350 geo_url = f"https://geocoding-api.open-meteo.com/v1/search?name={location}&count=1" 351 geo = await client.get(geo_url, timeout=10) 352 geo_data = geo.json() 353 if not geo_data.get("results"): 354 return json.dumps({"error": f"Location not found: {location}"}) 355 lat = geo_data["results"][0]["latitude"] 356 lon = geo_data["results"][0]["longitude"] 357 name = geo_data["results"][0].get("name", location) 358 # Weather 359 wx_url = ( 360 f"https://api.open-meteo.com/v1/forecast?" 361 f"latitude={lat}&longitude={lon}" 362 f"¤t=temperature_2m,apparent_temperature,weather_code,wind_speed_10m,relative_humidity_2m" 363 f"&temperature_unit=fahrenheit&wind_speed_unit=mph" 364 ) 365 wx = await client.get(wx_url, timeout=10) 366 return json.dumps({"location": name, **wx.json().get("current", {})}) 367 368 369 async def _tool_ha_state(client: httpx.AsyncClient, args: dict) -> str: 370 entity_id = args.get("entity_id", "") 371 url = f"{HA_URL}/api/states/{entity_id}" 372 headers = {"Authorization": f"Bearer {HA_TOKEN}"} 373 r = await client.get(url, headers=headers, timeout=10) 374 return r.text 375 376 377 async def _tool_ha_control(client: httpx.AsyncClient, args: dict) -> str: 378 entity_id = args.get("entity_id", "") 379 action = args.get("action", "toggle") 380 domain = entity_id.split(".")[0] if "." in entity_id else "light" 381 url = f"{HA_URL}/api/services/{domain}/{action}" 382 headers = {"Authorization": f"Bearer {HA_TOKEN}", "Content-Type": "application/json"} 383 r = await client.post(url, headers=headers, json={"entity_id": entity_id}, timeout=10) 384 return json.dumps({"status": "ok", "action": action, "entity": entity_id}) 385 386 387 async def _tool_sparql(client: httpx.AsyncClient, args: dict) -> str: 388 query = args.get("query", "") 389 url = f"{OXIGRAPH_URL}/query" 390 r = await client.post(url, data=query, headers={"Content-Type": "application/sparql-query"}, timeout=10) 391 return r.text[:2000] # Truncate for context 392 393 394 async def _tool_repl(client: httpx.AsyncClient, args: dict) -> str: 395 code = args.get("code", "") 396 timeout = min(args.get("timeout", 30), 60) 397 url = f"{REPL_URL}/execute" 398 r = await client.post(url, json={"code": code, "timeout": timeout}, timeout=timeout + 5) 399 return r.text[:2000] 400 401 402 async def _tool_dispatch_agent(js, nc, args: dict) -> str: 403 """Dispatch an operational agent and wait for its result.""" 404 agent_name = args.get("agent", "") 405 correlation_id = str(uuid.uuid4()) 406 407 # Publish trigger 408 trigger_subject = f"bob.agent.{agent_name}.trigger" 409 payload = { 410 "role": agent_name, 411 "run_id": correlation_id, 412 "triggered_at": datetime.now(timezone.utc).isoformat(), 413 "source": "coordinator", 414 "description": f"On-demand dispatch via voice", 415 "timeout_seconds": AGENT_DISPATCH_TIMEOUT, 416 } 417 await js.publish(trigger_subject, json.dumps(payload).encode()) 418 print(f"Dispatched {agent_name} (correlation={correlation_id})") 419 agent_dispatches.labels(agent=agent_name, outcome="dispatched").inc() 420 421 # Subscribe and wait for result 422 result_subject = f"bob.agent.{agent_name}.result" 423 result_future = asyncio.get_running_loop().create_future() 424 425 async def result_handler(msg): 426 try: 427 data = json.loads(msg.data.decode()) 428 # Match by run_id/correlation_id to avoid stealing another dispatch's result 429 rid = data.get("run_id", data.get("correlation_id", "")) 430 if rid == correlation_id and not result_future.done(): 431 result_future.set_result(data) 432 except Exception: 433 pass 434 435 sub = await nc.subscribe(result_subject, cb=result_handler) 436 437 try: 438 result = await asyncio.wait_for(result_future, timeout=AGENT_DISPATCH_TIMEOUT) 439 agent_dispatches.labels(agent=agent_name, outcome="completed").inc() 440 # Return a summary of the result 441 summary = result.get("summary", json.dumps(result)[:1000]) 442 return json.dumps({"agent": agent_name, "status": "completed", "summary": summary}) 443 except asyncio.TimeoutError: 444 agent_dispatch_timeouts.labels(agent=agent_name).inc() 445 agent_dispatches.labels(agent=agent_name, outcome="timeout").inc() 446 return json.dumps({ 447 "agent": agent_name, 448 "status": "timeout", 449 "message": f"Agent {agent_name} did not respond within {AGENT_DISPATCH_TIMEOUT} seconds.", 450 }) 451 finally: 452 await sub.unsubscribe() 453 454 455 async def _tool_proxy_status(client: httpx.AsyncClient, args: dict) -> str: 456 """Check Squid proxy health via SSH and functional test.""" 457 import subprocess 458 results = {} 459 460 SSH_OPTS = ["-o", "ConnectTimeout=3", "-o", "BatchMode=yes", 461 "-o", "StrictHostKeyChecking=no", "-F", "/dev/null", 462 "-i", "/root/.ssh/id_ed25519"] 463 464 # Container status via SSH 465 try: 466 r = subprocess.run( 467 ["ssh"] + SSH_OPTS + ["nuclide@nuclide-amd.lan", 468 "docker ps --filter name=squid-proxy --format '{{.Status}}'"], 469 capture_output=True, text=True, timeout=10, 470 ) 471 results["container"] = r.stdout.strip() or "DOWN" 472 except Exception as e: 473 results["container"] = f"SSH failed: {e}" 474 475 # Auth failures 476 try: 477 r = subprocess.run( 478 ["ssh"] + SSH_OPTS + ["nuclide@nuclide-amd.lan", 479 "docker logs --since 1h squid-proxy 2>&1 | grep -c TCP_DENIED/407 || echo 0"], 480 capture_output=True, text=True, timeout=10, 481 ) 482 results["auth_failures_1h"] = int(r.stdout.strip() or 0) 483 except Exception: 484 results["auth_failures_1h"] = "unknown" 485 486 # Functional test (curl through proxy) 487 proxy_pass = os.getenv("PROXY_PASSWORD", "") 488 if proxy_pass: 489 try: 490 r = subprocess.run( 491 ["curl", "-s", "--max-time", "10", "--proxy", 492 f"http://glean:{proxy_pass}@nuclide-amd.lan:3128", "https://api.ipify.org"], 493 capture_output=True, text=True, timeout=15, 494 ) 495 results["egress_ip"] = r.stdout.strip() or "FAILED" 496 results["functional"] = results["egress_ip"] == "47.205.28.88" 497 except Exception as e: 498 results["functional"] = False 499 results["egress_ip"] = f"test failed: {e}" 500 else: 501 results["functional"] = "unknown (no PROXY_PASSWORD)" 502 503 return json.dumps(results) 504 505 506 async def _tool_get_news(nc, args: dict) -> str: 507 """Fetch latest news headlines from NATS cache.""" 508 category = args.get("category", "all") 509 try: 510 js = nc.jetstream() 511 sub = await js.subscribe("bob.news.headlines", deliver_policy="last") 512 msg = await asyncio.wait_for(sub.next_msg(), timeout=5) 513 await sub.unsubscribe() 514 data = json.loads(msg.data.decode()) 515 headlines = data.get("headlines", []) 516 if category and category != "all": 517 headlines = [h for h in headlines if h.get("category") == category] 518 return json.dumps({ 519 "headlines": headlines[:5], 520 "weather_alerts": data.get("weather_alerts", [])[:3], 521 "total_available": len(data.get("headlines", [])), 522 }) 523 except Exception as e: 524 return json.dumps({"error": f"News not available: {e}"}) 525 526 527 NEO4J_URL = os.getenv("NEO4J_URL", "http://127.0.0.1:7474") 528 NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") 529 NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "") # set via sops env file 530 531 532 async def _tool_recall_memory(client: httpx.AsyncClient, args: dict) -> str: 533 """Search Graphiti temporal memory in Neo4j.""" 534 import base64 535 query = args.get("query", "") 536 keyword = query.split()[0] if query.split() else query 537 auth = base64.b64encode(f"{NEO4J_USER}:{NEO4J_PASSWORD}".encode()).decode() 538 headers = {"Authorization": f"Basic {auth}", "Content-Type": "application/json"} 539 540 cypher_episodes = """ 541 MATCH (e:Episodic) 542 WHERE e.content CONTAINS $keyword OR e.name CONTAINS $keyword 543 RETURN e.name AS name, e.content AS content, e.created_at AS created_at 544 ORDER BY e.created_at DESC LIMIT 5 545 """ 546 cypher_entities = """ 547 MATCH (n:Entity) 548 WHERE n.name CONTAINS $keyword OR n.summary CONTAINS $keyword 549 RETURN n.name AS name, n.summary AS summary LIMIT 5 550 """ 551 552 results = {"episodes": [], "entities": []} 553 try: 554 for cypher, key, fields in [ 555 (cypher_episodes, "episodes", ["name", "content", "created_at"]), 556 (cypher_entities, "entities", ["name", "summary"]), 557 ]: 558 r = await client.post( 559 f"{NEO4J_URL}/db/neo4j/tx/commit", 560 json={"statements": [{"statement": cypher, "parameters": {"keyword": keyword}}]}, 561 headers=headers, timeout=10, 562 ) 563 for row_data in r.json().get("results", [{}])[0].get("data", []): 564 row = row_data.get("row", []) 565 entry = {f: (row[i] or "")[:300] if i < len(row) else "" for i, f in enumerate(fields)} 566 results[key].append(entry) 567 except Exception as e: 568 return json.dumps({"error": f"Memory search failed: {e}"}) 569 570 if not results["episodes"] and not results["entities"]: 571 return json.dumps({"message": "No memories found matching that query."}) 572 return json.dumps(results) 573 574 575 # ── Classification ─────────────────────────────────────────────────── 576 577 def check_deterministic(text: str) -> str | None: 578 """Check if text matches a deterministic pattern. Returns response or None.""" 579 for pattern, handler in DETERMINISTIC_PATTERNS: 580 if pattern.search(text): 581 return handler(None) 582 return None 583 584 585 async def classify_request(client: httpx.AsyncClient, text: str) -> dict: 586 """Classify user text into a tier using the lightweight model.""" 587 with classification_latency.time(): 588 try: 589 r = await client.post( 590 f"{CLASSIFIER_URL}/chat/completions", 591 json={ 592 "model": CLASSIFIER_MODEL, 593 "messages": [ 594 {"role": "system", "content": CLASSIFICATION_PROMPT}, 595 {"role": "user", "content": text + "\n/no_think"}, 596 ], 597 "max_tokens": 100, 598 "temperature": 0.1, 599 "response_format": {"type": "json_object"}, 600 }, 601 timeout=5.0, 602 ) 603 r.raise_for_status() 604 content = r.json()["choices"][0]["message"]["content"] 605 # Parse JSON — with fallback for code fences 606 content = re.sub(r"```json\s*", "", content) 607 content = re.sub(r"```\s*$", "", content) 608 result = json.loads(content.strip()) 609 state.classifier_passed() 610 return result 611 except (httpx.HTTPError, json.JSONDecodeError, KeyError, IndexError) as e: 612 print(f"Classification error: {e}", file=sys.stderr) 613 state.classifier_failed() 614 errors_total.labels(type="classification").inc() 615 return {"tier": "complex", "confidence": 0.0, "reasoning": "classification failed — fallback"} 616 617 618 async def generate_simple_response(client: httpx.AsyncClient, text: str, context: list) -> str: 619 """Generate a direct response from the lightweight model.""" 620 messages = [{"role": "system", "content": SYSTEM_PROMPT}] + context + [{"role": "user", "content": text + "\n/no_think"}] 621 try: 622 r = await client.post( 623 f"{CLASSIFIER_URL}/chat/completions", 624 json={ 625 "model": CLASSIFIER_MODEL, 626 "messages": messages, 627 "max_tokens": 512, 628 "temperature": 0.7, 629 }, 630 timeout=10.0, 631 ) 632 r.raise_for_status() 633 content = r.json()["choices"][0]["message"]["content"] 634 # Strip think tags in case model ignores /no_think 635 content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() 636 return content 637 except Exception as e: 638 print(f"Simple response error: {e}", file=sys.stderr) 639 errors_total.labels(type="simple_response").inc() 640 # Escalate to primary 641 return await generate_primary_response(client, text, context, tools=False) 642 643 644 async def generate_primary_response( 645 client: httpx.AsyncClient, 646 text: str, 647 context: list, 648 tools: bool = False, 649 js=None, 650 nc=None, 651 ) -> str: 652 """Generate a response from Qwen3-32B, optionally with tool calling.""" 653 messages = [{"role": "system", "content": SYSTEM_PROMPT}] + context + [{"role": "user", "content": text}] 654 kwargs = { 655 "model": PRIMARY_LLM_MODEL, 656 "messages": messages, 657 "max_tokens": 1024, 658 "temperature": 0.7, 659 } 660 if tools: 661 kwargs["tools"] = COMPLEX_TOOLS 662 kwargs["tool_choice"] = "auto" 663 664 try: 665 r = await client.post( 666 f"{PRIMARY_LLM_URL}/chat/completions", 667 json=kwargs, 668 timeout=30.0, 669 ) 670 r.raise_for_status() 671 choice = r.json()["choices"][0] 672 message = choice["message"] 673 674 # Handle tool calls iteratively (capped to prevent infinite loops) 675 tool_rounds = 0 676 while message.get("tool_calls") and tool_rounds < MAX_TOOL_ROUNDS: 677 tool_rounds += 1 678 messages.append(message) 679 for tc in message["tool_calls"]: 680 fn_name = tc["function"]["name"] 681 fn_args = json.loads(tc["function"]["arguments"]) if isinstance(tc["function"]["arguments"], str) else tc["function"]["arguments"] 682 result = await execute_tool(client, fn_name, fn_args, js, nc) 683 messages.append({ 684 "role": "tool", 685 "tool_call_id": tc["id"], 686 "content": result, 687 }) 688 689 # Get follow-up response 690 kwargs["messages"] = messages 691 r = await client.post( 692 f"{PRIMARY_LLM_URL}/chat/completions", 693 json=kwargs, 694 timeout=30.0, 695 ) 696 r.raise_for_status() 697 choice = r.json()["choices"][0] 698 message = choice["message"] 699 700 state.primary_passed() 701 content = message.get("content", "") 702 # Strip think tags 703 content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() 704 return content 705 706 except Exception as e: 707 print(f"Primary LLM error: {e}", file=sys.stderr) 708 state.primary_failed() 709 errors_total.labels(type="primary_llm").inc() 710 return "I'm having trouble thinking right now. Try again in a minute." 711 712 713 # ── Request Handler ────────────────────────────────────────────────── 714 715 async def handle_request(msg, client: httpx.AsyncClient, js, nc): 716 """Handle a coordinator request from NATS.""" 717 requests_total.inc() 718 t0 = time.time() 719 720 try: 721 data = json.loads(msg.data.decode()) 722 except json.JSONDecodeError: 723 errors_total.labels(type="parse").inc() 724 return 725 726 text = data.get("text", "") 727 correlation_id = data.get("correlation_id", str(uuid.uuid4())) 728 context = data.get("context", []) # Recent conversation turns 729 730 if not text.strip(): 731 return 732 733 response_text = "" 734 tier = "unknown" 735 736 # Check mode 737 mode = state.mode 738 coordinator_mode.set({"normal": 0, "degraded": 1, "unavailable": 2}[mode]) 739 740 if mode == "unavailable": 741 response_text = "I'm having trouble thinking right now. Try again in a minute." 742 tier = "unavailable" 743 elif mode == "degraded": 744 # Skip classification, route all to primary with tools 745 tier = "complex" 746 with routing_latency.labels(tier=tier).time(): 747 response_text = await generate_primary_response(client, text, context, tools=True, js=js, nc=nc) 748 else: 749 # Normal mode: deterministic → classify → route 750 det_response = check_deterministic(text) 751 if det_response: 752 tier = "deterministic" 753 response_text = det_response 754 classifications_total.labels(tier=tier).inc() 755 else: 756 # Classify via lightweight model 757 classification = await classify_request(client, text) 758 tier = classification.get("tier", "complex") 759 confidence = classification.get("confidence", 0.0) 760 761 # Confidence-based escalation 762 if confidence < CONFIDENCE_THRESHOLD: 763 original_tier = tier 764 if tier == "simple": 765 tier = "moderate" 766 elif tier == "moderate": 767 tier = "complex" 768 if tier != original_tier: 769 print(f"Escalated {original_tier} → {tier} (confidence={confidence:.2f})") 770 771 classifications_total.labels(tier=tier).inc() 772 773 with routing_latency.labels(tier=tier).time(): 774 if tier == "simple": 775 response_text = await generate_simple_response(client, text, context) 776 elif tier == "moderate": 777 response_text = await generate_primary_response(client, text, context, tools=False) 778 else: # complex 779 response_text = await generate_primary_response(client, text, context, tools=True, js=js, nc=nc) 780 781 # Publish response 782 response_payload = { 783 "correlation_id": correlation_id, 784 "text": response_text, 785 "tier": tier, 786 "latency_ms": round((time.time() - t0) * 1000), 787 } 788 789 try: 790 await nc.publish( 791 "bob.coordinator.response", 792 json.dumps(response_payload).encode(), 793 ) 794 except Exception as e: 795 print(f"Failed to publish response: {e}", file=sys.stderr) 796 errors_total.labels(type="publish").inc() 797 798 799 # ── Health Check Background Task ───────────────────────────────────── 800 801 async def health_check_loop(client: httpx.AsyncClient): 802 """Periodically check classifier and primary LLM health.""" 803 while True: 804 try: 805 r = await client.get(f"{CLASSIFIER_URL.removesuffix('/v1')}/health", timeout=3.0) 806 if r.status_code == 200: 807 state.classifier_passed() 808 else: 809 state.classifier_failed() 810 except Exception: 811 state.classifier_failed() 812 813 try: 814 r = await client.get(f"{PRIMARY_LLM_URL.removesuffix('/v1')}/health", timeout=3.0) 815 if r.status_code == 200: 816 state.primary_passed() 817 else: 818 state.primary_failed() 819 except Exception: 820 state.primary_failed() 821 822 coordinator_mode.set({"normal": 0, "degraded": 1, "unavailable": 2}[state.mode]) 823 await asyncio.sleep(CLASSIFIER_HEALTH_INTERVAL) 824 825 826 # ── Health Endpoint ────────────────────────────────────────────────── 827 828 async def serve_health(reader, writer): 829 """Simple TCP health endpoint for /health and /metrics (Prometheus handles /metrics).""" 830 data = await reader.read(1024) 831 request_line = data.decode().split("\r\n")[0] if data else "" 832 833 if "GET /health" in request_line: 834 body = json.dumps({ 835 "status": "ok" if state.mode == "normal" else state.mode, 836 "uptime_seconds": round(time.time() - state.start_time), 837 "classifier": "connected" if state.classifier_healthy else "disconnected", 838 "primary_llm": "connected" if state.primary_healthy else "disconnected", 839 "nats": "connected", 840 "mode": state.mode, 841 }) 842 response = f"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {len(body)}\r\nConnection: close\r\n\r\n{body}" 843 else: 844 response = "HTTP/1.1 404 Not Found\r\nConnection: close\r\n\r\n" 845 846 writer.write(response.encode()) 847 await writer.drain() 848 writer.close() 849 850 851 # ── Main ───────────────────────────────────────────────────────────── 852 853 async def main(): 854 print(f"Bob Coordinator starting...") 855 print(f" Classifier: {CLASSIFIER_URL} ({CLASSIFIER_MODEL})") 856 print(f" Primary LLM: {PRIMARY_LLM_URL} ({PRIMARY_LLM_MODEL})") 857 print(f" NATS: {NATS_URL}") 858 print(f" Metrics: :{METRICS_PORT}") 859 860 # Start Prometheus metrics server 861 start_http_server(METRICS_PORT) 862 print(f"Prometheus metrics on :{METRICS_PORT}/metrics") 863 864 # HTTP client for LLM calls 865 client = httpx.AsyncClient() 866 867 # Connect to NATS 868 nc = await nats.connect(NATS_URL) 869 js = nc.jetstream() 870 871 # Ensure coordinator stream exists 872 try: 873 await js.find_stream_by_subject("bob.coordinator.>") 874 except Exception: 875 try: 876 await js.add_stream( 877 name=STREAM_NAME, 878 subjects=["bob.coordinator.>"], 879 retention="limits", 880 max_msgs=5000, 881 max_age=86400, # 1 day 882 storage="file", 883 discard="old", 884 ) 885 print(f"Created stream {STREAM_NAME}") 886 except Exception as e: 887 print(f"Stream setup: {e}") 888 889 # Subscribe to coordinator requests 890 async def on_request(msg): 891 await handle_request(msg, client, js, nc) 892 893 await nc.subscribe("bob.coordinator.request", cb=on_request) 894 print("Subscribed to bob.coordinator.request") 895 896 # Start health check loop 897 asyncio.create_task(health_check_loop(client)) 898 899 # Start health endpoint (separate from Prometheus port) 900 health_server = await asyncio.start_server(serve_health, "0.0.0.0", METRICS_PORT + 1) 901 print(f"Health endpoint on :{METRICS_PORT + 1}/health") 902 903 # Run warm-up classification to ensure model is ready 904 print("Running warm-up classification...") 905 warmup = await classify_request(client, "Hello Bob") 906 print(f"Warm-up result: {warmup}") 907 908 # Keep running 909 print("Coordinator running.") 910 try: 911 while True: 912 await asyncio.sleep(1) 913 except asyncio.CancelledError: 914 pass 915 finally: 916 await nc.close() 917 await client.aclose() 918 print("Coordinator stopped.") 919 920 921 if __name__ == "__main__": 922 asyncio.run(main())