test_agent2_types.py
1 """Tests for restai.agent2.types — block types, serialization, image MIME detection.""" 2 import base64 3 4 from restai.agent2.types import ( 5 ImageBlock, 6 Message, 7 TextBlock, 8 ToolResultBlock, 9 ToolUseBlock, 10 block_from_dict, 11 block_to_dict, 12 detect_image_mime, 13 message_from_dict, 14 message_to_dict, 15 user_text_message, 16 ) 17 18 19 # ---------- construction ---------- 20 21 22 def test_text_block_construction(): 23 b = TextBlock(text="hello") 24 assert b.text == "hello" 25 26 27 def test_tool_use_block_construction(): 28 b = ToolUseBlock(id="t1", name="search", input={"q": "test"}) 29 assert b.id == "t1" 30 assert b.name == "search" 31 assert b.input == {"q": "test"} 32 33 34 def test_tool_result_block_construction(): 35 b = ToolResultBlock(tool_use_id="t1", content="result text") 36 assert b.tool_use_id == "t1" 37 assert b.content == "result text" 38 assert b.is_error is False 39 40 41 def test_tool_result_block_error(): 42 b = ToolResultBlock(tool_use_id="t1", content="oops", is_error=True) 43 assert b.is_error is True 44 45 46 def test_image_block_construction(): 47 b = ImageBlock(data="abc123", mime_type="image/png") 48 assert b.data == "abc123" 49 assert b.mime_type == "image/png" 50 51 52 # ---------- Message.text_content ---------- 53 54 55 def test_message_text_content_joins_text_blocks(): 56 msg = Message( 57 role="user", 58 content=[TextBlock(text="hello"), TextBlock(text="world")], 59 ) 60 assert msg.text_content() == "hello\nworld" 61 62 63 def test_message_text_content_ignores_non_text_blocks(): 64 msg = Message( 65 role="assistant", 66 content=[ 67 TextBlock(text="start"), 68 ToolUseBlock(id="t1", name="x", input={}), 69 TextBlock(text="end"), 70 ], 71 ) 72 assert msg.text_content() == "start\nend" 73 74 75 def test_message_text_content_empty(): 76 msg = Message(role="user", content=[]) 77 assert msg.text_content() == "" 78 79 80 # ---------- block_to_dict / block_from_dict round-trip ---------- 81 82 83 def test_text_block_roundtrip(): 84 original = TextBlock(text="hello world") 85 d = block_to_dict(original) 86 assert d["type"] == "text" 87 restored = block_from_dict(d) 88 assert isinstance(restored, TextBlock) 89 assert restored.text == original.text 90 91 92 def test_tool_use_block_roundtrip(): 93 original = ToolUseBlock(id="abc", name="search", input={"q": "test"}) 94 d = block_to_dict(original) 95 assert d["type"] == "tool_use" 96 restored = block_from_dict(d) 97 assert isinstance(restored, ToolUseBlock) 98 assert restored.id == original.id 99 assert restored.name == original.name 100 assert restored.input == original.input 101 102 103 def test_tool_result_block_roundtrip(): 104 original = ToolResultBlock(tool_use_id="abc", content="done", is_error=True) 105 d = block_to_dict(original) 106 assert d["type"] == "tool_result" 107 restored = block_from_dict(d) 108 assert isinstance(restored, ToolResultBlock) 109 assert restored.tool_use_id == original.tool_use_id 110 assert restored.content == original.content 111 assert restored.is_error == original.is_error 112 113 114 def test_image_block_roundtrip(): 115 original = ImageBlock(data="iVBOR", mime_type="image/png") 116 d = block_to_dict(original) 117 assert d["type"] == "image" 118 restored = block_from_dict(d) 119 assert isinstance(restored, ImageBlock) 120 assert restored.data == original.data 121 assert restored.mime_type == original.mime_type 122 123 124 # ---------- message_to_dict / message_from_dict round-trip ---------- 125 126 127 def test_message_roundtrip(): 128 original = Message( 129 role="assistant", 130 content=[ 131 TextBlock(text="thinking..."), 132 ToolUseBlock(id="t1", name="calc", input={"expr": "1+1"}), 133 ], 134 ) 135 d = message_to_dict(original) 136 assert d["role"] == "assistant" 137 assert len(d["content"]) == 2 138 restored = message_from_dict(d) 139 assert restored.role == original.role 140 assert len(restored.content) == 2 141 assert isinstance(restored.content[0], TextBlock) 142 assert isinstance(restored.content[1], ToolUseBlock) 143 assert restored.content[0].text == "thinking..." 144 assert restored.content[1].name == "calc" 145 146 147 # ---------- detect_image_mime ---------- 148 149 150 def test_detect_image_mime_png(): 151 raw = b"\x89PNG\r\n\x1a\n" + b"\x00" * 56 152 b64 = base64.b64encode(raw).decode() 153 assert detect_image_mime(b64) == "image/png" 154 155 156 def test_detect_image_mime_jpeg(): 157 raw = b"\xff\xd8\xff" + b"\x00" * 61 158 b64 = base64.b64encode(raw).decode() 159 assert detect_image_mime(b64) == "image/jpeg" 160 161 162 def test_detect_image_mime_gif(): 163 raw = b"GIF89a" + b"\x00" * 58 164 b64 = base64.b64encode(raw).decode() 165 assert detect_image_mime(b64) == "image/gif" 166 167 168 def test_detect_image_mime_unknown_fallback(): 169 raw = b"\x00\x01\x02\x03" + b"\x00" * 60 170 b64 = base64.b64encode(raw).decode() 171 assert detect_image_mime(b64) == "image/png" 172 173 174 # ---------- ImageBlock.from_data_url ---------- 175 176 177 def test_image_block_from_data_url_with_prefix(): 178 raw = b"\xff\xd8\xff" + b"\x00" * 61 179 b64 = base64.b64encode(raw).decode() 180 url = f"data:image/jpeg;base64,{b64}" 181 block = ImageBlock.from_data_url(url) 182 assert block.mime_type == "image/jpeg" 183 assert block.data == b64 184 185 186 def test_image_block_from_data_url_plain_base64(): 187 raw = b"\x89PNG\r\n\x1a\n" + b"\x00" * 56 188 b64 = base64.b64encode(raw).decode() 189 block = ImageBlock.from_data_url(b64) 190 assert block.mime_type == "image/png" 191 assert block.data == b64 192 193 194 # ---------- user_text_message helper ---------- 195 196 197 def test_user_text_message(): 198 msg = user_text_message("hi there") 199 assert msg.role == "user" 200 assert len(msg.content) == 1 201 assert isinstance(msg.content[0], TextBlock) 202 assert msg.content[0].text == "hi there"