/ tests / agent / test_context_engine.py
test_context_engine.py
  1  """Tests for the ContextEngine ABC and plugin slot."""
  2  
  3  import json
  4  import pytest
  5  from typing import Any, Dict, List
  6  
  7  from agent.context_engine import ContextEngine
  8  from agent.context_compressor import ContextCompressor
  9  
 10  
 11  # ---------------------------------------------------------------------------
 12  # A minimal concrete engine for testing the ABC
 13  # ---------------------------------------------------------------------------
 14  
 15  class StubEngine(ContextEngine):
 16      """Minimal engine that satisfies the ABC without doing real work."""
 17  
 18      def __init__(self, context_length=200000, threshold_pct=0.50):
 19          self.context_length = context_length
 20          self.threshold_tokens = int(context_length * threshold_pct)
 21          self._compress_called = False
 22          self._tools_called = []
 23  
 24      @property
 25      def name(self) -> str:
 26          return "stub"
 27  
 28      def update_from_response(self, usage: Dict[str, Any]) -> None:
 29          self.last_prompt_tokens = usage.get("prompt_tokens", 0)
 30          self.last_completion_tokens = usage.get("completion_tokens", 0)
 31          self.last_total_tokens = usage.get("total_tokens", 0)
 32  
 33      def should_compress(self, prompt_tokens: int = None) -> bool:
 34          tokens = prompt_tokens if prompt_tokens is not None else self.last_prompt_tokens
 35          return tokens >= self.threshold_tokens
 36  
 37      def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
 38          self._compress_called = True
 39          self.compression_count += 1
 40          # Trivial: just return as-is
 41          return messages
 42  
 43      def get_tool_schemas(self) -> List[Dict[str, Any]]:
 44          return [
 45              {
 46                  "name": "stub_search",
 47                  "description": "Search the stub engine",
 48                  "parameters": {"type": "object", "properties": {}},
 49              }
 50          ]
 51  
 52      def handle_tool_call(self, name: str, args: Dict[str, Any]) -> str:
 53          self._tools_called.append(name)
 54          return json.dumps({"ok": True, "tool": name})
 55  
 56  
 57  # ---------------------------------------------------------------------------
 58  # ABC contract tests
 59  # ---------------------------------------------------------------------------
 60  
 61  class TestContextEngineABC:
 62      """Verify the ABC enforces the required interface."""
 63  
 64      def test_cannot_instantiate_abc_directly(self):
 65          with pytest.raises(TypeError):
 66              ContextEngine()
 67  
 68      def test_missing_methods_raises(self):
 69          """A subclass missing required methods cannot be instantiated."""
 70          class Incomplete(ContextEngine):
 71              @property
 72              def name(self):
 73                  return "incomplete"
 74          with pytest.raises(TypeError):
 75              Incomplete()
 76  
 77      def test_stub_engine_satisfies_abc(self):
 78          engine = StubEngine()
 79          assert isinstance(engine, ContextEngine)
 80          assert engine.name == "stub"
 81  
 82      def test_compressor_is_context_engine(self):
 83          c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
 84          assert isinstance(c, ContextEngine)
 85          assert c.name == "compressor"
 86  
 87  
 88  # ---------------------------------------------------------------------------
 89  # Default method behavior
 90  # ---------------------------------------------------------------------------
 91  
 92  class TestDefaults:
 93      """Verify ABC default implementations work correctly."""
 94  
 95      def test_default_tool_schemas_empty(self):
 96          engine = StubEngine()
 97          # StubEngine overrides this, so test the base via super
 98          assert ContextEngine.get_tool_schemas(engine) == []
 99  
