/ tests / test_trajectory_compressor.py
test_trajectory_compressor.py
  1  """Tests for trajectory_compressor.py — config, metrics, and compression logic."""
  2  
  3  import importlib
  4  import json
  5  import os
  6  import sys
  7  from types import SimpleNamespace
  8  from unittest.mock import AsyncMock, patch, MagicMock
  9  
 10  import pytest
 11  
 12  from trajectory_compressor import (
 13      CompressionConfig,
 14      TrajectoryMetrics,
 15      AggregateMetrics,
 16      TrajectoryCompressor,
 17  )
 18  
 19  
 20  def test_import_loads_env_from_hermes_home(tmp_path, monkeypatch):
 21      home = tmp_path / ".hermes"
 22      home.mkdir()
 23      (home / ".env").write_text("OPENROUTER_API_KEY=from-hermes-home\n", encoding="utf-8")
 24  
 25      monkeypatch.setenv("HERMES_HOME", str(home))
 26      monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
 27  
 28      sys.modules.pop("trajectory_compressor", None)
 29      importlib.import_module("trajectory_compressor")
 30  
 31      assert os.getenv("OPENROUTER_API_KEY") == "from-hermes-home"
 32  
 33  
 34  def test_generate_summary_kimi_omits_temperature():
 35      """Kimi models should have temperature omitted — server manages it."""
 36      config = CompressionConfig(
 37          summarization_model="kimi-for-coding",
 38          temperature=0.3,
 39          summary_target_tokens=100,
 40          max_retries=1,
 41      )
 42      compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
 43      compressor.config = config
 44      compressor.logger = MagicMock()
 45      compressor._use_call_llm = False
 46      compressor.client = MagicMock()
 47      compressor.client.chat.completions.create.return_value = SimpleNamespace(
 48          choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
 49      )
 50  
 51      metrics = TrajectoryMetrics()
 52      result = compressor._generate_summary("tool output", metrics)
 53  
 54      assert result.startswith("[CONTEXT SUMMARY]:")
 55      assert "temperature" not in compressor.client.chat.completions.create.call_args.kwargs
 56  
 57  
 58  def test_generate_summary_public_moonshot_kimi_k2_5_omits_temperature():
 59      """kimi-k2.5 on the public Moonshot API should not get a forced temperature."""
 60      config = CompressionConfig(
 61          summarization_model="kimi-k2.5",
 62          base_url="https://api.moonshot.ai/v1",
 63          temperature=0.3,
 64          summary_target_tokens=100,
 65          max_retries=1,
 66      )
 67      compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
 68      compressor.config = config
 69      compressor.logger = MagicMock()
 70      compressor._use_call_llm = False
 71      compressor.client = MagicMock()
 72      compressor.client.chat.completions.create.return_value = SimpleNamespace(
 73          choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
 74      )
 75  
 76      metrics = TrajectoryMetrics()
 77      result = compressor._generate_summary("tool output", metrics)
 78  
 79      assert result.startswith("[CONTEXT SUMMARY]:")
 80      assert "temperature" not in compressor.client.chat.completions.create.call_args.kwargs
 81  
 82  
 83  def test_generate_summary_public_moonshot_cn_kimi_k2_5_omits_temperature():
 84      """kimi-k2.5 on api.moonshot.cn should not get a forced temperature."""
 85      config = CompressionConfig(
 86          summarization_model="kimi-k2.5",
 87          base_url="https://api.moonshot.cn/v1",
 88          temperature=0.3,
 89          summary_target_tokens=100,
 90          max_retries=1,
 91      )
 92      compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
 93      compressor.config = config
 94      compressor.logger = MagicMock()
 95      compressor._use_call_llm = False
 96      compressor.client = MagicMock()
 97      compressor.client.chat.completions.create.return_value = SimpleNamespace(
 98          choices=[SimpleNamespace(message=SimpleNamespace(content="[CONTEXT SUMMARY]: summary"))]
 99      )
