/ tests / test_agent2_types.py
test_agent2_types.py
  1  """Tests for restai.agent2.types — block types, serialization, image MIME detection."""
  2  import base64
  3  
  4  from restai.agent2.types import (
  5      ImageBlock,
  6      Message,
  7      TextBlock,
  8      ToolResultBlock,
  9      ToolUseBlock,
 10      block_from_dict,
 11      block_to_dict,
 12      detect_image_mime,
 13      message_from_dict,
 14      message_to_dict,
 15      user_text_message,
 16  )
 17  
 18  
 19  # ---------- construction ----------
 20  
 21  
 22  def test_text_block_construction():
 23      b = TextBlock(text="hello")
 24      assert b.text == "hello"
 25  
 26  
 27  def test_tool_use_block_construction():
 28      b = ToolUseBlock(id="t1", name="search", input={"q": "test"})
 29      assert b.id == "t1"
 30      assert b.name == "search"
 31      assert b.input == {"q": "test"}
 32  
 33  
 34  def test_tool_result_block_construction():
 35      b = ToolResultBlock(tool_use_id="t1", content="result text")
 36      assert b.tool_use_id == "t1"
 37      assert b.content == "result text"
 38      assert b.is_error is False
 39  
 40  
 41  def test_tool_result_block_error():
 42      b = ToolResultBlock(tool_use_id="t1", content="oops", is_error=True)
 43      assert b.is_error is True
 44  
 45  
 46  def test_image_block_construction():
 47      b = ImageBlock(data="abc123", mime_type="image/png")
 48      assert b.data == "abc123"
 49      assert b.mime_type == "image/png"
 50  
 51  
 52  # ---------- Message.text_content ----------
 53  
 54  
 55  def test_message_text_content_joins_text_blocks():
 56      msg = Message(
 57          role="user",
 58          content=[TextBlock(text="hello"), TextBlock(text="world")],
 59      )
 60      assert msg.text_content() == "hello\nworld"
 61  
 62  
 63  def test_message_text_content_ignores_non_text_blocks():
 64      msg = Message(
 65          role="assistant",
 66          content=[
 67              TextBlock(text="start"),
 68              ToolUseBlock(id="t1", name="x", input={}),
 69              TextBlock(text="end"),
 70          ],
 71      )
 72      assert msg.text_content() == "start\nend"
 73  
 74  
 75  def test_message_text_content_empty():
 76      msg = Message(role="user", content=[])
 77      assert msg.text_content() == ""
 78  
 79  
 80  # ---------- block_to_dict / block_from_dict round-trip ----------
 81  
 82  
 83  def test_text_block_roundtrip():
 84      original = TextBlock(text="hello world")
 85      d = block_to_dict(original)
 86      assert d["type"] == "text"
 87      restored = block_from_dict(d)
 88      assert isinstance(restored, TextBlock)
 89      assert restored.text == original.text
 90  
 91  
 92  def test_tool_use_block_roundtrip():
 93      original = ToolUseBlock(id="abc", name="search", input={"q": "test"})
 94      d = block_to_dict(original)
 95      assert d["type"] == "tool_use"
 96      restored = block_from_dict(d)
 97      assert isinstance(restored, ToolUseBlock)
 98      assert restored.id == original.id
 99      assert restored.name == original.name
100      assert restored.input == original.input
101  
102  
103  def test_tool_result_block_roundtrip():
104      original = ToolResultBlock(tool_use_id="abc", content="done", is_error=True)
105      d = block_to_dict(original)
106      assert d["type"] == "tool_result"
107      restored = block_from_dict(d)
108      assert isinstance(restored, ToolResultBlock)
109      assert restored.tool_use_id == original.tool_use_id
110      assert restored.content == original.content
111      assert restored.is_error == original.is_error
112  
113  
114  def test_image_block_roundtrip():
115      original = ImageBlock(data="iVBOR", mime_type="image/png")
116      d = block_to_dict(original)
117      assert d["type"] == "image"
118      restored = block_from_dict(d)
119      assert isinstance(restored, ImageBlock)
120      assert restored.data == original.data
121      assert restored.mime_type == original.mime_type
122  
123  
124  # ---------- message_to_dict / message_from_dict round-trip ----------
125  
126  
127  def test_message_roundtrip():
128      original = Message(
129          role="assistant",
130          content=[
131              TextBlock(text="thinking..."),
132              ToolUseBlock(id="t1", name="calc", input={"expr": "1+1"}),
133          ],
134      )
135      d = message_to_dict(original)
136      assert d["role"] == "assistant"
137      assert len(d["content"]) == 2
138      restored = message_from_dict(d)
139      assert restored.role == original.role
140      assert len(restored.content) == 2
141      assert isinstance(restored.content[0], TextBlock)
142      assert isinstance(restored.content[1], ToolUseBlock)
143      assert restored.content[0].text == "thinking..."
144      assert restored.content[1].name == "calc"
145  
146  
147  # ---------- detect_image_mime ----------
148  
149  
150  def test_detect_image_mime_png():
151      raw = b"\x89PNG\r\n\x1a\n" + b"\x00" * 56
152      b64 = base64.b64encode(raw).decode()
153      assert detect_image_mime(b64) == "image/png"
154  
155  
156  def test_detect_image_mime_jpeg():
157      raw = b"\xff\xd8\xff" + b"\x00" * 61
158      b64 = base64.b64encode(raw).decode()
159      assert detect_image_mime(b64) == "image/jpeg"
160  
161  
162  def test_detect_image_mime_gif():
163      raw = b"GIF89a" + b"\x00" * 58
164      b64 = base64.b64encode(raw).decode()
165      assert detect_image_mime(b64) == "image/gif"
166  
167  
168  def test_detect_image_mime_unknown_fallback():
169      raw = b"\x00\x01\x02\x03" + b"\x00" * 60
170      b64 = base64.b64encode(raw).decode()
171      assert detect_image_mime(b64) == "image/png"
172  
173  
174  # ---------- ImageBlock.from_data_url ----------
175  
176  
177  def test_image_block_from_data_url_with_prefix():
178      raw = b"\xff\xd8\xff" + b"\x00" * 61
179      b64 = base64.b64encode(raw).decode()
180      url = f"data:image/jpeg;base64,{b64}"
181      block = ImageBlock.from_data_url(url)
182      assert block.mime_type == "image/jpeg"
183      assert block.data == b64
184  
185  
186  def test_image_block_from_data_url_plain_base64():
187      raw = b"\x89PNG\r\n\x1a\n" + b"\x00" * 56
188      b64 = base64.b64encode(raw).decode()
189      block = ImageBlock.from_data_url(b64)
190      assert block.mime_type == "image/png"
191      assert block.data == b64
192  
193  
194  # ---------- user_text_message helper ----------
195  
196  
197  def test_user_text_message():
198      msg = user_text_message("hi there")
199      assert msg.role == "user"
200      assert len(msg.content) == 1
201      assert isinstance(msg.content[0], TextBlock)
202      assert msg.content[0].text == "hi there"