/ src / solace_agent_mesh / common / a2a / artifact.py
artifact.py
  1  """
  2  Helpers for creating and consuming A2A Artifact objects.
  3  """
  4  
  5  import logging
  6  import uuid
  7  import base64
  8  from datetime import datetime, timezone
  9  from typing import Any, List, Optional, TYPE_CHECKING
 10  from urllib.parse import urlparse, parse_qs
 11  
 12  from .types import ContentPart
 13  from a2a.types import (
 14      Artifact,
 15      DataPart,
 16      FilePart,
 17      FileWithBytes,
 18      FileWithUri,
 19      Part,
 20      TextPart,
 21  )
 22  from .. import a2a
 23  from ..constants import ARTIFACT_TAG_USER_UPLOADED
 24  from ..utils.mime_helpers import resolve_mime_type
 25  
 26  if TYPE_CHECKING:
 27      from google.adk.artifacts import BaseArtifactService
 28  
 29  log = logging.getLogger(__name__)
 30  
 31  # --- Creation Helpers ---
 32  
 33  
 34  def create_text_artifact(
 35      name: str,
 36      text: str,
 37      description: str = "",
 38      artifact_id: Optional[str] = None,
 39  ) -> Artifact:
 40      """
 41      Creates a new Artifact object containing only a single TextPart.
 42  
 43      Args:
 44          name: The human-readable name of the artifact.
 45          text: The text content of the artifact.
 46          description: An optional description of the artifact.
 47          artifact_id: The artifact ID. If None, a new UUID is generated.
 48  
 49      Returns:
 50          A new `Artifact` object.
 51      """
 52      text_part = TextPart(text=text)
 53      return Artifact(
 54          artifact_id=artifact_id or str(uuid.uuid4().hex),
 55          parts=[Part(root=text_part)],
 56          name=name,
 57          description=description,
 58      )
 59  
 60  
 61  def create_data_artifact(
 62      name: str,
 63      data: dict[str, Any],
 64      description: str = "",
 65      artifact_id: Optional[str] = None,
 66  ) -> Artifact:
 67      """
 68      Creates a new Artifact object containing only a single DataPart.
 69  
 70      Args:
 71          name: The human-readable name of the artifact.
 72          data: The structured data content of the artifact.
 73          description: An optional description of the artifact.
 74          artifact_id: The artifact ID. If None, a new UUID is generated.
 75  
 76      Returns:
 77          A new `Artifact` object.
 78      """
 79      data_part = DataPart(data=data)
 80      return Artifact(
 81          artifact_id=artifact_id or str(uuid.uuid4().hex),
 82          parts=[Part(root=data_part)],
 83          name=name,
 84          description=description,
 85      )
 86  
 87  
 88  def update_artifact_parts(artifact: Artifact, new_parts: List[ContentPart]) -> Artifact:
 89      """Returns a new Artifact with its parts replaced."""
 90      wrapped_parts = [Part(root=p) for p in new_parts]
 91      return artifact.model_copy(update={"parts": wrapped_parts})
 92  
 93  
 94  async def prepare_file_part_for_publishing(
 95      part: FilePart,
 96      mode: str,
 97      artifact_service: "BaseArtifactService",
 98      user_id: str,
 99      session_id: str,
100      target_agent_name: str,
101      log_identifier: str,
102  ) -> Optional[FilePart]:
103      """
104      Prepares a FilePart for publishing based on the artifact handling mode.
105  
106      - 'ignore': Returns None.
107      - 'embed': Ensures the part contains bytes, resolving a URI if necessary.
108      - 'reference': Ensures the part contains a URI, saving bytes if necessary.
109      - 'passthrough': Returns the part as-is.
110  
111      Args:
112          part: The input FilePart, which may contain raw bytes or a URI.
113          mode: The artifact handling mode ('ignore', 'embed', 'reference', 'passthrough').
114          artifact_service: The ADK artifact service instance.
115          user_id: The user ID for the artifact context.
116          session_id: The session ID for the artifact context.
117          target_agent_name: The name of the agent the artifact will be associated with.
118          log_identifier: The logging identifier for log messages.
119  
120      Returns:
121          The processed FilePart, or None if ignored.
122      """
123      log_id = f"{log_identifier}[PrepareFilePart]"
124  
125      if mode == "ignore":
126          log.debug("%s Mode is 'ignore', filtering out FilePart.", log_id)
127          return None
128  
129      if mode == "passthrough":
130          log.debug("%s Mode is 'passthrough', returning original FilePart.", log_id)
131          return part
132  
133      if mode == "embed":
134          if isinstance(part.file, FileWithUri):
135              log.debug("%s Mode is 'embed', resolving URI for FilePart.", log_id)
136              return await resolve_file_part_uri(part, artifact_service, log_identifier)
137          return part  # It's already bytes, so it's embedded.
138  
139      if mode == "reference":
140          if isinstance(part.file, FileWithBytes):
141              if not artifact_service:
142                  log.warning(
143                      "%s Mode is 'reference' but no artifact_service is configured. Ignoring FilePart '%s'.",
144                      log_id,
145                      part.file.name,
146                  )
147                  return None
148  
149              try:
150                  filename = part.file.name or f"upload-{uuid.uuid4().hex}"
151                  content_bytes = base64.b64decode(part.file.bytes)
152                  mime_type = resolve_mime_type(filename, part.file.mime_type)
153  
154                  # Create a concise and accurate metadata dictionary.
155                  metadata_to_save = {
156                      "source": log_identifier,
157                      "description": "This artifact was uploaded via the gateway",
158                  }
159  
160                  # Call the helper with the new, simpler metadata.
161                  from ...agent.utils.artifact_helpers import save_artifact_with_metadata
162  
163                  save_result = await save_artifact_with_metadata(
164                      artifact_service=artifact_service,
165                      app_name=target_agent_name,
166                      user_id=user_id,
167                      session_id=session_id,
168                      filename=filename,
169                      content_bytes=content_bytes,
170                      mime_type=mime_type,
171                      metadata_dict=metadata_to_save,
172                      timestamp=datetime.now(timezone.utc),
173                      tags=[ARTIFACT_TAG_USER_UPLOADED],
174                  )
175  
176                  if save_result["status"] == "success":
177                      saved_version = save_result.get("data_version")
178                      from ...agent.utils.artifact_helpers import format_artifact_uri
179  
180                      artifact_uri = format_artifact_uri(
181                          app_name=target_agent_name,
182                          user_id=user_id,
183                          session_id=session_id,
184                          filename=filename,
185                          version=saved_version,
186                      )
187                      ref_part = a2a.create_file_part_from_uri(
188                          uri=artifact_uri,
189                          name=filename,
190                          mime_type=mime_type,
191                          metadata=part.metadata,
192                      )
193                      log.info(
194                          "%s Converted embedded file '%s' to reference: %s",
195                          log_id,
196                          filename,
197                          artifact_uri,
198                      )
199                      return ref_part
200                  else:
201                      log.error(
202                          "%s Failed to save artifact via helper: %s. Skipping FilePart.",
203                          log_id,
204                          save_result.get("message"),
205                      )
206                      return None
207  
208              except Exception as e:
209                  log.exception(
210                      "%s Failed to save artifact for reference mode: %s. Skipping FilePart.",
211                      log_id,
212                      e,
213                  )
214                  return None
215          return part  # It's already a reference (URI)
216  
217      # Default case if mode is unrecognized
218      log.warning(
219          "%s Unrecognized artifact_handling_mode '%s'. Ignoring FilePart.", log_id, mode
220      )
221      return None
222  
223  
224  async def resolve_file_part_uri(
225      part: FilePart, artifact_service: "BaseArtifactService", log_identifier: str
226  ) -> FilePart:
227      """
228      Resolves an artifact URI within a FilePart into embedded bytes.
229  
230      If the FilePart does not contain a resolvable `artifact://` URI, it is
231      returned unchanged.
232  
233      Args:
234          part: The FilePart to resolve.
235          artifact_service: The ADK artifact service instance.
236          log_identifier: The logging identifier for log messages.
237  
238      Returns:
239          A FilePart, either with embedded bytes if resolved, or the original part.
240      """
241      if not (
242          isinstance(part.file, FileWithUri)
243          and part.file.uri
244          and part.file.uri.startswith("artifact://")
245      ):
246          return part
247  
248      if not artifact_service:
249          log.warning(
250              "%s Cannot resolve artifact URI, artifact_service is not configured.",
251              log_identifier,
252          )
253          return part
254  
255      uri = part.file.uri
256      log_id_prefix = f"{log_identifier}[ResolveURI]"
257      try:
258          log.info("%s Found artifact URI to resolve: %s", log_id_prefix, uri)
259          parsed_uri = urlparse(uri)
260          app_name = parsed_uri.netloc
261          path_parts = parsed_uri.path.strip("/").split("/")
262  
263          if not app_name or len(path_parts) != 3:
264              raise ValueError(
265                  "Invalid URI structure. Expected artifact://app_name/user_id/session_id/filename"
266              )
267  
268          user_id, session_id, filename = path_parts
269          version_str = parse_qs(parsed_uri.query).get("version", [None])[0]
270          version = int(version_str) if version_str else None
271  
272          from ...agent.utils.artifact_helpers import load_artifact_content_or_metadata
273  
274          loaded_artifact = await load_artifact_content_or_metadata(
275              artifact_service=artifact_service,
276              app_name=app_name,
277              user_id=user_id,
278              session_id=session_id,
279              filename=filename,
280              version=version,
281              return_raw_bytes=True,
282          )
283  
284          if loaded_artifact.get("status") == "success":
285              content_bytes = loaded_artifact.get("raw_bytes")
286              new_file_content = FileWithBytes(
287                  bytes=base64.b64encode(content_bytes).decode("utf-8"),
288                  mime_type=part.file.mime_type,
289                  name=part.file.name,
290              )
291              part.file = new_file_content
292              log.info(
293                  "%s Successfully resolved and embedded artifact: %s",
294                  log_id_prefix,
295                  uri,
296              )
297          else:
298              log.error(
299                  "%s Failed to resolve artifact URI '%s': %s",
300                  log_id_prefix,
301                  uri,
302                  loaded_artifact.get("message"),
303              )
304      except Exception as e:
305          log.exception("%s Error resolving artifact URI '%s': %s", log_id_prefix, uri, e)
306      return part
307  
308  
309  # --- Consumption Helpers ---
310  
311  
312  def get_artifact_id(artifact: Artifact) -> str:
313      """Safely retrieves the ID from an Artifact object."""
314      return artifact.artifact_id
315  
316  
317  def get_artifact_name(artifact: Artifact) -> Optional[str]:
318      """Safely retrieves the name from an Artifact object."""
319      return artifact.name
320  
321  
322  def get_parts_from_artifact(artifact: Artifact) -> List[ContentPart]:
323      """
324      Extracts the raw, unwrapped Part objects (TextPart, DataPart, etc.) from an Artifact.
325  
326      Args:
327          artifact: The `Artifact` object.
328  
329      Returns:
330          A list of the unwrapped content parts.
331      """
332      return [part.root for part in artifact.parts]
333  
334  
335  def is_text_only_artifact(artifact: Artifact) -> bool:
336      """
337      Checks if an artifact contains only TextParts.
338  
339      Args:
340          artifact: The Artifact object to check.
341  
342      Returns:
343          True if all parts are TextParts, False otherwise.
344      """
345      if not artifact.parts:
346          return False
347      
348      for part in artifact.parts:
349          if not isinstance(part.root, TextPart):
350              return False
351      
352      return True
353  
354  
355  def get_text_content_from_artifact(artifact: Artifact) -> List[str]:
356      """
357      Extracts all text content from TextParts in an artifact.
358  
359      Args:
360          artifact: The Artifact object to extract text from.
361  
362      Returns:
363          A list of text strings from all TextParts. Returns empty list if no TextParts found.
364      """
365      text_content = []
366      
367      for part in artifact.parts:
368          if isinstance(part.root, TextPart):
369              text_content.append(part.root.text)
370      
371      return text_content