/ cli / commands / task_cmd / common.py
common.py
  1  """
  2  Common utilities shared between task send and task run commands.
  3  """
  4  import asyncio
  5  import base64
  6  import mimetypes
  7  import sys
  8  import uuid
  9  from pathlib import Path
 10  from typing import Optional, List, Dict, Any, Callable
 11  
 12  import click
 13  import httpx
 14  
 15  from cli.utils import error_exit
 16  
 17  
 18  async def fetch_available_agents(url: str, token: Optional[str] = None) -> List[Dict[str, Any]]:
 19      """Fetch available agents from the gateway."""
 20      headers = {}
 21      if token:
 22          headers["Authorization"] = f"Bearer {token}"
 23  
 24      async with httpx.AsyncClient(timeout=10.0) as client:
 25          response = await client.get(f"{url}/api/v1/agentCards", headers=headers)
 26          response.raise_for_status()
 27          return response.json()
 28  
 29  
 30  def get_agent_name_from_cards(agent_cards: List[Dict[str, Any]], preferred_name: str) -> Optional[str]:
 31      """
 32      Find a matching agent name from available cards.
 33      Tries exact match first, then case-insensitive, then partial match.
 34      Returns the exact name if found, or None if not found.
 35      """
 36      preferred_lower = preferred_name.lower()
 37  
 38      # Try exact match first
 39      for card in agent_cards:
 40          name = card.get("name", "")
 41          if name == preferred_name:
 42              return name
 43  
 44      # Try case-insensitive exact match
 45      for card in agent_cards:
 46          name = card.get("name", "")
 47          if name.lower() == preferred_lower:
 48              return name
 49  
 50      # Try partial match (name contains preferred, or preferred contains name)
 51      for card in agent_cards:
 52          name = card.get("name", "")
 53          name_lower = name.lower()
 54          if preferred_lower in name_lower or name_lower in preferred_lower:
 55              return name
 56  
 57      return None
 58  
 59  
 60  def get_mime_type(file_path: Path) -> str:
 61      """Determine MIME type for a file."""
 62      mime_type, _ = mimetypes.guess_type(str(file_path))
 63      return mime_type or "application/octet-stream"
 64  
 65  
 66  def read_file_as_base64(file_path: Path) -> str:
 67      """Read a file and return its content as base64."""
 68      with open(file_path, "rb") as f:
 69          return base64.b64encode(f.read()).decode("utf-8")
 70  
 71  
 72  def build_file_parts(file_paths: List[str]) -> List[dict]:
 73      """Build FilePart objects for the given file paths."""
 74      parts = []
 75      for file_path_str in file_paths:
 76          file_path = Path(file_path_str).resolve()
 77          if not file_path.exists():
 78              error_exit(f"File not found: {file_path}")
 79          if not file_path.is_file():
 80              error_exit(f"Not a file: {file_path}")
 81  
 82          mime_type = get_mime_type(file_path)
 83          base64_content = read_file_as_base64(file_path)
 84  
 85          parts.append({
 86              "kind": "file",
 87              "file": {
 88                  "bytes": base64_content,
 89                  "name": file_path.name,
 90                  "mimeType": mime_type,
 91              },
 92          })
 93  
 94      return parts
 95  
 96  
 97  async def download_stim_file(
 98      url: str, task_id: str, output_dir: Path, headers: dict
 99  ):
