task_execution_context.py
1 """ 2 Encapsulates the runtime state for a single, in-flight agent task. 3 """ 4 5 import asyncio 6 import threading 7 from typing import Any, Dict, List, Optional, TYPE_CHECKING 8 9 if TYPE_CHECKING: 10 from solace_ai_connector.common.message import Message as SolaceMessage 11 12 13 class TaskExecutionContext: 14 """ 15 A class to hold all runtime state and control mechanisms for a single agent task. 16 This object is created when a task is initiated and destroyed when it completes, 17 ensuring that all state is properly encapsulated and cleaned up. 18 """ 19 20 def __init__(self, task_id: str, a2a_context: Dict[str, Any]): 21 """ 22 Initializes the TaskExecutionContext. 23 24 Args: 25 task_id: The unique logical ID of the task. 26 a2a_context: The original, rich context dictionary from the A2A request. 27 """ 28 self.task_id: str = task_id 29 self.a2a_context: Dict[str, Any] = a2a_context 30 self.cancellation_event: asyncio.Event = asyncio.Event() 31 self.streaming_buffer: str = "" 32 self.run_based_response_buffer: str = "" 33 self.active_peer_sub_tasks: Dict[str, Dict[str, Any]] = {} 34 self.parallel_tool_calls: Dict[str, Dict[str, Any]] = {} 35 self.produced_artifacts: List[Dict[str, Any]] = [] 36 self.artifact_signals_to_return: List[Dict[str, Any]] = [] 37 self.event_loop: Optional[asyncio.AbstractEventLoop] = None 38 self.lock: threading.Lock = threading.Lock() 39 self.is_paused: bool = False # Track if task is paused waiting for peer/auth 40 41 # Token usage tracking 42 self.total_input_tokens: int = 0 43 self.total_output_tokens: int = 0 44 self.total_cached_input_tokens: int = 0 45 # Prompt tokens from the most recent agent LLM call — represents the 46 # peak context window occupancy for this task, unlike total_input_tokens 47 # which sums across every turn and inflates with each peer delegation. 48 self.last_input_tokens: int = 0 49 self.token_usage_by_model: Dict[str, Dict[str, int]] = {} 50 self.token_usage_by_source: Dict[str, Dict[str, int]] = {} 51 52 # Generic flags storage for task-level state 53 self._flags: Dict[str, Any] = {} 54 55 # Generic security storage (enterprise use only) 56 self._security_context: Dict[str, Any] = {} 57 58 # Original Solace message for ACK/NACK operations 59 # Stored here instead of a2a_context to avoid serialization issues 60 self._original_solace_message: Optional["SolaceMessage"] = None 61 62 # Turn tracking for proper spacing between LLM turns 63 self._current_invocation_id: Optional[str] = None 64 self._first_text_seen_in_turn: bool = False 65 self._need_spacing_before_next_text: bool = False 66 67 def cancel(self) -> None: 68 """Signals that the task should be cancelled.""" 69 self.cancellation_event.set() 70 71 def is_cancelled(self) -> bool: 72 """Checks if the cancellation event has been set.""" 73 return self.cancellation_event.is_set() 74 75 def set_paused(self, paused: bool) -> None: 76 """ 77 Marks the task as paused (waiting for peer response or user input). 78 79 Args: 80 paused: True if task is paused, False if resuming. 81 """ 82 with self.lock: 83 self.is_paused = paused 84 85 def get_is_paused(self) -> bool: 86 """ 87 Checks if the task is currently paused. 88 89 Returns: 90 True if task is paused (waiting for peer/auth), False otherwise. 91 """ 92 with self.lock: 93 return self.is_paused 94 95 def append_to_streaming_buffer(self, text: str) -> None: 96 """Appends a chunk of text to the main streaming buffer.""" 97 with self.lock: 98 self.streaming_buffer += text 99 100 def flush_streaming_buffer(self) -> str: 101 """Returns the entire content of the streaming buffer and clears it.""" 102 with self.lock: 103 content = self.streaming_buffer 104 self.streaming_buffer = "" 105 return content 106 107 def get_streaming_buffer_content(self) -> str: 108 """Returns the current buffer content without clearing it.""" 109 with self.lock: 110 return self.streaming_buffer 111 112 def append_to_run_based_buffer(self, text: str) -> None: 113 """Appends a chunk of processed text to the run-based response buffer.""" 114 with self.lock: 115 self.run_based_response_buffer += text 116 117 def register_peer_sub_task( 118 self, sub_task_id: str, correlation_data: Dict[str, Any] 119 ) -> None: 120 """Adds a new peer sub-task's correlation data to the tracking dictionary.""" 121 with self.lock: 122 self.active_peer_sub_tasks[sub_task_id] = correlation_data 123 124 def claim_sub_task_completion(self, sub_task_id: str) -> Optional[Dict[str, Any]]: 125 """ 126 Atomically retrieves and removes a sub-task's correlation data. 127 This is the core atomic operation to prevent race conditions. 128 Returns the correlation data if the claim is successful, otherwise None. 129 """ 130 with self.lock: 131 return self.active_peer_sub_tasks.pop(sub_task_id, None) 132 133 def register_parallel_call_sent(self, invocation_id: str) -> None: 134 """ 135 Registers that a new parallel tool call has been sent for a specific invocation. 136 Initializes the tracking dictionary for the invocation if it's the first call, 137 otherwise increments the total. 138 """ 139 with self.lock: 140 if invocation_id not in self.parallel_tool_calls: 141 self.parallel_tool_calls[invocation_id] = { 142 "total": 1, 143 "completed": 0, 144 "results": [], 145 } 146 else: 147 self.parallel_tool_calls[invocation_id]["total"] += 1 148 149 def handle_peer_timeout( 150 self, 151 sub_task_id: str, 152 correlation_data: Dict, 153 timeout_sec: int, 154 invocation_id: str, 155 ) -> bool: 156 """ 157 Handles a timeout for a specific peer sub-task for a given invocation. 158 159 Updates the parallel call tracker with a formatted error message and returns 160 True if all peer calls for that invocation are now complete. 161 162 Args: 163 sub_task_id: The ID of the sub-task that timed out. 164 correlation_data: The correlation data associated with the sub-task. 165 timeout_sec: The timeout duration in seconds. 166 invocation_id: The ID of the invocation that initiated the parallel calls. 167 168 Returns: 169 A boolean indicating if all parallel calls for the invocation are now complete. 170 """ 171 peer_tool_name = correlation_data.get("peer_tool_name", "unknown_tool") 172 timeout_message = f"Request to peer agent tool '{peer_tool_name}' timed out after {timeout_sec} seconds." 173 174 # The payload must be a dictionary with a 'result' key containing the simple string. 175 # This ensures the ADK framework presents it to the LLM as a simple text response. 176 simple_error_payload = {"result": timeout_message} 177 178 current_result = { 179 "adk_function_call_id": correlation_data.get("adk_function_call_id"), 180 "peer_tool_name": peer_tool_name, 181 "payload": simple_error_payload, 182 } 183 return self.record_parallel_result(current_result, invocation_id) 184 185 def record_parallel_result(self, result: Dict, invocation_id: str) -> bool: 186 """ 187 Records a result for a parallel tool call for a specific invocation 188 and returns True if all calls for that invocation are now complete. 189 """ 190 with self.lock: 191 invocation_state = self.parallel_tool_calls.get(invocation_id) 192 if not invocation_state: 193 # This can happen if a response arrives after a timeout has cleaned up. 194 return False 195 196 invocation_state["results"].append(result) 197 invocation_state["completed"] += 1 198 return invocation_state["completed"] >= invocation_state["total"] 199 200 def clear_parallel_invocation_state(self, invocation_id: str) -> None: 201 """ 202 Removes the state for a completed parallel tool call invocation. 203 """ 204 with self.lock: 205 if invocation_id in self.parallel_tool_calls: 206 del self.parallel_tool_calls[invocation_id] 207 208 def register_produced_artifact(self, filename: str, version: int) -> None: 209 """Adds a newly created artifact to the tracking list.""" 210 with self.lock: 211 self.produced_artifacts.append({"filename": filename, "version": version}) 212 213 def add_artifact_signal(self, signal: Dict[str, Any]) -> None: 214 """Adds an artifact return signal to the list in a thread-safe manner.""" 215 with self.lock: 216 self.artifact_signals_to_return.append(signal) 217 218 def get_and_clear_artifact_signals(self) -> List[Dict[str, Any]]: 219 """ 220 Retrieves all pending artifact signals and clears the list atomically. 221 """ 222 with self.lock: 223 signals = list(self.artifact_signals_to_return) # Create a copy 224 self.artifact_signals_to_return.clear() 225 return signals 226 227 def set_event_loop(self, loop: asyncio.AbstractEventLoop) -> None: 228 """Stores a reference to the task's event loop.""" 229 with self.lock: 230 self.event_loop = loop 231 232 def get_event_loop(self) -> Optional[asyncio.AbstractEventLoop]: 233 """Retrieves the stored event loop.""" 234 with self.lock: 235 return self.event_loop 236 237 def record_token_usage( 238 self, 239 input_tokens: int, 240 output_tokens: int, 241 model: str, 242 source: str = "agent", 243 tool_name: Optional[str] = None, 244 cached_input_tokens: int = 0, 245 max_input_tokens: Optional[int] = None, 246 ) -> None: 247 """ 248 Records token usage for an LLM call. 249 250 Args: 251 input_tokens: Number of input/prompt tokens. 252 output_tokens: Number of output/completion tokens. 253 model: Model identifier used for this call. 254 source: Source of the LLM call ("agent" or "tool"). 255 tool_name: Tool name if source is "tool". 256 cached_input_tokens: Number of cached input tokens (optional). 257 max_input_tokens: Context window size for this model (optional). 258 Stamped once per model so downstream consumers (e.g. the chat 259 context-usage indicator) can render a usage-vs-limit ratio 260 without cross-service lookups. 261 """ 262 with self.lock: 263 # Update totals 264 self.total_input_tokens += input_tokens 265 self.total_output_tokens += output_tokens 266 self.total_cached_input_tokens += cached_input_tokens 267 if source == "agent": 268 self.last_input_tokens = input_tokens 269 270 # Track by model 271 if model not in self.token_usage_by_model: 272 self.token_usage_by_model[model] = { 273 "input_tokens": 0, 274 "output_tokens": 0, 275 "cached_input_tokens": 0, 276 } 277 self.token_usage_by_model[model]["input_tokens"] += input_tokens 278 self.token_usage_by_model[model]["output_tokens"] += output_tokens 279 self.token_usage_by_model[model]["cached_input_tokens"] += cached_input_tokens 280 # Stamp max_input_tokens once; prefer the first non-null value we see 281 # so a later sub-call without the limit doesn't clear it. 282 if max_input_tokens and not self.token_usage_by_model[model].get("max_input_tokens"): 283 self.token_usage_by_model[model]["max_input_tokens"] = int(max_input_tokens) 284 285 # Track by source 286 source_key = f"{source}:{tool_name}" if tool_name else source 287 if source_key not in self.token_usage_by_source: 288 self.token_usage_by_source[source_key] = { 289 "input_tokens": 0, 290 "output_tokens": 0, 291 "cached_input_tokens": 0, 292 } 293 self.token_usage_by_source[source_key]["input_tokens"] += input_tokens 294 self.token_usage_by_source[source_key]["output_tokens"] += output_tokens 295 self.token_usage_by_source[source_key]["cached_input_tokens"] += cached_input_tokens 296 297 def get_token_usage_summary(self) -> Dict[str, Any]: 298 """ 299 Returns a summary of all token usage for this task. 300 301 Returns: 302 Dictionary containing total token counts and breakdowns by model and source. 303 """ 304 with self.lock: 305 return { 306 "total_input_tokens": self.total_input_tokens, 307 "total_output_tokens": self.total_output_tokens, 308 "total_cached_input_tokens": self.total_cached_input_tokens, 309 "total_tokens": self.total_input_tokens + self.total_output_tokens, 310 "last_input_tokens": self.last_input_tokens, 311 "by_model": dict(self.token_usage_by_model), 312 "by_source": dict(self.token_usage_by_source), 313 } 314 315 def set_security_data(self, key: str, value: Any) -> None: 316 """ 317 Store opaque security data (enterprise use only). 318 319 This method provides a secure storage mechanism for enterprise security features 320 such as authentication tokens. The stored data is isolated per task and 321 automatically cleaned up when the task completes. 322 323 Args: 324 key: Storage key for the security data 325 value: Security data to store (opaque to open source code) 326 """ 327 with self.lock: 328 self._security_context[key] = value 329 330 def get_security_data(self, key: str, default: Any = None) -> Any: 331 """ 332 Retrieve opaque security data (enterprise use only). 333 334 This method retrieves security data that was previously stored using 335 set_security_data(). The data is opaque to open source code. 336 337 Args: 338 key: Storage key for the security data 339 default: Default value to return if key not found 340 341 Returns: 342 The stored security data, or default if not found 343 """ 344 with self.lock: 345 return self._security_context.get(key, default) 346 347 def clear_security_data(self) -> None: 348 """ 349 Clear all security data. 350 351 This method is provided for completeness but is not explicitly called. 352 Security data is automatically cleaned up when the TaskExecutionContext 353 is removed from active_tasks and garbage collected. 354 """ 355 with self.lock: 356 self._security_context.clear() 357 358 def set_original_solace_message(self, message: Optional["SolaceMessage"]) -> None: 359 """ 360 Store the original Solace message for this task. 361 362 This message is used for ACK/NACK operations when the task completes. 363 Stored separately from a2a_context to avoid serialization issues when 364 the context is persisted to the ADK session state. 365 366 Args: 367 message: The Solace message that initiated this task, or None 368 """ 369 with self.lock: 370 self._original_solace_message = message 371 372 def get_original_solace_message(self) -> Optional["SolaceMessage"]: 373 """ 374 Retrieve the original Solace message for this task. 375 376 Returns: 377 The Solace message that initiated this task, or None if not available 378 """ 379 with self.lock: 380 return self._original_solace_message 381 382 def check_and_update_invocation(self, new_invocation_id: str) -> bool: 383 """ 384 Check if this is a new turn (different invocation_id) and update tracking. 385 386 Args: 387 new_invocation_id: The invocation_id from the current ADK event 388 389 Returns: 390 True if this is a new turn (invocation_id changed), False otherwise 391 """ 392 with self.lock: 393 is_new_turn = ( 394 self._current_invocation_id is not None 395 and self._current_invocation_id != new_invocation_id 396 ) 397 398 if is_new_turn: 399 # Mark that we need spacing before the next text 400 self._need_spacing_before_next_text = True 401 402 if is_new_turn or self._current_invocation_id is None: 403 self._current_invocation_id = new_invocation_id 404 self._first_text_seen_in_turn = False 405 406 return is_new_turn 407 408 def is_first_text_in_turn(self) -> bool: 409 """ 410 Check if this is the first text we're seeing in the current turn, 411 and mark it as seen. 412 413 Returns: 414 True if this is the first text in the turn, False otherwise 415 """ 416 with self.lock: 417 if not self._first_text_seen_in_turn: 418 self._first_text_seen_in_turn = True 419 return True 420 return False 421 422 def should_add_turn_spacing(self) -> bool: 423 """ 424 Check if we need to add spacing before the next text (because it's a new turn). 425 This flag is set when a new invocation starts and cleared after spacing is added. 426 427 Returns: 428 True if spacing should be added, False otherwise 429 """ 430 with self.lock: 431 if self._need_spacing_before_next_text: 432 self._need_spacing_before_next_text = False 433 return True 434 return False 435 436 def set_flag(self, key: str, value: Any) -> None: 437 """ 438 Set a task-level flag. 439 440 This method provides a generic mechanism for storing task-level state 441 that needs to persist across different parts of the task execution. 442 443 Args: 444 key: The flag name 445 value: The flag value 446 """ 447 with self.lock: 448 self._flags[key] = value 449 450 def get_flag(self, key: str, default: Any = None) -> Any: 451 """ 452 Get a task-level flag. 453 454 Args: 455 key: The flag name 456 default: Default value to return if flag not found 457 458 Returns: 459 The flag value, or default if not found 460 """ 461 with self.lock: 462 return self._flags.get(key, default)