test_classification.py
1 """Classification test suite — 25 labeled requests spanning all tiers. 2 3 Run against a live coordinator: 4 python test_classification.py [--nats-url nats://127.0.0.1:4222] 5 """ 6 7 import asyncio 8 import json 9 import os 10 import sys 11 import time 12 13 import nats 14 15 NATS_URL = os.getenv("NATS_URL", "nats://127.0.0.1:4222") 16 17 # 25 test cases: 5 deterministic, 6 simple, 7 moderate, 7 complex 18 TEST_CASES = [ 19 # ── Deterministic (5) — handled by regex, no LLM ── 20 ("What time is it?", "deterministic"), 21 ("What's the date today?", "deterministic"), 22 ("Hey Bob, good morning", "deterministic"), 23 ("Hello Bob", "deterministic"), 24 ("Good evening Bob", "deterministic"), 25 26 # ── Simple (6) — lightweight model, no tools ── 27 ("What's the capital of France?", "simple"), 28 ("How many ounces in a pound?", "simple"), 29 ("Tell me a joke", "simple"), 30 ("What's the meaning of life?", "simple"), 31 ("How do you say hello in Spanish?", "simple"), 32 ("What year did World War 2 end?", "simple"), 33 34 # ── Moderate (7) — Qwen3-32B, no tools ── 35 ("What's the weather like?", "moderate"), 36 ("What's the weather forecast for tomorrow?", "moderate"), 37 ("Explain how a transformer model works", "moderate"), 38 ("Compare the pros and cons of solar panels versus wind energy", "moderate"), 39 ("Summarize the key differences between TCP and UDP", "moderate"), 40 ("What are the health benefits of intermittent fasting?", "moderate"), 41 ("Write a haiku about the ocean", "moderate"), 42 43 # ── Complex (7) — Qwen3-32B with tools ── 44 ("Turn off the living room lights", "complex"), 45 ("Run a health check on the system", "complex"), 46 ("What's the temperature inside the house?", "complex"), 47 ("Check the system status and tell me about GPU temperatures", "complex"), 48 ("Is anyone home right now?", "complex"), 49 ("Update the knowledge base", "complex"), 50 ("Give me my morning briefing", "complex"), 51 ] 52 53 54 async def run_tests(): 55 nc = await nats.connect(NATS_URL) 56 results = [] 57 correct = 0 58 total = len(TEST_CASES) 59 60 for text, expected_tier in TEST_CASES: 61 correlation_id = f"test-{time.time()}" 62 response_future = asyncio.get_event_loop().create_future() 63 64 async def on_response(msg): 65 try: 66 data = json.loads(msg.data.decode()) 67 if data.get("correlation_id") == correlation_id and not response_future.done(): 68 response_future.set_result(data) 69 except Exception: 70 pass 71 72 sub = await nc.subscribe("bob.coordinator.response", cb=on_response) 73 74 await nc.publish( 75 "bob.coordinator.request", 76 json.dumps({"text": text, "correlation_id": correlation_id, "context": []}).encode(), 77 ) 78 79 try: 80 # Complex requests may dispatch agents — allow up to 90s 81 timeout = 90.0 if expected_tier == "complex" else 15.0 82 response = await asyncio.wait_for(response_future, timeout=timeout) 83 actual_tier = response.get("tier", "unknown") 84 latency = response.get("latency_ms", 0) 85 match = actual_tier == expected_tier 86 if match: 87 correct += 1 88 status = "PASS" if match else "FAIL" 89 results.append((text, expected_tier, actual_tier, latency, status)) 90 print(f" [{status}] \"{text[:50]}\" → expected={expected_tier}, got={actual_tier} ({latency}ms)") 91 except asyncio.TimeoutError: 92 results.append((text, expected_tier, "TIMEOUT", 0, "FAIL")) 93 print(f" [FAIL] \"{text[:50]}\" → TIMEOUT") 94 finally: 95 await sub.unsubscribe() 96 97 accuracy = correct / total * 100 98 print(f"\n{'='*60}") 99 print(f"Results: {correct}/{total} correct ({accuracy:.1f}%)") 100 print(f"Target: >90% ({int(total * 0.9)}/{total})") 101 print(f"{'PASS' if accuracy >= 90 else 'FAIL'}: {'meets' if accuracy >= 90 else 'does not meet'} acceptance criteria") 102 103 # Latency report 104 latencies = [r[3] for r in results if r[3] > 0] 105 if latencies: 106 print(f"\nLatency: min={min(latencies)}ms, max={max(latencies)}ms, avg={sum(latencies)/len(latencies):.0f}ms") 107 108 await nc.close() 109 return accuracy >= 90 110 111 112 if __name__ == "__main__": 113 if len(sys.argv) > 2 and sys.argv[1] == "--nats-url": 114 NATS_URL = sys.argv[2] 115 success = asyncio.run(run_tests()) 116 sys.exit(0 if success else 1)