test_context_engine.py
1 """Tests for the ContextEngine ABC and plugin slot.""" 2 3 import json 4 import pytest 5 from typing import Any, Dict, List 6 7 from agent.context_engine import ContextEngine 8 from agent.context_compressor import ContextCompressor 9 10 11 # --------------------------------------------------------------------------- 12 # A minimal concrete engine for testing the ABC 13 # --------------------------------------------------------------------------- 14 15 class StubEngine(ContextEngine): 16 """Minimal engine that satisfies the ABC without doing real work.""" 17 18 def __init__(self, context_length=200000, threshold_pct=0.50): 19 self.context_length = context_length 20 self.threshold_tokens = int(context_length * threshold_pct) 21 self._compress_called = False 22 self._tools_called = [] 23 24 @property 25 def name(self) -> str: 26 return "stub" 27 28 def update_from_response(self, usage: Dict[str, Any]) -> None: 29 self.last_prompt_tokens = usage.get("prompt_tokens", 0) 30 self.last_completion_tokens = usage.get("completion_tokens", 0) 31 self.last_total_tokens = usage.get("total_tokens", 0) 32 33 def should_compress(self, prompt_tokens: int = None) -> bool: 34 tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens 35 return tokens >= self.threshold_tokens 36 37 def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]: 38 self._compress_called = True 39 self.compression_count += 1 40 # Trivial: just return as-is 41 return messages 42 43 def get_tool_schemas(self) -> List[Dict[str, Any]]: 44 return [ 45 { 46 "name": "stub_search", 47 "description": "Search the stub engine", 48 "parameters": {"type": "object", "properties": {}}, 49 } 50 ] 51 52 def handle_tool_call(self, name: str, args: Dict[str, Any]) -> str: 53 self._tools_called.append(name) 54 return json.dumps({"ok": True, "tool": name}) 55 56 57 # --------------------------------------------------------------------------- 58 # ABC contract tests 59 # --------------------------------------------------------------------------- 60 61 class TestContextEngineABC: 62 """Verify the ABC enforces the required interface.""" 63 64 def test_cannot_instantiate_abc_directly(self): 65 with pytest.raises(TypeError): 66 ContextEngine() 67 68 def test_missing_methods_raises(self): 69 """A subclass missing required methods cannot be instantiated.""" 70 class Incomplete(ContextEngine): 71 @property 72 def name(self): 73 return "incomplete" 74 with pytest.raises(TypeError): 75 Incomplete() 76 77 def test_stub_engine_satisfies_abc(self): 78 engine = StubEngine() 79 assert isinstance(engine, ContextEngine) 80 assert engine.name == "stub" 81 82 def test_compressor_is_context_engine(self): 83 c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000) 84 assert isinstance(c, ContextEngine) 85 assert c.name == "compressor" 86 87 88 # --------------------------------------------------------------------------- 89 # Default method behavior 90 # --------------------------------------------------------------------------- 91 92 class TestDefaults: 93 """Verify ABC default implementations work correctly.""" 94 95 def test_default_tool_schemas_empty(self): 96 engine = StubEngine() 97 # StubEngine overrides this, so test the base via super 98 assert ContextEngine.get_tool_schemas(engine) == [] 99 100 def test_default_handle_tool_call_returns_error(self): 101 engine = StubEngine() 102 result = ContextEngine.handle_tool_call(engine, "unknown", {}) 103 data = json.loads(result) 104 assert "error" in data 105 106 def test_default_get_status(self): 107 engine = StubEngine() 108 engine.last_prompt_tokens = 50000 109 status = engine.get_status() 110 assert status["last_prompt_tokens"] == 50000 111 assert status["context_length"] == 200000 112 assert status["threshold_tokens"] == 100000 113 assert 0 < status["usage_percent"] <= 100 114 115 def test_on_session_reset(self): 116 engine = StubEngine() 117 engine.last_prompt_tokens = 999 118 engine.compression_count = 3 119 engine.on_session_reset() 120 assert engine.last_prompt_tokens == 0 121 assert engine.compression_count == 0 122 123 def test_should_compress_preflight_default_false(self): 124 engine = StubEngine() 125 assert engine.should_compress_preflight([]) is False 126 127 128 # --------------------------------------------------------------------------- 129 # StubEngine behavior 130 # --------------------------------------------------------------------------- 131 132 class TestStubEngine: 133 134 def test_should_compress(self): 135 engine = StubEngine(context_length=100000, threshold_pct=0.50) 136 assert not engine.should_compress(40000) 137 assert engine.should_compress(50000) 138 assert engine.should_compress(60000) 139 140 def test_compress_tracks_count(self): 141 engine = StubEngine() 142 msgs = [{"role": "user", "content": "hello"}] 143 result = engine.compress(msgs) 144 assert result == msgs 145 assert engine._compress_called 146 assert engine.compression_count == 1 147 148 def test_tool_schemas(self): 149 engine = StubEngine() 150 schemas = engine.get_tool_schemas() 151 assert len(schemas) == 1 152 assert schemas[0]["name"] == "stub_search" 153 154 def test_handle_tool_call(self): 155 engine = StubEngine() 156 result = engine.handle_tool_call("stub_search", {}) 157 assert json.loads(result)["ok"] is True 158 assert "stub_search" in engine._tools_called 159 160 def test_update_from_response(self): 161 engine = StubEngine() 162 engine.update_from_response({"prompt_tokens": 1000, "completion_tokens": 200, "total_tokens": 1200}) 163 assert engine.last_prompt_tokens == 1000 164 assert engine.last_completion_tokens == 200 165 166 167 # --------------------------------------------------------------------------- 168 # ContextCompressor session reset via ABC 169 # --------------------------------------------------------------------------- 170 171 class TestCompressorSessionReset: 172 """Verify ContextCompressor.on_session_reset() clears all state.""" 173 174 def test_reset_clears_state(self): 175 c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000) 176 c.last_prompt_tokens = 50000 177 c.compression_count = 3 178 c._previous_summary = "some old summary" 179 c._context_probed = True 180 c._context_probe_persistable = True 181 182 c.on_session_reset() 183 184 assert c.last_prompt_tokens == 0 185 assert c.last_completion_tokens == 0 186 assert c.last_total_tokens == 0 187 assert c.compression_count == 0 188 assert c._context_probed is False 189 assert c._context_probe_persistable is False 190 assert c._previous_summary is None 191 192 193 # --------------------------------------------------------------------------- 194 # Plugin slot (PluginManager integration) 195 # --------------------------------------------------------------------------- 196 197 class TestPluginContextEngineSlot: 198 """Test register_context_engine on PluginContext.""" 199 200 def test_register_engine(self): 201 from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest 202 mgr = PluginManager() 203 manifest = PluginManifest(name="test-lcm") 204 ctx = PluginContext(manifest, mgr) 205 206 engine = StubEngine() 207 ctx.register_context_engine(engine) 208 209 assert mgr._context_engine is engine 210 assert mgr._context_engine.name == "stub" 211 212 def test_reject_second_engine(self): 213 from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest 214 mgr = PluginManager() 215 manifest = PluginManifest(name="test-lcm") 216 ctx = PluginContext(manifest, mgr) 217 218 engine1 = StubEngine() 219 engine2 = StubEngine() 220 ctx.register_context_engine(engine1) 221 ctx.register_context_engine(engine2) # should be rejected 222 223 assert mgr._context_engine is engine1 224 225 def test_reject_non_engine(self): 226 from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest 227 mgr = PluginManager() 228 manifest = PluginManifest(name="test-bad") 229 ctx = PluginContext(manifest, mgr) 230 231 ctx.register_context_engine("not an engine") 232 assert mgr._context_engine is None 233 234 def test_get_plugin_context_engine(self): 235 from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest, get_plugin_context_engine, _plugin_manager 236 import hermes_cli.plugins as plugins_mod 237 238 # Inject a test manager 239 old_mgr = plugins_mod._plugin_manager 240 try: 241 mgr = PluginManager() 242 plugins_mod._plugin_manager = mgr 243 244 assert get_plugin_context_engine() is None 245 246 engine = StubEngine() 247 mgr._context_engine = engine 248 assert get_plugin_context_engine() is engine 249 finally: 250 plugins_mod._plugin_manager = old_mgr