/ services / coordinator / test_classification.py
test_classification.py
  1  """Classification test suite — 25 labeled requests spanning all tiers.
  2  
  3  Run against a live coordinator:
  4    python test_classification.py [--nats-url nats://127.0.0.1:4222]
  5  """
  6  
  7  import asyncio
  8  import json
  9  import os
 10  import sys
 11  import time
 12  
 13  import nats
 14  
 15  NATS_URL = os.getenv("NATS_URL", "nats://127.0.0.1:4222")
 16  
 17  # 25 test cases: 5 deterministic, 6 simple, 7 moderate, 7 complex
 18  TEST_CASES = [
 19      # ── Deterministic (5) — handled by regex, no LLM ──
 20      ("What time is it?", "deterministic"),
 21      ("What's the date today?", "deterministic"),
 22      ("Hey Bob, good morning", "deterministic"),
 23      ("Hello Bob", "deterministic"),
 24      ("Good evening Bob", "deterministic"),
 25  
 26      # ── Simple (6) — lightweight model, no tools ──
 27      ("What's the capital of France?", "simple"),
 28      ("How many ounces in a pound?", "simple"),
 29      ("Tell me a joke", "simple"),
 30      ("What's the meaning of life?", "simple"),
 31      ("How do you say hello in Spanish?", "simple"),
 32      ("What year did World War 2 end?", "simple"),
 33  
 34      # ── Moderate (7) — Qwen3-32B, no tools ──
 35      ("What's the weather like?", "moderate"),
 36      ("What's the weather forecast for tomorrow?", "moderate"),
 37      ("Explain how a transformer model works", "moderate"),
 38      ("Compare the pros and cons of solar panels versus wind energy", "moderate"),
 39      ("Summarize the key differences between TCP and UDP", "moderate"),
 40      ("What are the health benefits of intermittent fasting?", "moderate"),
 41      ("Write a haiku about the ocean", "moderate"),
 42  
 43      # ── Complex (7) — Qwen3-32B with tools ──
 44      ("Turn off the living room lights", "complex"),
 45      ("Run a health check on the system", "complex"),
 46      ("What's the temperature inside the house?", "complex"),
 47      ("Check the system status and tell me about GPU temperatures", "complex"),
 48      ("Is anyone home right now?", "complex"),
 49      ("Update the knowledge base", "complex"),
 50      ("Give me my morning briefing", "complex"),
 51  ]
 52  
 53  
 54  async def run_tests():
 55      nc = await nats.connect(NATS_URL)
 56      results = []
 57      correct = 0
 58      total = len(TEST_CASES)
 59  
 60      for text, expected_tier in TEST_CASES:
 61          correlation_id = f"test-{time.time()}"
 62          response_future = asyncio.get_event_loop().create_future()
 63  
 64          async def on_response(msg):
 65              try:
 66                  data = json.loads(msg.data.decode())
 67                  if data.get("correlation_id") == correlation_id and not response_future.done():
 68                      response_future.set_result(data)
 69              except Exception:
 70                  pass
 71  
 72          sub = await nc.subscribe("bob.coordinator.response", cb=on_response)
 73  
 74          await nc.publish(
 75              "bob.coordinator.request",
 76              json.dumps({"text": text, "correlation_id": correlation_id, "context": []}).encode(),
 77          )
 78  
 79          try:
 80              # Complex requests may dispatch agents — allow up to 90s
 81              timeout = 90.0 if expected_tier == "complex" else 15.0
 82              response = await asyncio.wait_for(response_future, timeout=timeout)
 83              actual_tier = response.get("tier", "unknown")
 84              latency = response.get("latency_ms", 0)
 85              match = actual_tier == expected_tier
 86              if match:
 87                  correct += 1
 88              status = "PASS" if match else "FAIL"
 89              results.append((text, expected_tier, actual_tier, latency, status))
 90              print(f"  [{status}] \"{text[:50]}\" → expected={expected_tier}, got={actual_tier} ({latency}ms)")
 91          except asyncio.TimeoutError:
 92              results.append((text, expected_tier, "TIMEOUT", 0, "FAIL"))
 93              print(f"  [FAIL] \"{text[:50]}\" → TIMEOUT")
 94          finally:
 95              await sub.unsubscribe()
 96  
 97      accuracy = correct / total * 100
 98      print(f"\n{'='*60}")
 99      print(f"Results: {correct}/{total} correct ({accuracy:.1f}%)")
100      print(f"Target: >90% ({int(total * 0.9)}/{total})")
101      print(f"{'PASS' if accuracy >= 90 else 'FAIL'}: {'meets' if accuracy >= 90 else 'does not meet'} acceptance criteria")
102  
103      # Latency report
104      latencies = [r[3] for r in results if r[3] > 0]
105      if latencies:
106          print(f"\nLatency: min={min(latencies)}ms, max={max(latencies)}ms, avg={sum(latencies)/len(latencies):.0f}ms")
107  
108      await nc.close()
109      return accuracy >= 90
110  
111  
112  if __name__ == "__main__":
113      if len(sys.argv) > 2 and sys.argv[1] == "--nats-url":
114          NATS_URL = sys.argv[2]
115      success = asyncio.run(run_tests())
116      sys.exit(0 if success else 1)