/ src / solace_agent_mesh / gateway / base / task_context.py
task_context.py
 1  """
 2  Manages context for tasks being processed by a gateway.
 3  """
 4  
 5  import logging
 6  import threading
 7  from typing import Callable, Dict, List, Optional, Any, Tuple
 8  
 9  log = logging.getLogger(__name__)
10  
11  
12  class TaskContextManager:
13      """
14      Stores and retrieves arbitrary context associated with an A2A task_id.
15  
16      This context is typically provided by a specific gateway implementation
17      (e.g., Slack channel/thread, HTTP session details) and is needed to
18      route responses back correctly to the external system.
19  
20      The manager is thread-safe.
21      """
22  
23      def __init__(self):
24          """Initializes the TaskContextManager."""
25          self._contexts: Dict[str, Dict[str, Any]] = {}
26          self._lock = threading.Lock()
27          log.debug("[TaskContextManager] Initialized.")
28  
29      def store_context(self, task_id: str, context_data: Dict[str, Any]) -> None:
30          """
31          Stores context data for a given task ID.
32  
33          Args:
34              task_id: The unique identifier for the task.
35              context_data: A dictionary containing the context to store.
36          """
37          with self._lock:
38              self._contexts[task_id] = context_data
39          log.debug("[TaskContextManager] Stored context for task_id: %s", task_id)
40  
41      def get_context(self, task_id: str) -> Optional[Dict[str, Any]]:
42          """
43          Retrieves the context data for a given task ID.
44  
45          Args:
46              task_id: The unique identifier for the task.
47  
48          Returns:
49              The context data dictionary if found, otherwise None.
50          """
51          with self._lock:
52              context = self._contexts.get(task_id)
53          log.debug(
54              "[TaskContextManager] Retrieved context for task_id: %s (Found: %s)",
55              task_id,
56              context is not None,
57          )
58          return context
59  
60      def remove_context(self, task_id: str) -> Optional[Dict[str, Any]]:
61          """Removes and returns the context data for a given task ID."""
62          with self._lock:
63              context = self._contexts.pop(task_id, None)
64          log.debug(
65              "[TaskContextManager] Removed context for task_id: %s (Found: %s)",
66              task_id,
67              context is not None,
68          )
69          return context
70  
71      def scan_contexts(
72          self, predicate: Callable[[str, Dict[str, Any]], bool]
73      ) -> List[Tuple[str, Dict[str, Any]]]:
74          """Returns (task_id, context) pairs where predicate(task_id, context) is True.
75  
76          Args:
77              predicate: A function that takes (task_id, context_data) and returns True
78                        for contexts that should be included in the result.
79  
80          Returns:
81              A list of (task_id, context_data) tuples matching the predicate.
82          """
83          with self._lock:
84              return [
85                  (tid, ctx.copy())
86                  for tid, ctx in self._contexts.items()
87                  if predicate(tid, ctx)
88              ]
89  
90      def clear_all_contexts_for_testing(self) -> None:
91          """Removes all stored contexts. For testing purposes."""
92          with self._lock:
93              self._contexts.clear()
94          log.debug("[TaskContextManager] All contexts cleared for testing.")