task_context.py
1 """ 2 Manages context for tasks being processed by a gateway. 3 """ 4 5 import logging 6 import threading 7 from typing import Callable, Dict, List, Optional, Any, Tuple 8 9 log = logging.getLogger(__name__) 10 11 12 class TaskContextManager: 13 """ 14 Stores and retrieves arbitrary context associated with an A2A task_id. 15 16 This context is typically provided by a specific gateway implementation 17 (e.g., Slack channel/thread, HTTP session details) and is needed to 18 route responses back correctly to the external system. 19 20 The manager is thread-safe. 21 """ 22 23 def __init__(self): 24 """Initializes the TaskContextManager.""" 25 self._contexts: Dict[str, Dict[str, Any]] = {} 26 self._lock = threading.Lock() 27 log.debug("[TaskContextManager] Initialized.") 28 29 def store_context(self, task_id: str, context_data: Dict[str, Any]) -> None: 30 """ 31 Stores context data for a given task ID. 32 33 Args: 34 task_id: The unique identifier for the task. 35 context_data: A dictionary containing the context to store. 36 """ 37 with self._lock: 38 self._contexts[task_id] = context_data 39 log.debug("[TaskContextManager] Stored context for task_id: %s", task_id) 40 41 def get_context(self, task_id: str) -> Optional[Dict[str, Any]]: 42 """ 43 Retrieves the context data for a given task ID. 44 45 Args: 46 task_id: The unique identifier for the task. 47 48 Returns: 49 The context data dictionary if found, otherwise None. 50 """ 51 with self._lock: 52 context = self._contexts.get(task_id) 53 log.debug( 54 "[TaskContextManager] Retrieved context for task_id: %s (Found: %s)", 55 task_id, 56 context is not None, 57 ) 58 return context 59 60 def remove_context(self, task_id: str) -> Optional[Dict[str, Any]]: 61 """Removes and returns the context data for a given task ID.""" 62 with self._lock: 63 context = self._contexts.pop(task_id, None) 64 log.debug( 65 "[TaskContextManager] Removed context for task_id: %s (Found: %s)", 66 task_id, 67 context is not None, 68 ) 69 return context 70 71 def scan_contexts( 72 self, predicate: Callable[[str, Dict[str, Any]], bool] 73 ) -> List[Tuple[str, Dict[str, Any]]]: 74 """Returns (task_id, context) pairs where predicate(task_id, context) is True. 75 76 Args: 77 predicate: A function that takes (task_id, context_data) and returns True 78 for contexts that should be included in the result. 79 80 Returns: 81 A list of (task_id, context_data) tuples matching the predicate. 82 """ 83 with self._lock: 84 return [ 85 (tid, ctx.copy()) 86 for tid, ctx in self._contexts.items() 87 if predicate(tid, ctx) 88 ] 89 90 def clear_all_contexts_for_testing(self) -> None: 91 """Removes all stored contexts. For testing purposes.""" 92 with self._lock: 93 self._contexts.clear() 94 log.debug("[TaskContextManager] All contexts cleared for testing.")