test_compress_focus.py
1 """Tests for focus_topic flowing through the compressor. 2 3 Verifies that _generate_summary and compress accept and use the focus_topic 4 parameter correctly. Inspired by Claude Code's /compact <focus>. 5 """ 6 7 from unittest.mock import MagicMock, patch 8 9 from agent.context_compressor import ContextCompressor 10 11 12 def _make_compressor(): 13 """Create a ContextCompressor with minimal state for testing.""" 14 compressor = ContextCompressor.__new__(ContextCompressor) 15 compressor.protect_first_n = 2 16 compressor.protect_last_n = 5 17 compressor.tail_token_budget = 20000 18 compressor.context_length = 200000 19 compressor.threshold_percent = 0.80 20 compressor.threshold_tokens = 160000 21 compressor.max_summary_tokens = 10000 22 compressor.quiet_mode = True 23 compressor.compression_count = 0 24 compressor.last_prompt_tokens = 0 25 compressor._previous_summary = None 26 compressor._summary_failure_cooldown_until = 0.0 27 compressor.summary_model = None 28 compressor.model = "test-model" 29 compressor.provider = "test" 30 compressor.base_url = "http://localhost" 31 compressor.api_key = "test-key" 32 compressor.api_mode = "chat_completions" 33 return compressor 34 35 36 def test_focus_topic_injected_into_summary_prompt(): 37 """When focus_topic is provided, the LLM prompt includes focus guidance.""" 38 compressor = _make_compressor() 39 turns = [ 40 {"role": "user", "content": "Tell me about the database schema"}, 41 {"role": "assistant", "content": "The schema has tables: users, orders, products."}, 42 ] 43 44 captured_prompt = {} 45 46 def mock_call_llm(**kwargs): 47 captured_prompt["messages"] = kwargs["messages"] 48 resp = MagicMock() 49 resp.choices = [MagicMock()] 50 resp.choices[0].message.content = "## Goal\nUnderstand DB schema." 51 return resp 52 53 with patch("agent.context_compressor.call_llm", mock_call_llm): 54 result = compressor._generate_summary(turns, focus_topic="database schema") 55 56 assert result is not None 57 prompt_text = captured_prompt["messages"][0]["content"] 58 assert 'FOCUS TOPIC: "database schema"' in prompt_text 59 assert "PRIORITISE" in prompt_text 60 assert "60-70%" in prompt_text 61 62 63 def test_no_focus_topic_no_injection(): 64 """Without focus_topic, the prompt doesn't contain focus guidance.""" 65 compressor = _make_compressor() 66 turns = [ 67 {"role": "user", "content": "Hello"}, 68 {"role": "assistant", "content": "Hi"}, 69 ] 70 71 captured_prompt = {} 72 73 def mock_call_llm(**kwargs): 74 captured_prompt["messages"] = kwargs["messages"] 75 resp = MagicMock() 76 resp.choices = [MagicMock()] 77 resp.choices[0].message.content = "## Goal\nGreeting." 78 return resp 79 80 with patch("agent.context_compressor.call_llm", mock_call_llm): 81 result = compressor._generate_summary(turns) 82 83 prompt_text = captured_prompt["messages"][0]["content"] 84 assert "FOCUS TOPIC" not in prompt_text 85 86 87 def test_compress_passes_focus_to_generate_summary(): 88 """compress() passes focus_topic through to _generate_summary.""" 89 compressor = _make_compressor() 90 91 # Track what _generate_summary receives 92 received_kwargs = {} 93 original_generate = compressor._generate_summary 94 95 def tracking_generate(turns, **kwargs): 96 received_kwargs.update(kwargs) 97 return "## Goal\nTest." 98 99 compressor._generate_summary = tracking_generate 100 101 messages = [ 102 {"role": "system", "content": "System prompt"}, 103 {"role": "user", "content": "first"}, 104 {"role": "assistant", "content": "reply1"}, 105 {"role": "user", "content": "second"}, 106 {"role": "assistant", "content": "reply2"}, 107 {"role": "user", "content": "third"}, 108 {"role": "assistant", "content": "reply3"}, 109 {"role": "user", "content": "fourth"}, 110 {"role": "assistant", "content": "reply4"}, 111 ] 112 113 compressor.compress(messages, current_tokens=100000, focus_topic="authentication flow") 114 115 assert received_kwargs.get("focus_topic") == "authentication flow" 116 117 118 def test_compress_none_focus_by_default(): 119 """compress() passes None focus_topic by default.""" 120 compressor = _make_compressor() 121 122 received_kwargs = {} 123 124 def tracking_generate(turns, **kwargs): 125 received_kwargs.update(kwargs) 126 return "## Goal\nTest." 127 128 compressor._generate_summary = tracking_generate 129 130 messages = [ 131 {"role": "system", "content": "System prompt"}, 132 {"role": "user", "content": "first"}, 133 {"role": "assistant", "content": "reply1"}, 134 {"role": "user", "content": "second"}, 135 {"role": "assistant", "content": "reply2"}, 136 {"role": "user", "content": "third"}, 137 {"role": "assistant", "content": "reply3"}, 138 {"role": "user", "content": "fourth"}, 139 {"role": "assistant", "content": "reply4"}, 140 ] 141 142 compressor.compress(messages, current_tokens=100000) 143 144 assert received_kwargs.get("focus_topic") is None