server.py
  1  import asyncio
  2  import threading
  3  import time
  4  from typing import Any, Dict, List, Optional
  5  
  6  import uvicorn
  7  from a2a.server.apps import A2AFastAPIApplication
  8  from a2a.server.agent_execution import AgentExecutor
  9  from a2a.server.request_handlers import DefaultRequestHandler
 10  from a2a.server.tasks import InMemoryTaskStore
 11  from a2a.types import AgentCard
 12  from fastapi import FastAPI, Request
 13  from starlette.responses import JSONResponse
 14  from solace_ai_connector.common.log import log
 15  
 16  
 17  class TestA2AAgentServer:
 18      """
 19      Manages a runnable, in-process A2A agent for integration testing.
 20  
 21      This server uses a DeclarativeAgentExecutor to respond to requests based on
 22      directives provided in the test case, allowing for predictable and
 23      controllable behavior of a downstream A2A agent.
 24      """
 25  
 26      def __init__(
 27          self, host: str, port: int, agent_card: AgentCard, agent_executor: AgentExecutor
 28      ):
 29          # 2.2.2: __init__ accepts host, port, and AgentCard
 30          self.host = host
 31          self.port = port
 32          self.agent_card = agent_card
 33          self.agent_executor = agent_executor
 34  
 35          # 2.2.3: Initialize instance variables
 36          self._uvicorn_server: Optional[uvicorn.Server] = None
 37          self._server_thread: Optional[threading.Thread] = None
 38          self.captured_requests: List[Dict[str, Any]] = []
 39          self._stateful_responses_cache: Dict[str, List[Any]] = {}
 40          self._stateful_cache_lock = threading.Lock()
 41          self._primed_responses: List[Dict[str, Any]] = []
 42          self._primed_responses_lock = threading.Lock()
 43  
 44          # Auth testing state
 45          self._auth_validation_enabled = False
 46          self._expected_auth_type: Optional[str] = None  # "bearer", "apikey", None
 47          self._expected_auth_value: Optional[str] = None
 48          self._auth_should_fail_once = False  # For testing retry logic
 49          self._auth_failure_count = 0
 50          self._captured_auth_headers: List[Dict[str, str]] = []
 51  
 52          # HTTP error simulation state
 53          self._http_error_config: Optional[Dict[str, Any]] = None
 54  
 55          # 2.3: A2A Application Setup
 56          # 2.3.2: Instantiate InMemoryTaskStore
 57          task_store = InMemoryTaskStore()
 58  
 59          # 2.3.3: Instantiate DefaultRequestHandler
 60          handler = DefaultRequestHandler(
 61              agent_executor=self.agent_executor, task_store=task_store
 62          )
 63  
 64          # 2.3.4: Instantiate A2AFastAPIApplication
 65          a2a_app_builder = A2AFastAPIApplication(
 66              agent_card=self.agent_card, http_handler=handler
 67          )
 68  
 69          # 2.3.5: Build the FastAPI app
 70          self.app: FastAPI = a2a_app_builder.build(rpc_url="/a2a")
 71  
 72          # 2.3.6: Update the agent card with the correct, full URL
 73          self.agent_card.url = f"{self.url}/a2a"
 74  
 75          # 2.3.7: Add request capture middleware
 76          @self.app.middleware("http")
 77          async def capture_request_middleware(request: Request, call_next):
 78              if request.url.path == "/a2a":
 79                  try:
 80                      body = await request.json()
 81                      self.captured_requests.append(body)
 82                      log.debug(
 83                          "[TestA2AAgentServer] Captured request: %s",
 84                          body.get("method"),
 85                      )
 86                  except Exception as e:
 87                      log.error(
 88                          "[TestA2AAgentServer] Failed to capture request body: %s", e
 89                      )
 90              response = await call_next(request)
 91              return response
 92  
 93          # 2.3.7b: Add HTTP error simulation middleware (runs before other middleware)
 94          @self.app.middleware("http")
 95          async def http_error_simulation_middleware(request: Request, call_next):
 96              # Only simulate errors for A2A endpoint
 97              if request.url.path == "/a2a" and self._http_error_config:
 98                  config = self._http_error_config
 99                  self._http_error_config = None  # One-time use
