/ restai / agent2 / types.py
types.py
  1  """Pure-data types for the agent2 runtime. No llamaindex imports."""
  2  from __future__ import annotations
  3  
  4  import base64
  5  from dataclasses import dataclass, field
  6  from typing import Literal, Union
  7  
  8  MessageRole = Literal["user", "assistant"]
  9  
 10  
 11  @dataclass
 12  class TextBlock:
 13      text: str
 14  
 15  
 16  @dataclass
 17  class ToolUseBlock:
 18      id: str
 19      name: str
 20      input: dict
 21  
 22  
 23  @dataclass
 24  class ToolResultBlock:
 25      tool_use_id: str
 26      content: str
 27      is_error: bool = False
 28  
 29  
 30  @dataclass
 31  class ImageBlock:
 32      """A base64-encoded image attached to a user message.
 33  
 34      `data` is the raw base64 string (no data URL prefix). `mime_type` is the
 35      standard MIME (e.g. 'image/png', 'image/jpeg'). Use `from_data_url()` or
 36      `from_base64()` to construct one with auto-detection.
 37      """
 38      data: str
 39      mime_type: str
 40  
 41      @classmethod
 42      def from_data_url(cls, url: str) -> "ImageBlock":
 43          # data:image/png;base64,iVBORw0KG...
 44          if url.startswith("data:") and ";base64," in url:
 45              header, _, body = url.partition(";base64,")
 46              mime = header[len("data:"):] or "image/png"
 47              return cls(data=body, mime_type=mime)
 48          # Plain base64 — sniff
 49          return cls.from_base64(url)
 50  
 51      @classmethod
 52      def from_base64(cls, b64: str) -> "ImageBlock":
 53          return cls(data=b64, mime_type=detect_image_mime(b64))
 54  
 55  
 56  def detect_image_mime(b64_data: str) -> str:
 57      """Sniff the MIME type from the leading bytes of a base64-encoded image.
 58  
 59      Falls back to 'image/png' if the magic bytes don't match anything known.
 60      """
 61      try:
 62          head = base64.b64decode(b64_data[:64], validate=False)
 63      except Exception:
 64          return "image/png"
 65      if head.startswith(b"\x89PNG\r\n\x1a\n"):
 66          return "image/png"
 67      if head.startswith(b"\xff\xd8\xff"):
 68          return "image/jpeg"
 69      if head.startswith(b"GIF87a") or head.startswith(b"GIF89a"):
 70          return "image/gif"
 71      if head[:4] == b"RIFF" and head[8:12] == b"WEBP":
 72          return "image/webp"
 73      return "image/png"
 74  
 75  
 76  ContentBlock = Union[TextBlock, ToolUseBlock, ToolResultBlock, ImageBlock]
 77  
 78  
 79  @dataclass
 80  class Message:
 81      role: MessageRole
 82      content: list
 83  
 84      def text_content(self) -> str:
 85          return "\n".join(
 86              block.text for block in self.content if isinstance(block, TextBlock)
 87          ).strip()
 88  
 89  
 90  @dataclass
 91  class AgentSession:
 92      messages: list = field(default_factory=list)
 93      turn_count: int = 0
 94      state: dict = field(default_factory=dict)
 95  
 96  
 97  @dataclass
 98  class AgentEvent:
 99      type: Literal["assistant", "tool_result", "final", "text_delta"]
100      message: Union[Message, None] = None
101      turn: int = 0
102      data: dict = field(default_factory=dict)
103  
104  
105  def user_text_message(text: str) -> Message:
106      return Message(role="user", content=[TextBlock(text=text)])
107  
108  
109  # ---------- JSON serialization (for memory persistence) ----------
110  
111  
112  def block_to_dict(block: ContentBlock) -> dict:
113      if isinstance(block, TextBlock):
114          return {"type": "text", "text": block.text}
115      if isinstance(block, ToolUseBlock):
116          return {"type": "tool_use", "id": block.id, "name": block.name, "input": block.input}
117      if isinstance(block, ToolResultBlock):
118          return {
119              "type": "tool_result",
120              "tool_use_id": block.tool_use_id,
121              "content": block.content,
122              "is_error": block.is_error,
123          }
124      if isinstance(block, ImageBlock):
125          return {"type": "image", "data": block.data, "mime_type": block.mime_type}
126      raise TypeError(f"Unknown block type: {type(block)}")
127  
128  
129  def block_from_dict(d: dict) -> ContentBlock:
130      t = d.get("type")
131      if t == "text":
132          return TextBlock(text=d.get("text", ""))
133      if t == "tool_use":
134          return ToolUseBlock(id=d["id"], name=d["name"], input=d.get("input", {}))
135      if t == "tool_result":
136          return ToolResultBlock(
137              tool_use_id=d["tool_use_id"],
138              content=d.get("content", ""),
139              is_error=bool(d.get("is_error", False)),
140          )
141      if t == "image":
142          return ImageBlock(data=d.get("data", ""), mime_type=d.get("mime_type", "image/png"))
143      raise ValueError(f"Unknown block dict: {d!r}")
144  
145  
146  def message_to_dict(msg: Message) -> dict:
147      return {"role": msg.role, "content": [block_to_dict(b) for b in msg.content]}
148  
149  
150  def message_from_dict(d: dict) -> Message:
151      return Message(
152          role=d["role"],
153          content=[block_from_dict(b) for b in d.get("content", [])],
154      )