100  
101      metrics = TrajectoryMetrics()
102      result = compressor._generate_summary("tool output", metrics)
103  
104      assert result.startswith("[CONTEXT SUMMARY]:")
105      assert "temperature" not in compressor.client.chat.completions.create.call_args.kwargs
106  
107  
108  # ---------------------------------------------------------------------------
109  # CompressionConfig
110  # ---------------------------------------------------------------------------
111  
112  
113  class TestCompressionConfig:
114      def test_defaults(self):
115          config = CompressionConfig()
116          assert config.target_max_tokens == 15250
117          assert config.summary_target_tokens == 750
118          assert config.protect_last_n_turns == 4
119          assert config.skip_under_target is True
120  
121      def test_from_yaml(self, tmp_path):
122          yaml_content = """\
123  tokenizer:
124    name: custom-tokenizer
125    trust_remote_code: false
126  compression:
127    target_max_tokens: 10000
128    summary_target_tokens: 500
129  protected_turns:
130    first_system: true
131    first_human: false
132    last_n_turns: 6
133  summarization:
134    model: gpt-4
135    temperature: 0.5
136    max_retries: 5
137  output:
138    add_summary_notice: false
139    output_suffix: _short
140  processing:
141    num_workers: 8
142    max_concurrent_requests: 100
143    skip_under_target: false
144    save_over_limit: false
145  metrics:
146    enabled: false
147    per_trajectory: false
148    output_file: my_metrics.json
149  """
150          yaml_file = tmp_path / "config.yaml"
151          yaml_file.write_text(yaml_content)
152          config = CompressionConfig.from_yaml(str(yaml_file))
153          assert config.tokenizer_name == "custom-tokenizer"
154          assert config.trust_remote_code is False
155          assert config.target_max_tokens == 10000
156          assert config.summary_target_tokens == 500
157          assert config.protect_first_human is False
158          assert config.protect_last_n_turns == 6
159          assert config.summarization_model == "gpt-4"
160          assert config.temperature == 0.5
161          assert config.max_retries == 5
162          assert config.add_summary_notice is False
163          assert config.output_suffix == "_short"
164          assert config.num_workers == 8
165          assert config.max_concurrent_requests == 100
166          assert config.skip_under_target is False
167          assert config.save_over_limit is False
168          assert config.metrics_enabled is False
169          assert config.metrics_output_file == "my_metrics.json"
170  
171      def test_from_yaml_partial(self, tmp_path):
172          """Only specified sections override defaults."""
173          yaml_file = tmp_path / "config.yaml"
174          yaml_file.write_text("compression:\n  target_max_tokens: 8000\n")
175          config = CompressionConfig.from_yaml(str(yaml_file))
176          assert config.target_max_tokens == 8000
177          # Other sections keep defaults
178          assert config.protect_last_n_turns == 4
179          assert config.num_workers == 4
180  
181      def test_from_yaml_empty(self, tmp_path):
182          yaml_file = tmp_path / "config.yaml"
183          yaml_file.write_text("{}\n")
184          config = CompressionConfig.from_yaml(str(yaml_file))
185          assert config.target_max_tokens == 15250  # all defaults
186  
187  
188  # ---------------------------------------------------------------------------
189  # TrajectoryMetrics
190  # ---------------------------------------------------------------------------
191  
192  
193  class TestTrajectoryMetrics:
194      def test_to_dict(self):
195          m = TrajectoryMetrics()
196          m.original_tokens = 10000
197          m.compressed_tokens = 5000
198          m.tokens_saved = 5000
199          m.compression_ratio = 0.5
200          m.original_turns = 20
201          m.compressed_turns = 10
202          m.turns_removed = 10
203          m.was_compressed = True
204          d = m.to_dict()
205          assert d["original_tokens"] == 10000
206          assert d["compressed_tokens"] == 5000
207          assert d["compression_ratio"] == 0.5
208          assert d["was_compressed"] is True
209          assert d["compression_region"]["start_idx"] == -1
210  
211      def test_default_values(self):
212          m = TrajectoryMetrics()
213          d = m.to_dict()
214          assert d["original_tokens"] == 0
215          assert d["was_compressed"] is False
216          assert d["skipped_under_target"] is False
217  
218  
219  # ---------------------------------------------------------------------------
220  # AggregateMetrics
221  # ---------------------------------------------------------------------------
222  
223  
224  class TestAggregateMetrics:
225      def test_empty_to_dict(self):
226          agg = AggregateMetrics()
227          d = agg.to_dict()
228          assert d["summary"]["total_trajectories"] == 0
229          assert d["averages"]["avg_compression_ratio"] == 1.0
230          assert d["averages"]["avg_tokens_saved_per_compressed"] == 0
231  
232      def test_add_compressed_trajectory(self):
233          agg = AggregateMetrics()
234          m = TrajectoryMetrics()
235          m.original_tokens = 20000
236          m.compressed_tokens = 10000
237          m.tokens_saved = 10000
238          m.compression_ratio = 0.5
239          m.original_turns = 30
240          m.compressed_turns = 15
241          m.turns_removed = 15
242          m.was_compressed = True
243          agg.add_trajectory_metrics(m)
244          assert agg.total_trajectories == 1
245          assert agg.trajectories_compressed == 1
246          assert agg.total_tokens_saved == 10000
247          assert len(agg.compression_ratios) == 1
248  
249      def test_add_skipped_trajectory(self):
250          agg = AggregateMetrics()
251          m = TrajectoryMetrics()
252          m.original_tokens = 5000
253          m.compressed_tokens = 5000
254          m.skipped_under_target = True
255          agg.add_trajectory_metrics(m)
256          assert agg.trajectories_skipped_under_target == 1
257          assert agg.trajectories_compressed == 0
258  
259      def test_add_over_limit_trajectory(self):
260          agg = AggregateMetrics()
261          m = TrajectoryMetrics()
262          m.original_tokens = 20000
263          m.compressed_tokens = 16000
264          m.still_over_limit = True
265          m.was_compressed = True
266          m.compression_ratio = 0.8
267          agg.add_trajectory_metrics(m)
268          assert agg.trajectories_still_over_limit == 1
269  
270      def test_multiple_trajectories_aggregation(self):
271          agg = AggregateMetrics()
272          for i in range(3):
273              m = TrajectoryMetrics()
274              m.original_tokens = 10000
275              m.compressed_tokens = 5000
276              m.tokens_saved = 5000
277              m.turns_removed = 5
278              m.was_compressed = True
279              m.compression_ratio = 0.5
280              agg.add_trajectory_metrics(m)
281          d = agg.to_dict()
282          assert d["summary"]["total_trajectories"] == 3
283          assert d["summary"]["trajectories_compressed"] == 3
284          assert d["tokens"]["total_saved"] == 15000
285          assert d["averages"]["avg_compression_ratio"] == 0.5
286  
287      def test_to_dict_no_division_by_zero(self):
288          """Ensure no ZeroDivisionError with empty data."""
289          agg = AggregateMetrics()
290          d = agg.to_dict()
291          assert d["summarization"]["success_rate"] == 1.0
292          assert d["tokens"]["overall_compression_ratio"] == 0.0
293  
294  
295  # ---------------------------------------------------------------------------
296  # TrajectoryCompressor._find_protected_indices
297  # ---------------------------------------------------------------------------
298  
299  
300  def _make_compressor(config=None):
301      """Create a TrajectoryCompressor with mocked tokenizer and summarizer."""
302      if config is None:
303          config = CompressionConfig()
304      with patch.object(TrajectoryCompressor, '_init_tokenizer'), \
305           patch.object(TrajectoryCompressor, '_init_summarizer'):
306          compressor = TrajectoryCompressor(config)
307      # Provide a simple token counter for tests (1 token per 4 chars)
308      compressor.tokenizer = MagicMock()
309      compressor.tokenizer.encode = lambda text: [0] * (len(text) // 4)
310      return compressor
311  
312  
313  class TestFindProtectedIndices:
314      def test_basic_trajectory(self):
315          tc = _make_compressor()
316          trajectory = [
317              {"from": "system", "value": "You are an agent."},
318              {"from": "human", "value": "Do something."},
319              {"from": "gpt", "value": "I will use a tool."},
320              {"from": "tool", "value": "Tool result."},
321              {"from": "gpt", "value": "More work."},
322              {"from": "tool", "value": "Another result."},
323              {"from": "gpt", "value": "Work continues."},
324              {"from": "tool", "value": "Result 3."},
325              {"from": "gpt", "value": "Done."},
326              {"from": "human", "value": "Thanks."},
327          ]
328          protected, start, end = tc._find_protected_indices(trajectory)
329          # First system (0), human (1), gpt (2), tool (3) are protected
330          assert 0 in protected
331          assert 1 in protected
332          assert 2 in protected
333          assert 3 in protected
334          # Last 4 turns (6,7,8,9) are protected
335          assert 6 in protected
336          assert 7 in protected
337          assert 8 in protected
338          assert 9 in protected
339          # Compressible region should be between head and tail
340          assert start >= 4
341          assert end <= 6
342  
343      def test_short_trajectory_all_protected(self):
344          tc = _make_compressor()
345          trajectory = [
346              {"from": "system", "value": "sys"},
347              {"from": "human", "value": "hi"},
348              {"from": "gpt", "value": "hello"},
349          ]
350          protected, start, end = tc._find_protected_indices(trajectory)
351          # All 3 turns should be protected (first of each + last 4 covers all)
352          assert len(protected) == 3
353          assert start >= end  # Nothing to compress
354  
355      def test_protect_last_n_zero(self):
356          config = CompressionConfig()
357          config.protect_last_n_turns = 0
358          tc = _make_compressor(config)
359          trajectory = [
360              {"from": "system", "value": "sys"},
361              {"from": "human", "value": "q"},
362              {"from": "gpt", "value": "a"},
363              {"from": "tool", "value": "r"},
364              {"from": "gpt", "value": "b"},
365              {"from": "tool", "value": "r2"},
366              {"from": "gpt", "value": "c"},
367              {"from": "tool", "value": "r3"},
368          ]
369          protected, start, end = tc._find_protected_indices(trajectory)
370          # Only first occurrences protected, no tail protection
371          assert 0 in protected
372          assert 1 in protected
373          assert 2 in protected
374          assert 3 in protected
375          assert 7 not in protected
376  
377      def test_no_system_turn(self):
378          tc = _make_compressor()
379          trajectory = [
380              {"from": "human", "value": "hi"},
381              {"from": "gpt", "value": "hello"},
382              {"from": "tool", "value": "data"},
383              {"from": "gpt", "value": "result"},
384              {"from": "human", "value": "thanks"},
385          ]
386          protected, start, end = tc._find_protected_indices(trajectory)
387          assert 0 in protected  # first human
388  
389      def test_disable_protect_first_system(self):
390          config = CompressionConfig()
391          config.protect_first_system = False
392          tc = _make_compressor(config)
393          trajectory = [
394              {"from": "system", "value": "sys"},
395              {"from": "human", "value": "q"},
396              {"from": "gpt", "value": "a"},
397              {"from": "tool", "value": "r"},
398              {"from": "gpt", "value": "b"},
399              {"from": "tool", "value": "r2"},
400              {"from": "gpt", "value": "c"},
401              {"from": "tool", "value": "r3"},
402          ]
403          protected, _, _ = tc._find_protected_indices(trajectory)
404          assert 0 not in protected  # system not protected
405  
406  
407  # ---------------------------------------------------------------------------
408  # TrajectoryCompressor._extract_turn_content_for_summary
409  # ---------------------------------------------------------------------------
410  
411  
412  class TestExtractTurnContent:
413      def test_basic_extraction(self):
414          tc = _make_compressor()
415          trajectory = [
416              {"from": "gpt", "value": "I will search."},
417              {"from": "tool", "value": "Search result: found it."},
418              {"from": "gpt", "value": "Great, done."},
419          ]
420          content = tc._extract_turn_content_for_summary(trajectory, 0, 2)
421          assert "[Turn 0 - GPT]" in content
422          assert "I will search." in content
423          assert "[Turn 1 - TOOL]" in content
424          assert "Search result: found it." in content
425          # Turn 2 should NOT be included (end is exclusive)
426          assert "[Turn 2" not in content
427  
428      def test_long_content_truncated(self):
429          tc = _make_compressor()
430          trajectory = [
431              {"from": "tool", "value": "x" * 5000},
432          ]
433          content = tc._extract_turn_content_for_summary(trajectory, 0, 1)
434          assert "...[truncated]..." in content
435          assert len(content) < 5000
436  
437      def test_empty_range(self):
438          tc = _make_compressor()
439          trajectory = [{"from": "gpt", "value": "hello"}]
440          content = tc._extract_turn_content_for_summary(trajectory, 0, 0)
441          assert content == ""
442  
443  
444  # ---------------------------------------------------------------------------
445  # TrajectoryCompressor.count_tokens / count_trajectory_tokens
446  # ---------------------------------------------------------------------------
447  
448  
449  class TestTokenCounting:
450      def test_count_tokens_empty(self):
451          tc = _make_compressor()
452          assert tc.count_tokens("") == 0
453  
454      def test_count_tokens_basic(self):
455          tc = _make_compressor()
456          # Our mock: 1 token per 4 chars
457          assert tc.count_tokens("12345678") == 2
458  
459      def test_count_trajectory_tokens(self):
460          tc = _make_compressor()
461          trajectory = [
462              {"from": "system", "value": "12345678"},   # 2 tokens
463              {"from": "human", "value": "1234567890ab"}, # 3 tokens
464          ]
465          assert tc.count_trajectory_tokens(trajectory) == 5
466  
467      def test_count_turn_tokens(self):
468          tc = _make_compressor()
469          trajectory = [
470              {"from": "system", "value": "1234"},     # 1 token
471              {"from": "human", "value": "12345678"},  # 2 tokens
472          ]
473          result = tc.count_turn_tokens(trajectory)
474          assert result == [1, 2]
475  
476      def test_count_tokens_fallback_on_error(self):
477          tc = _make_compressor()
478          tc.tokenizer.encode = MagicMock(side_effect=Exception("fail"))
479          # Should fallback to len(text) // 4
480          assert tc.count_tokens("12345678") == 2
481  
482  
483  class TestGenerateSummary:
484      def test_generate_summary_handles_none_content(self):
485          tc = _make_compressor()
486          tc.client = MagicMock()
487          tc.client.chat.completions.create.return_value = SimpleNamespace(
488              choices=[SimpleNamespace(message=SimpleNamespace(content=None))]
489          )
490          metrics = TrajectoryMetrics()
491  
492          summary = tc._generate_summary("Turn content", metrics)
493  
494          assert summary == "[CONTEXT SUMMARY]:"
495  
496      @pytest.mark.asyncio
497      async def test_generate_summary_async_handles_none_content(self):
498          tc = _make_compressor()
499          mock_client = MagicMock()
500          mock_client.chat.completions.create = AsyncMock(
501              return_value=SimpleNamespace(
502                  choices=[SimpleNamespace(message=SimpleNamespace(content=None))]
503              )
504          )
505          tc._get_async_client = MagicMock(return_value=mock_client)
506          metrics = TrajectoryMetrics()
507  
508          summary = await tc._generate_summary_async("Turn content", metrics)
509  
510          assert summary == "[CONTEXT SUMMARY]:"