test_agent2_compression.py
1 """Tests for restai.agent2.compression — token counting, turn boundaries, splitting, truncation.""" 2 import asyncio 3 4 from restai.agent2.compression import ( 5 _approx_char_count, 6 compress_session, 7 count_session_tokens, 8 find_user_turn_boundaries, 9 hard_truncate, 10 split_for_compression, 11 ) 12 from restai.agent2.types import ( 13 AgentSession, 14 Message, 15 TextBlock, 16 ToolResultBlock, 17 ToolUseBlock, 18 user_text_message, 19 ) 20 21 22 def _user(text: str) -> Message: 23 return user_text_message(text) 24 25 26 def _assistant(text: str) -> Message: 27 return Message(role="assistant", content=[TextBlock(text=text)]) 28 29 30 def _assistant_tool_use() -> Message: 31 return Message( 32 role="assistant", 33 content=[ToolUseBlock(id="t1", name="search", input={"q": "test"})], 34 ) 35 36 37 def _user_tool_result() -> Message: 38 return Message( 39 role="user", 40 content=[ToolResultBlock(tool_use_id="t1", content="result")], 41 ) 42 43 44 # ---------- count_session_tokens ---------- 45 46 47 def test_count_session_tokens_nonempty(): 48 msgs = [_user("Hello, how are you?"), _assistant("I am fine, thank you.")] 49 tokens = count_session_tokens(msgs) 50 assert tokens > 0 51 52 53 def test_count_session_tokens_empty(): 54 assert count_session_tokens([]) == 0 55 56 57 # ---------- find_user_turn_boundaries ---------- 58 59 60 def test_find_user_turn_boundaries_basic(): 61 msgs = [ 62 _user("q1"), # 0 - turn boundary 63 _assistant("a1"), # 1 64 _user("q2"), # 2 - turn boundary 65 _assistant("a2"), # 3 66 _user("q3"), # 4 - turn boundary 67 ] 68 boundaries = find_user_turn_boundaries(msgs) 69 assert boundaries == [0, 2, 4] 70 71 72 def test_find_user_turn_boundaries_skips_tool_result_user(): 73 msgs = [ 74 _user("q1"), # 0 - turn boundary 75 _assistant_tool_use(), # 1 76 _user_tool_result(), # 2 - NOT a turn boundary (contains ToolResultBlock) 77 _assistant("a1"), # 3 78 _user("q2"), # 4 - turn boundary 79 ] 80 boundaries = find_user_turn_boundaries(msgs) 81 assert boundaries == [0, 4] 82 83 84 def test_find_user_turn_boundaries_empty(): 85 assert find_user_turn_boundaries([]) == [] 86 87 88 # ---------- split_for_compression ---------- 89 90 91 def test_split_for_compression_enough_turns(): 92 msgs = [ 93 _user("q1"), _assistant("a1"), 94 _user("q2"), _assistant("a2"), 95 _user("q3"), _assistant("a3"), 96 _user("q4"), _assistant("a4"), 97 _user("q5"), _assistant("a5"), 98 ] 99 to_compress, to_keep = split_for_compression(msgs, keep_n_turns=3) 100 assert len(to_compress) > 0 101 assert len(to_keep) > 0 102 # The kept slice should start at the 3rd-from-last turn boundary 103 assert to_keep[0].role == "user" 104 assert to_compress + to_keep == msgs 105 106 107 def test_split_for_compression_not_enough_turns(): 108 msgs = [_user("q1"), _assistant("a1")] 109 to_compress, to_keep = split_for_compression(msgs, keep_n_turns=3) 110 assert to_compress == [] 111 assert to_keep == msgs 112 113 114 def test_split_for_compression_exact_keep_n(): 115 msgs = [ 116 _user("q1"), _assistant("a1"), 117 _user("q2"), _assistant("a2"), 118 _user("q3"), _assistant("a3"), 119 ] 120 to_compress, to_keep = split_for_compression(msgs, keep_n_turns=3) 121 # 3 turns, keep_n=3 → not enough to compress (need strictly more) 122 assert to_compress == [] 123 assert to_keep == msgs 124 125 126 # ---------- hard_truncate ---------- 127 128 129 def test_hard_truncate_drops_oldest(): 130 msgs = [ 131 _user("A" * 1000), _assistant("B" * 1000), 132 _user("C" * 1000), _assistant("D" * 1000), 133 _user("short"), _assistant("reply"), 134 ] 135 total_before = count_session_tokens(msgs) 136 # Set a budget well below the total but enough for the last turn 137 budget = count_session_tokens([_user("short"), _assistant("reply")]) + 50 138 truncated = hard_truncate(msgs, budget, keep_n_turns=1) 139 assert len(truncated) < len(msgs) 140 total_after = count_session_tokens(truncated) 141 assert total_after <= budget 142 143 144 def test_hard_truncate_empty(): 145 assert hard_truncate([], 100, 3) == [] 146 147 148 # ---------- _approx_char_count ---------- 149 150 151 def test_approx_char_count_text(): 152 msgs = [_user("hello"), _assistant("world")] 153 count = _approx_char_count(msgs) 154 assert count >= len("hello") + len("world") 155 156 157 def test_approx_char_count_tool_result(): 158 msg = Message( 159 role="user", 160 content=[ToolResultBlock(tool_use_id="t1", content="result text")], 161 ) 162 count = _approx_char_count([msg]) 163 assert count >= len("result text") 164 165 166 # ---------- compress_session ---------- 167 168 169 def test_compress_session_returns_false_no_context_window(): 170 session = AgentSession(messages=[_user("hello")]) 171 result = asyncio.run( 172 compress_session( 173 session, 174 provider=None, 175 config=None, 176 context_window=None, 177 ) 178 ) 179 assert result is False 180 181 182 def test_compress_session_returns_false_zero_context_window(): 183 session = AgentSession(messages=[_user("hello")]) 184 result = asyncio.run( 185 compress_session( 186 session, 187 provider=None, 188 config=None, 189 context_window=0, 190 ) 191 ) 192 assert result is False 193 194 195 def test_compress_session_returns_false_under_budget(): 196 session = AgentSession(messages=[_user("hello"), _assistant("hi")]) 197 result = asyncio.run( 198 compress_session( 199 session, 200 provider=None, 201 config=None, 202 context_window=100000, 203 ) 204 ) 205 assert result is False