server.py
1 import asyncio 2 import threading 3 import time 4 from typing import Any, Dict, List, Optional 5 6 import uvicorn 7 from a2a.server.apps import A2AFastAPIApplication 8 from a2a.server.agent_execution import AgentExecutor 9 from a2a.server.request_handlers import DefaultRequestHandler 10 from a2a.server.tasks import InMemoryTaskStore 11 from a2a.types import AgentCard 12 from fastapi import FastAPI, Request 13 from starlette.responses import JSONResponse 14 from solace_ai_connector.common.log import log 15 16 17 class TestA2AAgentServer: 18 """ 19 Manages a runnable, in-process A2A agent for integration testing. 20 21 This server uses a DeclarativeAgentExecutor to respond to requests based on 22 directives provided in the test case, allowing for predictable and 23 controllable behavior of a downstream A2A agent. 24 """ 25 26 def __init__( 27 self, host: str, port: int, agent_card: AgentCard, agent_executor: AgentExecutor 28 ): 29 # 2.2.2: __init__ accepts host, port, and AgentCard 30 self.host = host 31 self.port = port 32 self.agent_card = agent_card 33 self.agent_executor = agent_executor 34 35 # 2.2.3: Initialize instance variables 36 self._uvicorn_server: Optional[uvicorn.Server] = None 37 self._server_thread: Optional[threading.Thread] = None 38 self.captured_requests: List[Dict[str, Any]] = [] 39 self._stateful_responses_cache: Dict[str, List[Any]] = {} 40 self._stateful_cache_lock = threading.Lock() 41 self._primed_responses: List[Dict[str, Any]] = [] 42 self._primed_responses_lock = threading.Lock() 43 44 # Auth testing state 45 self._auth_validation_enabled = False 46 self._expected_auth_type: Optional[str] = None # "bearer", "apikey", None 47 self._expected_auth_value: Optional[str] = None 48 self._auth_should_fail_once = False # For testing retry logic 49 self._auth_failure_count = 0 50 self._captured_auth_headers: List[Dict[str, str]] = [] 51 52 # HTTP error simulation state 53 self._http_error_config: Optional[Dict[str, Any]] = None 54 55 # 2.3: A2A Application Setup 56 # 2.3.2: Instantiate InMemoryTaskStore 57 task_store = InMemoryTaskStore() 58 59 # 2.3.3: Instantiate DefaultRequestHandler 60 handler = DefaultRequestHandler( 61 agent_executor=self.agent_executor, task_store=task_store 62 ) 63 64 # 2.3.4: Instantiate A2AFastAPIApplication 65 a2a_app_builder = A2AFastAPIApplication( 66 agent_card=self.agent_card, http_handler=handler 67 ) 68 69 # 2.3.5: Build the FastAPI app 70 self.app: FastAPI = a2a_app_builder.build(rpc_url="/a2a") 71 72 # 2.3.6: Update the agent card with the correct, full URL 73 self.agent_card.url = f"{self.url}/a2a" 74 75 # 2.3.7: Add request capture middleware 76 @self.app.middleware("http") 77 async def capture_request_middleware(request: Request, call_next): 78 if request.url.path == "/a2a": 79 try: 80 body = await request.json() 81 self.captured_requests.append(body) 82 log.debug( 83 "[TestA2AAgentServer] Captured request: %s", 84 body.get("method"), 85 ) 86 except Exception as e: 87 log.error( 88 "[TestA2AAgentServer] Failed to capture request body: %s", e 89 ) 90 response = await call_next(request) 91 return response 92 93 # 2.3.7b: Add HTTP error simulation middleware (runs before other middleware) 94 @self.app.middleware("http") 95 async def http_error_simulation_middleware(request: Request, call_next): 96 # Only simulate errors for A2A endpoint 97 if request.url.path == "/a2a" and self._http_error_config: 98 config = self._http_error_config 99 self._http_error_config = None # One-time use 100 log.info( 101 "[TestA2AAgentServer] Simulating HTTP error: status=%d", 102 config["status_code"], 103 ) 104 return JSONResponse( 105 status_code=config["status_code"], 106 content=config.get( 107 "error_body", {"error": f"HTTP {config['status_code']}"} 108 ), 109 ) 110 return await call_next(request) 111 112 # 2.3.8: Add auth validation middleware 113 @self.app.middleware("http") 114 async def auth_validation_middleware(request: Request, call_next): 115 # Skip validation for non-A2A endpoints 116 if request.url.path != "/a2a": 117 return await call_next(request) 118 119 # Capture auth headers for test assertions 120 auth_header = request.headers.get("Authorization", "") 121 apikey_header = request.headers.get("X-API-Key", "") 122 123 self._captured_auth_headers.append( 124 { 125 "authorization": auth_header, 126 "x_api_key": apikey_header, 127 "path": request.url.path, 128 "timestamp": time.time(), 129 } 130 ) 131 132 # If auth validation is disabled, just pass through 133 if not self._auth_validation_enabled: 134 return await call_next(request) 135 136 # Test retry logic: fail once, then succeed 137 if self._auth_should_fail_once and self._auth_failure_count == 0: 138 self._auth_failure_count += 1 139 log.info( 140 "[TestA2AAgentServer] Simulating 401 for retry test (first attempt)" 141 ) 142 return JSONResponse( 143 status_code=401, 144 content={ 145 "error": "unauthorized", 146 "message": "Invalid or expired token", 147 }, 148 ) 149 150 # Validate bearer token 151 if self._expected_auth_type == "bearer": 152 if not auth_header.startswith("Bearer "): 153 log.warning( 154 "[TestA2AAgentServer] Missing or malformed Bearer token" 155 ) 156 return JSONResponse( 157 status_code=401, 158 content={ 159 "error": "unauthorized", 160 "message": "Bearer token required", 161 }, 162 ) 163 164 token = auth_header.replace("Bearer ", "") 165 if self._expected_auth_value and token != self._expected_auth_value: 166 log.warning( 167 "[TestA2AAgentServer] Invalid token. Expected '%s', got '%s'", 168 self._expected_auth_value, 169 token, 170 ) 171 return JSONResponse( 172 status_code=401, 173 content={"error": "unauthorized", "message": "Invalid token"}, 174 ) 175 176 # Validate API key 177 elif self._expected_auth_type == "apikey": 178 if not apikey_header: 179 log.warning("[TestA2AAgentServer] Missing API key") 180 return JSONResponse( 181 status_code=401, 182 content={ 183 "error": "unauthorized", 184 "message": "API key required", 185 }, 186 ) 187 188 if ( 189 self._expected_auth_value 190 and apikey_header != self._expected_auth_value 191 ): 192 log.warning("[TestA2AAgentServer] Invalid API key") 193 return JSONResponse( 194 status_code=401, 195 content={"error": "unauthorized", "message": "Invalid API key"}, 196 ) 197 198 # Auth validation passed 199 return await call_next(request) 200 201 @property 202 def url(self) -> str: 203 """Returns the base URL of the running server.""" 204 return f"http://{self.host}:{self.port}" 205 206 @property 207 def started(self) -> bool: 208 """Checks if the uvicorn server instance is started.""" 209 return self._uvicorn_server is not None and self._uvicorn_server.started 210 211 def start(self): 212 """Starts the FastAPI server in a separate thread.""" 213 if self._server_thread is not None and self._server_thread.is_alive(): 214 log.warning("[TestA2AAgentServer] Server is already running.") 215 return 216 217 self.clear_captured_requests() 218 self.clear_stateful_cache() 219 self.clear_primed_responses() 220 221 config = uvicorn.Config( 222 self.app, host=self.host, port=self.port, log_level="warning" 223 ) 224 self._uvicorn_server = uvicorn.Server(config) 225 226 async def async_serve_wrapper(): 227 try: 228 if self._uvicorn_server: 229 await self._uvicorn_server.serve() 230 except asyncio.CancelledError: 231 log.info("[TestA2AAgentServer] Server.serve() task was cancelled.") 232 except Exception as e: 233 log.error( 234 f"[TestA2AAgentServer] Error during server.serve(): {e}", 235 exc_info=True, 236 ) 237 238 def run_server_in_new_loop(): 239 loop = asyncio.new_event_loop() 240 asyncio.set_event_loop(loop) 241 try: 242 loop.run_until_complete(async_serve_wrapper()) 243 finally: 244 try: 245 all_tasks = asyncio.all_tasks(loop) 246 if all_tasks: 247 for task in all_tasks: 248 task.cancel() 249 loop.run_until_complete( 250 asyncio.gather(*all_tasks, return_exceptions=True) 251 ) 252 if hasattr(loop, "shutdown_asyncgens"): 253 loop.run_until_complete(loop.shutdown_asyncgens()) 254 except Exception as e: 255 log.error( 256 f"[TestA2AAgentServer] Error during loop shutdown: {e}", 257 exc_info=True, 258 ) 259 finally: 260 loop.close() 261 log.info("[TestA2AAgentServer] Event loop in server thread closed.") 262 263 self._server_thread = threading.Thread( 264 target=run_server_in_new_loop, daemon=True 265 ) 266 self._server_thread.start() 267 log.info(f"[TestA2AAgentServer] Starting on http://{self.host}:{self.port}...") 268 269 def stop(self): 270 """Stops the FastAPI server.""" 271 if self._uvicorn_server: 272 self._uvicorn_server.should_exit = True 273 274 if self._server_thread and self._server_thread.is_alive(): 275 log.info("[TestA2AAgentServer] Stopping, joining thread...") 276 self._server_thread.join(timeout=5.0) 277 if self._server_thread.is_alive(): 278 log.warning("[TestA2AAgentServer] Server thread did not exit cleanly.") 279 self._server_thread = None 280 self._uvicorn_server = None 281 self.clear_primed_responses() 282 self.clear_auth_state() 283 log.info("[TestA2AAgentServer] Stopped.") 284 285 def clear_captured_requests(self): 286 """Clears the list of captured requests.""" 287 self.captured_requests.clear() 288 289 def prime_responses(self, responses: List[Dict[str, Any]]): 290 """ 291 Primes the server with a sequence of responses to serve for subsequent requests. 292 Each call to this method overwrites any previously primed responses. 293 """ 294 with self._primed_responses_lock: 295 self._primed_responses = list(responses) 296 log.info( 297 "[TestA2AAgentServer] Primed with %d responses.", 298 len(self._primed_responses), 299 ) 300 301 def get_next_primed_response(self) -> Optional[Dict[str, Any]]: 302 """ 303 Retrieves the next available primed response in a thread-safe manner. 304 This is intended to be called by the agent executor. 305 """ 306 with self._primed_responses_lock: 307 if self._primed_responses: 308 response = self._primed_responses.pop(0) 309 log.debug( 310 "[TestA2AAgentServer] Consumed primed response. %d remaining.", 311 len(self._primed_responses), 312 ) 313 return response 314 return None 315 316 def clear_primed_responses(self): 317 """Clears the primed response queue.""" 318 with self._primed_responses_lock: 319 self._primed_responses.clear() 320 log.debug("[TestA2AAgentServer] Cleared primed responses.") 321 322 def configure_auth_validation( 323 self, 324 enabled: bool = True, 325 auth_type: Optional[str] = None, 326 expected_value: Optional[str] = None, 327 should_fail_once: bool = False, 328 ): 329 """ 330 Configures authentication validation for testing. 331 332 Args: 333 enabled: Whether to validate auth headers 334 auth_type: "bearer" or "apikey" 335 expected_value: The expected token/key value 336 should_fail_once: If True, first request returns 401, subsequent succeed 337 """ 338 self._auth_validation_enabled = enabled 339 self._expected_auth_type = auth_type 340 self._expected_auth_value = expected_value 341 self._auth_should_fail_once = should_fail_once 342 self._auth_failure_count = 0 343 log.info( 344 "[TestA2AAgentServer] Auth validation configured: " 345 "enabled=%s, type=%s, fail_once=%s", 346 enabled, 347 auth_type, 348 should_fail_once, 349 ) 350 351 def get_captured_auth_headers(self) -> List[Dict[str, str]]: 352 """Returns all captured authentication headers for test assertions.""" 353 return self._captured_auth_headers.copy() 354 355 def clear_auth_state(self): 356 """Clears all auth-related test state.""" 357 self._auth_validation_enabled = False 358 self._expected_auth_type = None 359 self._expected_auth_value = None 360 self._auth_should_fail_once = False 361 self._auth_failure_count = 0 362 self._captured_auth_headers.clear() 363 log.debug("[TestA2AAgentServer] Auth state cleared") 364 365 def clear_stateful_cache(self): 366 """Clears the stateful response cache.""" 367 with self._stateful_cache_lock: 368 self._stateful_responses_cache.clear() 369 370 def configure_http_error_response( 371 self, status_code: int, error_body: Optional[Dict[str, Any]] = None 372 ): 373 """ 374 Configures the server to return an HTTP error for the next request. 375 376 This is a one-time configuration - after returning the error once, 377 the server returns to normal operation. 378 379 Args: 380 status_code: HTTP status code to return (e.g., 500, 503) 381 error_body: Optional JSON body to return with the error 382 """ 383 self._http_error_config = { 384 "status_code": status_code, 385 "error_body": error_body or {"error": f"HTTP {status_code}"}, 386 } 387 log.info( 388 "[TestA2AAgentServer] Configured to return HTTP %d on next request", 389 status_code, 390 ) 391 392 def clear_captured_auth_headers(self): 393 """Clears the captured authentication headers list.""" 394 self._captured_auth_headers.clear() 395 log.debug("[TestA2AAgentServer] Cleared captured auth headers.") 396 397 def get_cancel_requests(self) -> List[Dict[str, Any]]: 398 """Returns all captured cancel requests.""" 399 return [ 400 req for req in self.captured_requests if req.get("method") == "tasks/cancel" 401 ] 402 403 def was_cancel_requested_for_task(self, task_id: str) -> bool: 404 """Checks if a cancel request was received for a specific task ID.""" 405 cancel_requests = self.get_cancel_requests() 406 for req in cancel_requests: 407 params = req.get("params", {}) 408 if params.get("id") == task_id: 409 return True 410 return False