100                  log.info(
101                      "[TestA2AAgentServer] Simulating HTTP error: status=%d",
102                      config["status_code"],
103                  )
104                  return JSONResponse(
105                      status_code=config["status_code"],
106                      content=config.get(
107                          "error_body", {"error": f"HTTP {config['status_code']}"}
108                      ),
109                  )
110              return await call_next(request)
111  
112          # 2.3.8: Add auth validation middleware
113          @self.app.middleware("http")
114          async def auth_validation_middleware(request: Request, call_next):
115              # Skip validation for non-A2A endpoints
116              if request.url.path != "/a2a":
117                  return await call_next(request)
118  
119              # Capture auth headers for test assertions
120              auth_header = request.headers.get("Authorization", "")
121              apikey_header = request.headers.get("X-API-Key", "")
122  
123              self._captured_auth_headers.append(
124                  {
125                      "authorization": auth_header,
126                      "x_api_key": apikey_header,
127                      "path": request.url.path,
128                      "timestamp": time.time(),
129                  }
130              )
131  
132              # If auth validation is disabled, just pass through
133              if not self._auth_validation_enabled:
134                  return await call_next(request)
135  
136              # Test retry logic: fail once, then succeed
137              if self._auth_should_fail_once and self._auth_failure_count == 0:
138                  self._auth_failure_count += 1
139                  log.info(
140                      "[TestA2AAgentServer] Simulating 401 for retry test (first attempt)"
141                  )
142                  return JSONResponse(
143                      status_code=401,
144                      content={
145                          "error": "unauthorized",
146                          "message": "Invalid or expired token",
147                      },
148                  )
149  
150              # Validate bearer token
151              if self._expected_auth_type == "bearer":
152                  if not auth_header.startswith("Bearer "):
153                      log.warning(
154                          "[TestA2AAgentServer] Missing or malformed Bearer token"
155                      )
156                      return JSONResponse(
157                          status_code=401,
158                          content={
159                              "error": "unauthorized",
160                              "message": "Bearer token required",
161                          },
162                      )
163  
164                  token = auth_header.replace("Bearer ", "")
165                  if self._expected_auth_value and token != self._expected_auth_value:
166                      log.warning(
167                          "[TestA2AAgentServer] Invalid token. Expected '%s', got '%s'",
168                          self._expected_auth_value,
169                          token,
170                      )
171                      return JSONResponse(
172                          status_code=401,
173                          content={"error": "unauthorized", "message": "Invalid token"},
174                      )
175  
176              # Validate API key
177              elif self._expected_auth_type == "apikey":
178                  if not apikey_header:
179                      log.warning("[TestA2AAgentServer] Missing API key")
180                      return JSONResponse(
181                          status_code=401,
182                          content={
183                              "error": "unauthorized",
184                              "message": "API key required",
185                          },
186                      )
187  
188                  if (
189                      self._expected_auth_value
190                      and apikey_header != self._expected_auth_value
191                  ):
192                      log.warning("[TestA2AAgentServer] Invalid API key")
193                      return JSONResponse(
194                          status_code=401,
195                          content={"error": "unauthorized", "message": "Invalid API key"},
196                      )
197  
198              # Auth validation passed
199              return await call_next(request)
200  
201      @property
202      def url(self) -> str:
203          """Returns the base URL of the running server."""
204          return f"http://{self.host}:{self.port}"
205  
206      @property
207      def started(self) -> bool:
208          """Checks if the uvicorn server instance is started."""
209          return self._uvicorn_server is not None and self._uvicorn_server.started
210  
211      def start(self):
212          """Starts the FastAPI server in a separate thread."""
213          if self._server_thread is not None and self._server_thread.is_alive():
214              log.warning("[TestA2AAgentServer] Server is already running.")
215              return
216  
217          self.clear_captured_requests()
218          self.clear_stateful_cache()
219          self.clear_primed_responses()
220  
221          config = uvicorn.Config(
222              self.app, host=self.host, port=self.port, log_level="warning"
223          )
224          self._uvicorn_server = uvicorn.Server(config)
225  
226          async def async_serve_wrapper():
227              try:
228                  if self._uvicorn_server:
229                      await self._uvicorn_server.serve()
230              except asyncio.CancelledError:
231                  log.info("[TestA2AAgentServer] Server.serve() task was cancelled.")
232              except Exception as e:
233                  log.error(
234                      f"[TestA2AAgentServer] Error during server.serve(): {e}",
235                      exc_info=True,
236                  )
237  
238          def run_server_in_new_loop():
239              loop = asyncio.new_event_loop()
240              asyncio.set_event_loop(loop)
241              try:
242                  loop.run_until_complete(async_serve_wrapper())
243              finally:
244                  try:
245                      all_tasks = asyncio.all_tasks(loop)
246                      if all_tasks:
247                          for task in all_tasks:
248                              task.cancel()
249                          loop.run_until_complete(
250                              asyncio.gather(*all_tasks, return_exceptions=True)
251                          )
252                      if hasattr(loop, "shutdown_asyncgens"):
253                          loop.run_until_complete(loop.shutdown_asyncgens())
254                  except Exception as e:
255                      log.error(
256                          f"[TestA2AAgentServer] Error during loop shutdown: {e}",
257                          exc_info=True,
258                      )
259                  finally:
260                      loop.close()
261                      log.info("[TestA2AAgentServer] Event loop in server thread closed.")
262  
263          self._server_thread = threading.Thread(
264              target=run_server_in_new_loop, daemon=True
265          )
266          self._server_thread.start()
267          log.info(f"[TestA2AAgentServer] Starting on http://{self.host}:{self.port}...")
268  
269      def stop(self):
270          """Stops the FastAPI server."""
271          if self._uvicorn_server:
272              self._uvicorn_server.should_exit = True
273  
274          if self._server_thread and self._server_thread.is_alive():
275              log.info("[TestA2AAgentServer] Stopping, joining thread...")
276              self._server_thread.join(timeout=5.0)
277              if self._server_thread.is_alive():
278                  log.warning("[TestA2AAgentServer] Server thread did not exit cleanly.")
279          self._server_thread = None
280          self._uvicorn_server = None
281          self.clear_primed_responses()
282          self.clear_auth_state()
283          log.info("[TestA2AAgentServer] Stopped.")
284  
285      def clear_captured_requests(self):
286          """Clears the list of captured requests."""
287          self.captured_requests.clear()
288  
289      def prime_responses(self, responses: List[Dict[str, Any]]):
290          """
291          Primes the server with a sequence of responses to serve for subsequent requests.
292          Each call to this method overwrites any previously primed responses.
293          """
294          with self._primed_responses_lock:
295              self._primed_responses = list(responses)
296              log.info(
297                  "[TestA2AAgentServer] Primed with %d responses.",
298                  len(self._primed_responses),
299              )
300  
301      def get_next_primed_response(self) -> Optional[Dict[str, Any]]:
302          """
303          Retrieves the next available primed response in a thread-safe manner.
304          This is intended to be called by the agent executor.
305          """
306          with self._primed_responses_lock:
307              if self._primed_responses:
308                  response = self._primed_responses.pop(0)
309                  log.debug(
310                      "[TestA2AAgentServer] Consumed primed response. %d remaining.",
311                      len(self._primed_responses),
312                  )
313                  return response
314          return None
315  
316      def clear_primed_responses(self):
317          """Clears the primed response queue."""
318          with self._primed_responses_lock:
319              self._primed_responses.clear()
320              log.debug("[TestA2AAgentServer] Cleared primed responses.")
321  
322      def configure_auth_validation(
323          self,
324          enabled: bool = True,
325          auth_type: Optional[str] = None,
326          expected_value: Optional[str] = None,
327          should_fail_once: bool = False,
328      ):
329          """
330          Configures authentication validation for testing.
331  
332          Args:
333              enabled: Whether to validate auth headers
334              auth_type: "bearer" or "apikey"
335              expected_value: The expected token/key value
336              should_fail_once: If True, first request returns 401, subsequent succeed
337          """
338          self._auth_validation_enabled = enabled
339          self._expected_auth_type = auth_type
340          self._expected_auth_value = expected_value
341          self._auth_should_fail_once = should_fail_once
342          self._auth_failure_count = 0
343          log.info(
344              "[TestA2AAgentServer] Auth validation configured: "
345              "enabled=%s, type=%s, fail_once=%s",
346              enabled,
347              auth_type,
348              should_fail_once,
349          )
350  
351      def get_captured_auth_headers(self) -> List[Dict[str, str]]:
352          """Returns all captured authentication headers for test assertions."""
353          return self._captured_auth_headers.copy()
354  
355      def clear_auth_state(self):
356          """Clears all auth-related test state."""
357          self._auth_validation_enabled = False
358          self._expected_auth_type = None
359          self._expected_auth_value = None
360          self._auth_should_fail_once = False
361          self._auth_failure_count = 0
362          self._captured_auth_headers.clear()
363          log.debug("[TestA2AAgentServer] Auth state cleared")
364  
365      def clear_stateful_cache(self):
366          """Clears the stateful response cache."""
367          with self._stateful_cache_lock:
368              self._stateful_responses_cache.clear()
369  
370      def configure_http_error_response(
371          self, status_code: int, error_body: Optional[Dict[str, Any]] = None
372      ):
373          """
374          Configures the server to return an HTTP error for the next request.
375  
376          This is a one-time configuration - after returning the error once,
377          the server returns to normal operation.
378  
379          Args:
380              status_code: HTTP status code to return (e.g., 500, 503)
381              error_body: Optional JSON body to return with the error
382          """
383          self._http_error_config = {
384              "status_code": status_code,
385              "error_body": error_body or {"error": f"HTTP {status_code}"},
386          }
387          log.info(
388              "[TestA2AAgentServer] Configured to return HTTP %d on next request",
389              status_code,
390          )
391  
392      def clear_captured_auth_headers(self):
393          """Clears the captured authentication headers list."""
394          self._captured_auth_headers.clear()
395          log.debug("[TestA2AAgentServer] Cleared captured auth headers.")
396  
397      def get_cancel_requests(self) -> List[Dict[str, Any]]:
398          """Returns all captured cancel requests."""
399          return [
400              req for req in self.captured_requests if req.get("method") == "tasks/cancel"
401          ]
402  
403      def was_cancel_requested_for_task(self, task_id: str) -> bool:
404          """Checks if a cancel request was received for a specific task ID."""
405          cancel_requests = self.get_cancel_requests()
406          for req in cancel_requests:
407              params = req.get("params", {})
408              if params.get("id") == task_id:
409                  return True
410          return False