100      def test_default_handle_tool_call_returns_error(self):
101          engine = StubEngine()
102          result = ContextEngine.handle_tool_call(engine, "unknown", {})
103          data = json.loads(result)
104          assert "error" in data
105  
106      def test_default_get_status(self):
107          engine = StubEngine()
108          engine.last_prompt_tokens = 50000
109          status = engine.get_status()
110          assert status["last_prompt_tokens"] == 50000
111          assert status["context_length"] == 200000
112          assert status["threshold_tokens"] == 100000
113          assert 0 < status["usage_percent"] <= 100
114  
115      def test_on_session_reset(self):
116          engine = StubEngine()
117          engine.last_prompt_tokens = 999
118          engine.compression_count = 3
119          engine.on_session_reset()
120          assert engine.last_prompt_tokens == 0
121          assert engine.compression_count == 0
122  
123      def test_should_compress_preflight_default_false(self):
124          engine = StubEngine()
125          assert engine.should_compress_preflight([]) is False
126  
127  
128  # ---------------------------------------------------------------------------
129  # StubEngine behavior
130  # ---------------------------------------------------------------------------
131  
132  class TestStubEngine:
133  
134      def test_should_compress(self):
135          engine = StubEngine(context_length=100000, threshold_pct=0.50)
136          assert not engine.should_compress(40000)
137          assert engine.should_compress(50000)
138          assert engine.should_compress(60000)
139  
140      def test_compress_tracks_count(self):
141          engine = StubEngine()
142          msgs = [{"role": "user", "content": "hello"}]
143          result = engine.compress(msgs)
144          assert result == msgs
145          assert engine._compress_called
146          assert engine.compression_count == 1
147  
148      def test_tool_schemas(self):
149          engine = StubEngine()
150          schemas = engine.get_tool_schemas()
151          assert len(schemas) == 1
152          assert schemas[0]["name"] == "stub_search"
153  
154      def test_handle_tool_call(self):
155          engine = StubEngine()
156          result = engine.handle_tool_call("stub_search", {})
157          assert json.loads(result)["ok"] is True
158          assert "stub_search" in engine._tools_called
159  
160      def test_update_from_response(self):
161          engine = StubEngine()
162          engine.update_from_response({"prompt_tokens": 1000, "completion_tokens": 200, "total_tokens": 1200})
163          assert engine.last_prompt_tokens == 1000
164          assert engine.last_completion_tokens == 200
165  
166  
167  # ---------------------------------------------------------------------------
168  # ContextCompressor session reset via ABC
169  # ---------------------------------------------------------------------------
170  
171  class TestCompressorSessionReset:
172      """Verify ContextCompressor.on_session_reset() clears all state."""
173  
174      def test_reset_clears_state(self):
175          c = ContextCompressor(model="test", quiet_mode=True, config_context_length=200000)
176          c.last_prompt_tokens = 50000
177          c.compression_count = 3
178          c._previous_summary = "some old summary"
179          c._context_probed = True
180          c._context_probe_persistable = True
181  
182          c.on_session_reset()
183  
184          assert c.last_prompt_tokens == 0
185          assert c.last_completion_tokens == 0
186          assert c.last_total_tokens == 0
187          assert c.compression_count == 0
188          assert c._context_probed is False
189          assert c._context_probe_persistable is False
190          assert c._previous_summary is None
191  
192  
193  # ---------------------------------------------------------------------------
194  # Plugin slot (PluginManager integration)
195  # ---------------------------------------------------------------------------
196  
197  class TestPluginContextEngineSlot:
198      """Test register_context_engine on PluginContext."""
199  
200      def test_register_engine(self):
201          from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest
202          mgr = PluginManager()
203          manifest = PluginManifest(name="test-lcm")
204          ctx = PluginContext(manifest, mgr)
205  
206          engine = StubEngine()
207          ctx.register_context_engine(engine)
208  
209          assert mgr._context_engine is engine
210          assert mgr._context_engine.name == "stub"
211  
212      def test_reject_second_engine(self):
213          from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest
214          mgr = PluginManager()
215          manifest = PluginManifest(name="test-lcm")
216          ctx = PluginContext(manifest, mgr)
217  
218          engine1 = StubEngine()
219          engine2 = StubEngine()
220          ctx.register_context_engine(engine1)
221          ctx.register_context_engine(engine2)  # should be rejected
222  
223          assert mgr._context_engine is engine1
224  
225      def test_reject_non_engine(self):
226          from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest
227          mgr = PluginManager()
228          manifest = PluginManifest(name="test-bad")
229          ctx = PluginContext(manifest, mgr)
230  
231          ctx.register_context_engine("not an engine")
232          assert mgr._context_engine is None
233  
234      def test_get_plugin_context_engine(self):
235          from hermes_cli.plugins import PluginManager, PluginContext, PluginManifest, get_plugin_context_engine, _plugin_manager
236          import hermes_cli.plugins as plugins_mod
237  
238          # Inject a test manager
239          old_mgr = plugins_mod._plugin_manager
240          try:
241              mgr = PluginManager()
242              plugins_mod._plugin_manager = mgr
243  
244              assert get_plugin_context_engine() is None
245  
246              engine = StubEngine()
247              mgr._context_engine = engine
248              assert get_plugin_context_engine() is engine
249          finally:
250              plugins_mod._plugin_manager = old_mgr