/ tests / agent / test_compress_focus.py
test_compress_focus.py
  1  """Tests for focus_topic flowing through the compressor.
  2  
  3  Verifies that _generate_summary and compress accept and use the focus_topic
  4  parameter correctly.  Inspired by Claude Code's /compact <focus>.
  5  """
  6  
  7  from unittest.mock import MagicMock, patch
  8  
  9  from agent.context_compressor import ContextCompressor
 10  
 11  
 12  def _make_compressor():
 13      """Create a ContextCompressor with minimal state for testing."""
 14      compressor = ContextCompressor.__new__(ContextCompressor)
 15      compressor.protect_first_n = 2
 16      compressor.protect_last_n = 5
 17      compressor.tail_token_budget = 20000
 18      compressor.context_length = 200000
 19      compressor.threshold_percent = 0.80
 20      compressor.threshold_tokens = 160000
 21      compressor.max_summary_tokens = 10000
 22      compressor.quiet_mode = True
 23      compressor.compression_count = 0
 24      compressor.last_prompt_tokens = 0
 25      compressor._previous_summary = None
 26      compressor._summary_failure_cooldown_until = 0.0
 27      compressor.summary_model = None
 28      compressor.model = "test-model"
 29      compressor.provider = "test"
 30      compressor.base_url = "http://localhost"
 31      compressor.api_key = "test-key"
 32      compressor.api_mode = "chat_completions"
 33      return compressor
 34  
 35  
 36  def test_focus_topic_injected_into_summary_prompt():
 37      """When focus_topic is provided, the LLM prompt includes focus guidance."""
 38      compressor = _make_compressor()
 39      turns = [
 40          {"role": "user", "content": "Tell me about the database schema"},
 41          {"role": "assistant", "content": "The schema has tables: users, orders, products."},
 42      ]
 43  
 44      captured_prompt = {}
 45  
 46      def mock_call_llm(**kwargs):
 47          captured_prompt["messages"] = kwargs["messages"]
 48          resp = MagicMock()
 49          resp.choices = [MagicMock()]
 50          resp.choices[0].message.content = "## Goal\nUnderstand DB schema."
 51          return resp
 52  
 53      with patch("agent.context_compressor.call_llm", mock_call_llm):
 54          result = compressor._generate_summary(turns, focus_topic="database schema")
 55  
 56      assert result is not None
 57      prompt_text = captured_prompt["messages"][0]["content"]
 58      assert 'FOCUS TOPIC: "database schema"' in prompt_text
 59      assert "PRIORITISE" in prompt_text
 60      assert "60-70%" in prompt_text
 61  
 62  
 63  def test_no_focus_topic_no_injection():
 64      """Without focus_topic, the prompt doesn't contain focus guidance."""
 65      compressor = _make_compressor()
 66      turns = [
 67          {"role": "user", "content": "Hello"},
 68          {"role": "assistant", "content": "Hi"},
 69      ]
 70  
 71      captured_prompt = {}
 72  
 73      def mock_call_llm(**kwargs):
 74          captured_prompt["messages"] = kwargs["messages"]
 75          resp = MagicMock()
 76          resp.choices = [MagicMock()]
 77          resp.choices[0].message.content = "## Goal\nGreeting."
 78          return resp
 79  
 80      with patch("agent.context_compressor.call_llm", mock_call_llm):
 81          result = compressor._generate_summary(turns)
 82  
 83      prompt_text = captured_prompt["messages"][0]["content"]
 84      assert "FOCUS TOPIC" not in prompt_text
 85  
 86  
 87  def test_compress_passes_focus_to_generate_summary():
 88      """compress() passes focus_topic through to _generate_summary."""
 89      compressor = _make_compressor()
 90  
 91      # Track what _generate_summary receives
 92      received_kwargs = {}
 93      original_generate = compressor._generate_summary
 94  
 95      def tracking_generate(turns, **kwargs):
 96          received_kwargs.update(kwargs)
 97          return "## Goal\nTest."
 98  
 99      compressor._generate_summary = tracking_generate
100  
101      messages = [
102          {"role": "system", "content": "System prompt"},
103          {"role": "user", "content": "first"},
104          {"role": "assistant", "content": "reply1"},
105          {"role": "user", "content": "second"},
106          {"role": "assistant", "content": "reply2"},
107          {"role": "user", "content": "third"},
108          {"role": "assistant", "content": "reply3"},
109          {"role": "user", "content": "fourth"},
110          {"role": "assistant", "content": "reply4"},
111      ]
112  
113      compressor.compress(messages, current_tokens=100000, focus_topic="authentication flow")
114  
115      assert received_kwargs.get("focus_topic") == "authentication flow"
116  
117  
118  def test_compress_none_focus_by_default():
119      """compress() passes None focus_topic by default."""
120      compressor = _make_compressor()
121  
122      received_kwargs = {}
123  
124      def tracking_generate(turns, **kwargs):
125          received_kwargs.update(kwargs)
126          return "## Goal\nTest."
127  
128      compressor._generate_summary = tracking_generate
129  
130      messages = [
131          {"role": "system", "content": "System prompt"},
132          {"role": "user", "content": "first"},
133          {"role": "assistant", "content": "reply1"},
134          {"role": "user", "content": "second"},
135          {"role": "assistant", "content": "reply2"},
136          {"role": "user", "content": "third"},
137          {"role": "assistant", "content": "reply3"},
138          {"role": "user", "content": "fourth"},
139          {"role": "assistant", "content": "reply4"},
140      ]
141  
142      compressor.compress(messages, current_tokens=100000)
143  
144      assert received_kwargs.get("focus_topic") is None