100      """Download the STIM file for the task."""
101      stim_url = f"{url}/api/v1/tasks/{task_id}"
102  
103      async with httpx.AsyncClient(timeout=30.0) as client:
104          response = await client.get(stim_url, headers=headers)
105          response.raise_for_status()
106  
107          stim_path = output_dir / f"{task_id}.stim"
108          with open(stim_path, "wb") as f:
109              f.write(response.content)
110  
111  
112  async def execute_task(
113      message: str,
114      url: str,
115      agent: str,
116      session_id: Optional[str],
117      token: Optional[str],
118      files: List[str],
119      timeout: int,
120      output_dir: Optional[Path],
121      quiet: bool,
122      no_stim: bool,
123      debug: bool,
124      session_hint: str = "",
125  ) -> int:
126      """
127      Core task execution: send a task, stream SSE response, save outputs.
128  
129      Shared by both 'sam task send' and 'sam task run'.
130  
131      Args:
132          message: The prompt text to send
133          url: Base URL of the gateway (already stripped of trailing slash)
134          agent: Resolved agent name
135          session_id: Optional session ID (generates new UUID if None)
136          token: Optional auth token
137          files: List of file paths to attach
138          timeout: SSE timeout in seconds
139          output_dir: Output directory (auto-created from task ID if None)
140          quiet: Suppress streaming output
141          no_stim: Skip STIM file download
142          debug: Enable debug output
143          session_hint: Extra text appended after session ID in summary
144  
145      Returns:
146          Exit code (0 for success, 1 for failure)
147      """
148      from .sse_client import SSEClient
149      from .message_assembler import MessageAssembler
150      from .event_recorder import EventRecorder
151      from .artifact_handler import ArtifactHandler
152  
153      def _debug(msg: str):
154          if debug:
155              click.echo(click.style(f"[DEBUG] {msg}", fg="yellow"), err=True)
156  
157      # Generate session ID if not provided
158      effective_session_id = session_id or str(uuid.uuid4())
159  
160      # Build message parts
161      parts = [{"kind": "text", "text": message}]
162  
163      if files:
164          file_parts = build_file_parts(files)
165          parts.extend(file_parts)
166          if not quiet:
167              click.echo(click.style(f"Attached {len(file_parts)} file(s)", fg="blue"))
168  
169      # Build JSON-RPC request payload
170      payload = {
171          "jsonrpc": "2.0",
172          "id": f"req-{uuid.uuid4()}",
173          "method": "message/stream",
174          "params": {
175              "message": {
176                  "role": "user",
177                  "parts": parts,
178                  "messageId": f"msg-{uuid.uuid4()}",
179                  "kind": "message",
180                  "contextId": effective_session_id,
181                  "metadata": {"agent_name": agent},
182              }
183          },
184      }
185  
186      # Build headers
187      headers = {"Content-Type": "application/json"}
188      if token:
189          headers["Authorization"] = f"Bearer {token}"
190  
191      if not quiet:
192          click.echo(click.style(f"Sending task to {agent}...", fg="blue"))
193  
194      _debug(f"POST {url}/api/v1/message:stream")
195  
196      try:
197          async with httpx.AsyncClient(timeout=30.0) as client:
198              response = await client.post(
199                  f"{url}/api/v1/message:stream",
200                  json=payload,
201                  headers=headers,
202              )
203              _debug(f"Response status: {response.status_code}")
204              response.raise_for_status()
205              result = response.json()
206      except httpx.ConnectError:
207          click.echo(click.style(f"Failed to connect to {url}", fg="red"), err=True)
208          return 1
209      except httpx.HTTPStatusError as e:
210          click.echo(
211              click.style(f"HTTP error {e.response.status_code}: {e.response.text}", fg="red"),
212              err=True,
213          )
214          return 1
215  
216      # Extract task ID from response
217      task_result = result.get("result", {})
218      task_id = task_result.get("id")
219  
220      if not task_id:
221          click.echo(click.style(f"No task ID in response: {result}", fg="red"), err=True)
222          return 1
223  
224      if not quiet:
225          click.echo(click.style(f"Task ID: {task_id}", fg="blue"))
226          click.echo()
227  
228      # Create output directory if not provided
229      if output_dir is None:
230          output_dir = Path(f"/tmp/sam-task-{task_id}")
231      output_dir.mkdir(parents=True, exist_ok=True)
232  
233      # Initialize components
234      assembler = MessageAssembler()
235      recorder = EventRecorder(output_dir)
236      sse_client = SSEClient(url, timeout, token, debug=debug)
237  
238      _debug(f"Subscribing to SSE events for task {task_id}")
239  
240      # Track response text for saving
241      response_text_parts = []
242  
243      # Subscribe to SSE and process events
244      try:
245          async for event in sse_client.subscribe(task_id):
246              recorder.record_event(event.event_type, event.data)
247  
248              msg, new_text = assembler.process_event(event.event_type, event.data)
249  
250              if new_text:
251                  response_text_parts.append(new_text)
252                  if not quiet:
253                      click.echo(new_text, nl=False)
254                      sys.stdout.flush()
255  
256              if msg.is_complete:
257                  _debug("Task is complete")
258                  break
259  
260      except httpx.HTTPStatusError as e:
261          click.echo(click.style(f"\nSSE connection error: {e}", fg="red"), err=True)
262          return 1
263      except asyncio.TimeoutError as e:
264          click.echo(click.style(f"\nTimeout: {e}", fg="yellow"), err=True)
265      except Exception as e:
266          click.echo(click.style(f"\nSSE error: {e}", fg="red"), err=True)
267          return 1
268  
269      # Ensure newline after streaming
270      if not quiet and response_text_parts:
271          click.echo()
272  
273      # Get final message state
274      final_msg = assembler.get_message()
275  
276      # Save recorded events
277      recorder.save()
278  
279      # Save response text
280      response_text = "".join(response_text_parts)
281      response_path = output_dir / "response.txt"
282      with open(response_path, "w") as f:
283          f.write(response_text)
284  
285      # Download artifacts
286      artifact_handler = ArtifactHandler(url, effective_session_id, output_dir, token)
287      try:
288          downloaded_artifacts = await artifact_handler.download_all_artifacts()
289          if downloaded_artifacts and not quiet:
290              click.echo()
291              click.echo(click.style("Downloaded artifacts:", fg="green"))
292              for artifact in downloaded_artifacts:
293                  click.echo(f"  - {artifact.filename} ({artifact.size} bytes)")
294      except Exception as e:
295          if not quiet:
296              click.echo(click.style(f"Warning: Could not download artifacts: {e}", fg="yellow"))
297  
298      # Fetch STIM file
299      if not no_stim:
300          try:
301              await download_stim_file(url, task_id, output_dir, headers)
302          except Exception as e:
303              if not quiet:
304                  click.echo(click.style(f"Warning: Could not download STIM file: {e}", fg="yellow"))
305  
306      # Print summary
307      click.echo()
308      click.echo(click.style("---", fg="cyan"))
309  
310      if final_msg.is_error:
311          click.echo(click.style("Task failed.", fg="red", bold=True))
312          exit_code = 1
313      else:
314          click.echo(click.style("Task completed successfully.", fg="green", bold=True))
315          exit_code = 0
316  
317      click.echo(f"Session ID: {click.style(effective_session_id, fg='cyan')}{session_hint}")
318      click.echo(f"Task ID: {task_id}")
319      click.echo(f"Output directory: {output_dir}")
320      click.echo(f"Events recorded: {recorder.get_event_count()}")
321  
322      return exit_code