common.py
1 """ 2 Common utilities shared between task send and task run commands. 3 """ 4 import asyncio 5 import base64 6 import mimetypes 7 import sys 8 import uuid 9 from pathlib import Path 10 from typing import Optional, List, Dict, Any, Callable 11 12 import click 13 import httpx 14 15 from cli.utils import error_exit 16 17 18 async def fetch_available_agents(url: str, token: Optional[str] = None) -> List[Dict[str, Any]]: 19 """Fetch available agents from the gateway.""" 20 headers = {} 21 if token: 22 headers["Authorization"] = f"Bearer {token}" 23 24 async with httpx.AsyncClient(timeout=10.0) as client: 25 response = await client.get(f"{url}/api/v1/agentCards", headers=headers) 26 response.raise_for_status() 27 return response.json() 28 29 30 def get_agent_name_from_cards(agent_cards: List[Dict[str, Any]], preferred_name: str) -> Optional[str]: 31 """ 32 Find a matching agent name from available cards. 33 Tries exact match first, then case-insensitive, then partial match. 34 Returns the exact name if found, or None if not found. 35 """ 36 preferred_lower = preferred_name.lower() 37 38 # Try exact match first 39 for card in agent_cards: 40 name = card.get("name", "") 41 if name == preferred_name: 42 return name 43 44 # Try case-insensitive exact match 45 for card in agent_cards: 46 name = card.get("name", "") 47 if name.lower() == preferred_lower: 48 return name 49 50 # Try partial match (name contains preferred, or preferred contains name) 51 for card in agent_cards: 52 name = card.get("name", "") 53 name_lower = name.lower() 54 if preferred_lower in name_lower or name_lower in preferred_lower: 55 return name 56 57 return None 58 59 60 def get_mime_type(file_path: Path) -> str: 61 """Determine MIME type for a file.""" 62 mime_type, _ = mimetypes.guess_type(str(file_path)) 63 return mime_type or "application/octet-stream" 64 65 66 def read_file_as_base64(file_path: Path) -> str: 67 """Read a file and return its content as base64.""" 68 with open(file_path, "rb") as f: 69 return base64.b64encode(f.read()).decode("utf-8") 70 71 72 def build_file_parts(file_paths: List[str]) -> List[dict]: 73 """Build FilePart objects for the given file paths.""" 74 parts = [] 75 for file_path_str in file_paths: 76 file_path = Path(file_path_str).resolve() 77 if not file_path.exists(): 78 error_exit(f"File not found: {file_path}") 79 if not file_path.is_file(): 80 error_exit(f"Not a file: {file_path}") 81 82 mime_type = get_mime_type(file_path) 83 base64_content = read_file_as_base64(file_path) 84 85 parts.append({ 86 "kind": "file", 87 "file": { 88 "bytes": base64_content, 89 "name": file_path.name, 90 "mimeType": mime_type, 91 }, 92 }) 93 94 return parts 95 96 97 async def download_stim_file( 98 url: str, task_id: str, output_dir: Path, headers: dict 99 ): 100 """Download the STIM file for the task.""" 101 stim_url = f"{url}/api/v1/tasks/{task_id}" 102 103 async with httpx.AsyncClient(timeout=30.0) as client: 104 response = await client.get(stim_url, headers=headers) 105 response.raise_for_status() 106 107 stim_path = output_dir / f"{task_id}.stim" 108 with open(stim_path, "wb") as f: 109 f.write(response.content) 110 111 112 async def execute_task( 113 message: str, 114 url: str, 115 agent: str, 116 session_id: Optional[str], 117 token: Optional[str], 118 files: List[str], 119 timeout: int, 120 output_dir: Optional[Path], 121 quiet: bool, 122 no_stim: bool, 123 debug: bool, 124 session_hint: str = "", 125 ) -> int: 126 """ 127 Core task execution: send a task, stream SSE response, save outputs. 128 129 Shared by both 'sam task send' and 'sam task run'. 130 131 Args: 132 message: The prompt text to send 133 url: Base URL of the gateway (already stripped of trailing slash) 134 agent: Resolved agent name 135 session_id: Optional session ID (generates new UUID if None) 136 token: Optional auth token 137 files: List of file paths to attach 138 timeout: SSE timeout in seconds 139 output_dir: Output directory (auto-created from task ID if None) 140 quiet: Suppress streaming output 141 no_stim: Skip STIM file download 142 debug: Enable debug output 143 session_hint: Extra text appended after session ID in summary 144 145 Returns: 146 Exit code (0 for success, 1 for failure) 147 """ 148 from .sse_client import SSEClient 149 from .message_assembler import MessageAssembler 150 from .event_recorder import EventRecorder 151 from .artifact_handler import ArtifactHandler 152 153 def _debug(msg: str): 154 if debug: 155 click.echo(click.style(f"[DEBUG] {msg}", fg="yellow"), err=True) 156 157 # Generate session ID if not provided 158 effective_session_id = session_id or str(uuid.uuid4()) 159 160 # Build message parts 161 parts = [{"kind": "text", "text": message}] 162 163 if files: 164 file_parts = build_file_parts(files) 165 parts.extend(file_parts) 166 if not quiet: 167 click.echo(click.style(f"Attached {len(file_parts)} file(s)", fg="blue")) 168 169 # Build JSON-RPC request payload 170 payload = { 171 "jsonrpc": "2.0", 172 "id": f"req-{uuid.uuid4()}", 173 "method": "message/stream", 174 "params": { 175 "message": { 176 "role": "user", 177 "parts": parts, 178 "messageId": f"msg-{uuid.uuid4()}", 179 "kind": "message", 180 "contextId": effective_session_id, 181 "metadata": {"agent_name": agent}, 182 } 183 }, 184 } 185 186 # Build headers 187 headers = {"Content-Type": "application/json"} 188 if token: 189 headers["Authorization"] = f"Bearer {token}" 190 191 if not quiet: 192 click.echo(click.style(f"Sending task to {agent}...", fg="blue")) 193 194 _debug(f"POST {url}/api/v1/message:stream") 195 196 try: 197 async with httpx.AsyncClient(timeout=30.0) as client: 198 response = await client.post( 199 f"{url}/api/v1/message:stream", 200 json=payload, 201 headers=headers, 202 ) 203 _debug(f"Response status: {response.status_code}") 204 response.raise_for_status() 205 result = response.json() 206 except httpx.ConnectError: 207 click.echo(click.style(f"Failed to connect to {url}", fg="red"), err=True) 208 return 1 209 except httpx.HTTPStatusError as e: 210 click.echo( 211 click.style(f"HTTP error {e.response.status_code}: {e.response.text}", fg="red"), 212 err=True, 213 ) 214 return 1 215 216 # Extract task ID from response 217 task_result = result.get("result", {}) 218 task_id = task_result.get("id") 219 220 if not task_id: 221 click.echo(click.style(f"No task ID in response: {result}", fg="red"), err=True) 222 return 1 223 224 if not quiet: 225 click.echo(click.style(f"Task ID: {task_id}", fg="blue")) 226 click.echo() 227 228 # Create output directory if not provided 229 if output_dir is None: 230 output_dir = Path(f"/tmp/sam-task-{task_id}") 231 output_dir.mkdir(parents=True, exist_ok=True) 232 233 # Initialize components 234 assembler = MessageAssembler() 235 recorder = EventRecorder(output_dir) 236 sse_client = SSEClient(url, timeout, token, debug=debug) 237 238 _debug(f"Subscribing to SSE events for task {task_id}") 239 240 # Track response text for saving 241 response_text_parts = [] 242 243 # Subscribe to SSE and process events 244 try: 245 async for event in sse_client.subscribe(task_id): 246 recorder.record_event(event.event_type, event.data) 247 248 msg, new_text = assembler.process_event(event.event_type, event.data) 249 250 if new_text: 251 response_text_parts.append(new_text) 252 if not quiet: 253 click.echo(new_text, nl=False) 254 sys.stdout.flush() 255 256 if msg.is_complete: 257 _debug("Task is complete") 258 break 259 260 except httpx.HTTPStatusError as e: 261 click.echo(click.style(f"\nSSE connection error: {e}", fg="red"), err=True) 262 return 1 263 except asyncio.TimeoutError as e: 264 click.echo(click.style(f"\nTimeout: {e}", fg="yellow"), err=True) 265 except Exception as e: 266 click.echo(click.style(f"\nSSE error: {e}", fg="red"), err=True) 267 return 1 268 269 # Ensure newline after streaming 270 if not quiet and response_text_parts: 271 click.echo() 272 273 # Get final message state 274 final_msg = assembler.get_message() 275 276 # Save recorded events 277 recorder.save() 278 279 # Save response text 280 response_text = "".join(response_text_parts) 281 response_path = output_dir / "response.txt" 282 with open(response_path, "w") as f: 283 f.write(response_text) 284 285 # Download artifacts 286 artifact_handler = ArtifactHandler(url, effective_session_id, output_dir, token) 287 try: 288 downloaded_artifacts = await artifact_handler.download_all_artifacts() 289 if downloaded_artifacts and not quiet: 290 click.echo() 291 click.echo(click.style("Downloaded artifacts:", fg="green")) 292 for artifact in downloaded_artifacts: 293 click.echo(f" - {artifact.filename} ({artifact.size} bytes)") 294 except Exception as e: 295 if not quiet: 296 click.echo(click.style(f"Warning: Could not download artifacts: {e}", fg="yellow")) 297 298 # Fetch STIM file 299 if not no_stim: 300 try: 301 await download_stim_file(url, task_id, output_dir, headers) 302 except Exception as e: 303 if not quiet: 304 click.echo(click.style(f"Warning: Could not download STIM file: {e}", fg="yellow")) 305 306 # Print summary 307 click.echo() 308 click.echo(click.style("---", fg="cyan")) 309 310 if final_msg.is_error: 311 click.echo(click.style("Task failed.", fg="red", bold=True)) 312 exit_code = 1 313 else: 314 click.echo(click.style("Task completed successfully.", fg="green", bold=True)) 315 exit_code = 0 316 317 click.echo(f"Session ID: {click.style(effective_session_id, fg='cyan')}{session_hint}") 318 click.echo(f"Task ID: {task_id}") 319 click.echo(f"Output directory: {output_dir}") 320 click.echo(f"Events recorded: {recorder.get_event_count()}") 321 322 return exit_code