/ tests / test_agent2_compression.py
test_agent2_compression.py
  1  """Tests for restai.agent2.compression — token counting, turn boundaries, splitting, truncation."""
  2  import asyncio
  3  
  4  from restai.agent2.compression import (
  5      _approx_char_count,
  6      compress_session,
  7      count_session_tokens,
  8      find_user_turn_boundaries,
  9      hard_truncate,
 10      split_for_compression,
 11  )
 12  from restai.agent2.types import (
 13      AgentSession,
 14      Message,
 15      TextBlock,
 16      ToolResultBlock,
 17      ToolUseBlock,
 18      user_text_message,
 19  )
 20  
 21  
 22  def _user(text: str) -> Message:
 23      return user_text_message(text)
 24  
 25  
 26  def _assistant(text: str) -> Message:
 27      return Message(role="assistant", content=[TextBlock(text=text)])
 28  
 29  
 30  def _assistant_tool_use() -> Message:
 31      return Message(
 32          role="assistant",
 33          content=[ToolUseBlock(id="t1", name="search", input={"q": "test"})],
 34      )
 35  
 36  
 37  def _user_tool_result() -> Message:
 38      return Message(
 39          role="user",
 40          content=[ToolResultBlock(tool_use_id="t1", content="result")],
 41      )
 42  
 43  
 44  # ---------- count_session_tokens ----------
 45  
 46  
 47  def test_count_session_tokens_nonempty():
 48      msgs = [_user("Hello, how are you?"), _assistant("I am fine, thank you.")]
 49      tokens = count_session_tokens(msgs)
 50      assert tokens > 0
 51  
 52  
 53  def test_count_session_tokens_empty():
 54      assert count_session_tokens([]) == 0
 55  
 56  
 57  # ---------- find_user_turn_boundaries ----------
 58  
 59  
 60  def test_find_user_turn_boundaries_basic():
 61      msgs = [
 62          _user("q1"),           # 0 - turn boundary
 63          _assistant("a1"),      # 1
 64          _user("q2"),           # 2 - turn boundary
 65          _assistant("a2"),      # 3
 66          _user("q3"),           # 4 - turn boundary
 67      ]
 68      boundaries = find_user_turn_boundaries(msgs)
 69      assert boundaries == [0, 2, 4]
 70  
 71  
 72  def test_find_user_turn_boundaries_skips_tool_result_user():
 73      msgs = [
 74          _user("q1"),           # 0 - turn boundary
 75          _assistant_tool_use(), # 1
 76          _user_tool_result(),   # 2 - NOT a turn boundary (contains ToolResultBlock)
 77          _assistant("a1"),      # 3
 78          _user("q2"),           # 4 - turn boundary
 79      ]
 80      boundaries = find_user_turn_boundaries(msgs)
 81      assert boundaries == [0, 4]
 82  
 83  
 84  def test_find_user_turn_boundaries_empty():
 85      assert find_user_turn_boundaries([]) == []
 86  
 87  
 88  # ---------- split_for_compression ----------
 89  
 90  
 91  def test_split_for_compression_enough_turns():
 92      msgs = [
 93          _user("q1"), _assistant("a1"),
 94          _user("q2"), _assistant("a2"),
 95          _user("q3"), _assistant("a3"),
 96          _user("q4"), _assistant("a4"),
 97          _user("q5"), _assistant("a5"),
 98      ]
 99      to_compress, to_keep = split_for_compression(msgs, keep_n_turns=3)
100      assert len(to_compress) > 0
101      assert len(to_keep) > 0
102      # The kept slice should start at the 3rd-from-last turn boundary
103      assert to_keep[0].role == "user"
104      assert to_compress + to_keep == msgs
105  
106  
107  def test_split_for_compression_not_enough_turns():
108      msgs = [_user("q1"), _assistant("a1")]
109      to_compress, to_keep = split_for_compression(msgs, keep_n_turns=3)
110      assert to_compress == []
111      assert to_keep == msgs
112  
113  
114  def test_split_for_compression_exact_keep_n():
115      msgs = [
116          _user("q1"), _assistant("a1"),
117          _user("q2"), _assistant("a2"),
118          _user("q3"), _assistant("a3"),
119      ]
120      to_compress, to_keep = split_for_compression(msgs, keep_n_turns=3)
121      # 3 turns, keep_n=3 → not enough to compress (need strictly more)
122      assert to_compress == []
123      assert to_keep == msgs
124  
125  
126  # ---------- hard_truncate ----------
127  
128  
129  def test_hard_truncate_drops_oldest():
130      msgs = [
131          _user("A" * 1000), _assistant("B" * 1000),
132          _user("C" * 1000), _assistant("D" * 1000),
133          _user("short"), _assistant("reply"),
134      ]
135      total_before = count_session_tokens(msgs)
136      # Set a budget well below the total but enough for the last turn
137      budget = count_session_tokens([_user("short"), _assistant("reply")]) + 50
138      truncated = hard_truncate(msgs, budget, keep_n_turns=1)
139      assert len(truncated) < len(msgs)
140      total_after = count_session_tokens(truncated)
141      assert total_after <= budget
142  
143  
144  def test_hard_truncate_empty():
145      assert hard_truncate([], 100, 3) == []
146  
147  
148  # ---------- _approx_char_count ----------
149  
150  
151  def test_approx_char_count_text():
152      msgs = [_user("hello"), _assistant("world")]
153      count = _approx_char_count(msgs)
154      assert count >= len("hello") + len("world")
155  
156  
157  def test_approx_char_count_tool_result():
158      msg = Message(
159          role="user",
160          content=[ToolResultBlock(tool_use_id="t1", content="result text")],
161      )
162      count = _approx_char_count([msg])
163      assert count >= len("result text")
164  
165  
166  # ---------- compress_session ----------
167  
168  
169  def test_compress_session_returns_false_no_context_window():
170      session = AgentSession(messages=[_user("hello")])
171      result = asyncio.run(
172          compress_session(
173              session,
174              provider=None,
175              config=None,
176              context_window=None,
177          )
178      )
179      assert result is False
180  
181  
182  def test_compress_session_returns_false_zero_context_window():
183      session = AgentSession(messages=[_user("hello")])
184      result = asyncio.run(
185          compress_session(
186              session,
187              provider=None,
188              config=None,
189              context_window=0,
190          )
191      )
192      assert result is False
193  
194  
195  def test_compress_session_returns_false_under_budget():
196      session = AgentSession(messages=[_user("hello"), _assistant("hi")])
197      result = asyncio.run(
198          compress_session(
199              session,
200              provider=None,
201              config=None,
202              context_window=100000,
203          )
204      )
205      assert result is False