/ src / solace_agent_mesh / agent / sac / task_execution_context.py
task_execution_context.py
  1  """
  2  Encapsulates the runtime state for a single, in-flight agent task.
  3  """
  4  
  5  import asyncio
  6  import threading
  7  from typing import Any, Dict, List, Optional, TYPE_CHECKING
  8  
  9  if TYPE_CHECKING:
 10      from solace_ai_connector.common.message import Message as SolaceMessage
 11  
 12  
 13  class TaskExecutionContext:
 14      """
 15      A class to hold all runtime state and control mechanisms for a single agent task.
 16      This object is created when a task is initiated and destroyed when it completes,
 17      ensuring that all state is properly encapsulated and cleaned up.
 18      """
 19  
 20      def __init__(self, task_id: str, a2a_context: Dict[str, Any]):
 21          """
 22          Initializes the TaskExecutionContext.
 23  
 24          Args:
 25              task_id: The unique logical ID of the task.
 26              a2a_context: The original, rich context dictionary from the A2A request.
 27          """
 28          self.task_id: str = task_id
 29          self.a2a_context: Dict[str, Any] = a2a_context
 30          self.cancellation_event: asyncio.Event = asyncio.Event()
 31          self.streaming_buffer: str = ""
 32          self.run_based_response_buffer: str = ""
 33          self.active_peer_sub_tasks: Dict[str, Dict[str, Any]] = {}
 34          self.parallel_tool_calls: Dict[str, Dict[str, Any]] = {}
 35          self.produced_artifacts: List[Dict[str, Any]] = []
 36          self.artifact_signals_to_return: List[Dict[str, Any]] = []
 37          self.event_loop: Optional[asyncio.AbstractEventLoop] = None
 38          self.lock: threading.Lock = threading.Lock()
 39          self.is_paused: bool = False  # Track if task is paused waiting for peer/auth
 40  
 41          # Token usage tracking
 42          self.total_input_tokens: int = 0
 43          self.total_output_tokens: int = 0
 44          self.total_cached_input_tokens: int = 0
 45          # Prompt tokens from the most recent agent LLM call — represents the
 46          # peak context window occupancy for this task, unlike total_input_tokens
 47          # which sums across every turn and inflates with each peer delegation.
 48          self.last_input_tokens: int = 0
 49          self.token_usage_by_model: Dict[str, Dict[str, int]] = {}
 50          self.token_usage_by_source: Dict[str, Dict[str, int]] = {}
 51  
 52          # Generic flags storage for task-level state
 53          self._flags: Dict[str, Any] = {}
 54  
 55          # Generic security storage (enterprise use only)
 56          self._security_context: Dict[str, Any] = {}
 57  
 58          # Original Solace message for ACK/NACK operations
 59          # Stored here instead of a2a_context to avoid serialization issues
 60          self._original_solace_message: Optional["SolaceMessage"] = None
 61  
 62          # Turn tracking for proper spacing between LLM turns
 63          self._current_invocation_id: Optional[str] = None
 64          self._first_text_seen_in_turn: bool = False
 65          self._need_spacing_before_next_text: bool = False
 66  
 67      def cancel(self) -> None:
 68          """Signals that the task should be cancelled."""
 69          self.cancellation_event.set()
 70  
 71      def is_cancelled(self) -> bool:
 72          """Checks if the cancellation event has been set."""
 73          return self.cancellation_event.is_set()
 74  
 75      def set_paused(self, paused: bool) -> None:
 76          """
 77          Marks the task as paused (waiting for peer response or user input).
 78  
 79          Args:
 80              paused: True if task is paused, False if resuming.
 81          """
 82          with self.lock:
 83              self.is_paused = paused
 84  
 85      def get_is_paused(self) -> bool:
 86          """
 87          Checks if the task is currently paused.
 88  
 89          Returns:
 90              True if task is paused (waiting for peer/auth), False otherwise.
 91          """
 92          with self.lock:
 93              return self.is_paused
 94  
 95      def append_to_streaming_buffer(self, text: str) -> None:
 96          """Appends a chunk of text to the main streaming buffer."""
 97          with self.lock:
 98              self.streaming_buffer += text
 99  
100      def flush_streaming_buffer(self) -> str:
101          """Returns the entire content of the streaming buffer and clears it."""
102          with self.lock:
103              content = self.streaming_buffer
104              self.streaming_buffer = ""
105              return content
106  
107      def get_streaming_buffer_content(self) -> str:
108          """Returns the current buffer content without clearing it."""
109          with self.lock:
110              return self.streaming_buffer
111  
112      def append_to_run_based_buffer(self, text: str) -> None:
113          """Appends a chunk of processed text to the run-based response buffer."""
114          with self.lock:
115              self.run_based_response_buffer += text
116  
117      def register_peer_sub_task(
118          self, sub_task_id: str, correlation_data: Dict[str, Any]
119      ) -> None:
120          """Adds a new peer sub-task's correlation data to the tracking dictionary."""
121          with self.lock:
122              self.active_peer_sub_tasks[sub_task_id] = correlation_data
123  
124      def claim_sub_task_completion(self, sub_task_id: str) -> Optional[Dict[str, Any]]:
125          """
126          Atomically retrieves and removes a sub-task's correlation data.
127          This is the core atomic operation to prevent race conditions.
128          Returns the correlation data if the claim is successful, otherwise None.
129          """
130          with self.lock:
131              return self.active_peer_sub_tasks.pop(sub_task_id, None)
132  
133      def register_parallel_call_sent(self, invocation_id: str) -> None:
134          """
135          Registers that a new parallel tool call has been sent for a specific invocation.
136          Initializes the tracking dictionary for the invocation if it's the first call,
137          otherwise increments the total.
138          """
139          with self.lock:
140              if invocation_id not in self.parallel_tool_calls:
141                  self.parallel_tool_calls[invocation_id] = {
142                      "total": 1,
143                      "completed": 0,
144                      "results": [],
145                  }
146              else:
147                  self.parallel_tool_calls[invocation_id]["total"] += 1
148  
149      def handle_peer_timeout(
150          self,
151          sub_task_id: str,
152          correlation_data: Dict,
153          timeout_sec: int,
154          invocation_id: str,
155      ) -> bool:
156          """
157          Handles a timeout for a specific peer sub-task for a given invocation.
158  
159          Updates the parallel call tracker with a formatted error message and returns
160          True if all peer calls for that invocation are now complete.
161  
162          Args:
163              sub_task_id: The ID of the sub-task that timed out.
164              correlation_data: The correlation data associated with the sub-task.
165              timeout_sec: The timeout duration in seconds.
166              invocation_id: The ID of the invocation that initiated the parallel calls.
167  
168          Returns:
169              A boolean indicating if all parallel calls for the invocation are now complete.
170          """
171          peer_tool_name = correlation_data.get("peer_tool_name", "unknown_tool")
172          timeout_message = f"Request to peer agent tool '{peer_tool_name}' timed out after {timeout_sec} seconds."
173  
174          # The payload must be a dictionary with a 'result' key containing the simple string.
175          # This ensures the ADK framework presents it to the LLM as a simple text response.
176          simple_error_payload = {"result": timeout_message}
177  
178          current_result = {
179              "adk_function_call_id": correlation_data.get("adk_function_call_id"),
180              "peer_tool_name": peer_tool_name,
181              "payload": simple_error_payload,
182          }
183          return self.record_parallel_result(current_result, invocation_id)
184  
185      def record_parallel_result(self, result: Dict, invocation_id: str) -> bool:
186          """
187          Records a result for a parallel tool call for a specific invocation
188          and returns True if all calls for that invocation are now complete.
189          """
190          with self.lock:
191              invocation_state = self.parallel_tool_calls.get(invocation_id)
192              if not invocation_state:
193                  # This can happen if a response arrives after a timeout has cleaned up.
194                  return False
195  
196              invocation_state["results"].append(result)
197              invocation_state["completed"] += 1
198              return invocation_state["completed"] >= invocation_state["total"]
199  
200      def clear_parallel_invocation_state(self, invocation_id: str) -> None:
201          """
202          Removes the state for a completed parallel tool call invocation.
203          """
204          with self.lock:
205              if invocation_id in self.parallel_tool_calls:
206                  del self.parallel_tool_calls[invocation_id]
207  
208      def register_produced_artifact(self, filename: str, version: int) -> None:
209          """Adds a newly created artifact to the tracking list."""
210          with self.lock:
211              self.produced_artifacts.append({"filename": filename, "version": version})
212  
213      def add_artifact_signal(self, signal: Dict[str, Any]) -> None:
214          """Adds an artifact return signal to the list in a thread-safe manner."""
215          with self.lock:
216              self.artifact_signals_to_return.append(signal)
217  
218      def get_and_clear_artifact_signals(self) -> List[Dict[str, Any]]:
219          """
220          Retrieves all pending artifact signals and clears the list atomically.
221          """
222          with self.lock:
223              signals = list(self.artifact_signals_to_return)  # Create a copy
224              self.artifact_signals_to_return.clear()
225              return signals
226  
227      def set_event_loop(self, loop: asyncio.AbstractEventLoop) -> None:
228          """Stores a reference to the task's event loop."""
229          with self.lock:
230              self.event_loop = loop
231  
232      def get_event_loop(self) -> Optional[asyncio.AbstractEventLoop]:
233          """Retrieves the stored event loop."""
234          with self.lock:
235              return self.event_loop
236  
237      def record_token_usage(
238          self,
239          input_tokens: int,
240          output_tokens: int,
241          model: str,
242          source: str = "agent",
243          tool_name: Optional[str] = None,
244          cached_input_tokens: int = 0,
245          max_input_tokens: Optional[int] = None,
246      ) -> None:
247          """
248          Records token usage for an LLM call.
249  
250          Args:
251              input_tokens: Number of input/prompt tokens.
252              output_tokens: Number of output/completion tokens.
253              model: Model identifier used for this call.
254              source: Source of the LLM call ("agent" or "tool").
255              tool_name: Tool name if source is "tool".
256              cached_input_tokens: Number of cached input tokens (optional).
257              max_input_tokens: Context window size for this model (optional).
258                  Stamped once per model so downstream consumers (e.g. the chat
259                  context-usage indicator) can render a usage-vs-limit ratio
260                  without cross-service lookups.
261          """
262          with self.lock:
263              # Update totals
264              self.total_input_tokens += input_tokens
265              self.total_output_tokens += output_tokens
266              self.total_cached_input_tokens += cached_input_tokens
267              if source == "agent":
268                  self.last_input_tokens = input_tokens
269  
270              # Track by model
271              if model not in self.token_usage_by_model:
272                  self.token_usage_by_model[model] = {
273                      "input_tokens": 0,
274                      "output_tokens": 0,
275                      "cached_input_tokens": 0,
276                  }
277              self.token_usage_by_model[model]["input_tokens"] += input_tokens
278              self.token_usage_by_model[model]["output_tokens"] += output_tokens
279              self.token_usage_by_model[model]["cached_input_tokens"] += cached_input_tokens
280              # Stamp max_input_tokens once; prefer the first non-null value we see
281              # so a later sub-call without the limit doesn't clear it.
282              if max_input_tokens and not self.token_usage_by_model[model].get("max_input_tokens"):
283                  self.token_usage_by_model[model]["max_input_tokens"] = int(max_input_tokens)
284              
285              # Track by source
286              source_key = f"{source}:{tool_name}" if tool_name else source
287              if source_key not in self.token_usage_by_source:
288                  self.token_usage_by_source[source_key] = {
289                      "input_tokens": 0,
290                      "output_tokens": 0,
291                      "cached_input_tokens": 0,
292                  }
293              self.token_usage_by_source[source_key]["input_tokens"] += input_tokens
294              self.token_usage_by_source[source_key]["output_tokens"] += output_tokens
295              self.token_usage_by_source[source_key]["cached_input_tokens"] += cached_input_tokens
296  
297      def get_token_usage_summary(self) -> Dict[str, Any]:
298          """
299          Returns a summary of all token usage for this task.
300          
301          Returns:
302              Dictionary containing total token counts and breakdowns by model and source.
303          """
304          with self.lock:
305              return {
306                  "total_input_tokens": self.total_input_tokens,
307                  "total_output_tokens": self.total_output_tokens,
308                  "total_cached_input_tokens": self.total_cached_input_tokens,
309                  "total_tokens": self.total_input_tokens + self.total_output_tokens,
310                  "last_input_tokens": self.last_input_tokens,
311                  "by_model": dict(self.token_usage_by_model),
312                  "by_source": dict(self.token_usage_by_source),
313              }
314      
315      def set_security_data(self, key: str, value: Any) -> None:
316          """
317          Store opaque security data (enterprise use only).
318          
319          This method provides a secure storage mechanism for enterprise security features
320          such as authentication tokens. The stored data is isolated per task and
321          automatically cleaned up when the task completes.
322          
323          Args:
324              key: Storage key for the security data
325              value: Security data to store (opaque to open source code)
326          """
327          with self.lock:
328              self._security_context[key] = value
329      
330      def get_security_data(self, key: str, default: Any = None) -> Any:
331          """
332          Retrieve opaque security data (enterprise use only).
333          
334          This method retrieves security data that was previously stored using
335          set_security_data(). The data is opaque to open source code.
336          
337          Args:
338              key: Storage key for the security data
339              default: Default value to return if key not found
340              
341          Returns:
342              The stored security data, or default if not found
343          """
344          with self.lock:
345              return self._security_context.get(key, default)
346      
347      def clear_security_data(self) -> None:
348          """
349          Clear all security data.
350  
351          This method is provided for completeness but is not explicitly called.
352          Security data is automatically cleaned up when the TaskExecutionContext
353          is removed from active_tasks and garbage collected.
354          """
355          with self.lock:
356              self._security_context.clear()
357  
358      def set_original_solace_message(self, message: Optional["SolaceMessage"]) -> None:
359          """
360          Store the original Solace message for this task.
361  
362          This message is used for ACK/NACK operations when the task completes.
363          Stored separately from a2a_context to avoid serialization issues when
364          the context is persisted to the ADK session state.
365  
366          Args:
367              message: The Solace message that initiated this task, or None
368          """
369          with self.lock:
370              self._original_solace_message = message
371  
372      def get_original_solace_message(self) -> Optional["SolaceMessage"]:
373          """
374          Retrieve the original Solace message for this task.
375  
376          Returns:
377              The Solace message that initiated this task, or None if not available
378          """
379          with self.lock:
380              return self._original_solace_message
381  
382      def check_and_update_invocation(self, new_invocation_id: str) -> bool:
383          """
384          Check if this is a new turn (different invocation_id) and update tracking.
385  
386          Args:
387              new_invocation_id: The invocation_id from the current ADK event
388  
389          Returns:
390              True if this is a new turn (invocation_id changed), False otherwise
391          """
392          with self.lock:
393              is_new_turn = (
394                  self._current_invocation_id is not None
395                  and self._current_invocation_id != new_invocation_id
396              )
397  
398              if is_new_turn:
399                  # Mark that we need spacing before the next text
400                  self._need_spacing_before_next_text = True
401  
402              if is_new_turn or self._current_invocation_id is None:
403                  self._current_invocation_id = new_invocation_id
404                  self._first_text_seen_in_turn = False
405  
406              return is_new_turn
407  
408      def is_first_text_in_turn(self) -> bool:
409          """
410          Check if this is the first text we're seeing in the current turn,
411          and mark it as seen.
412  
413          Returns:
414              True if this is the first text in the turn, False otherwise
415          """
416          with self.lock:
417              if not self._first_text_seen_in_turn:
418                  self._first_text_seen_in_turn = True
419                  return True
420              return False
421  
422      def should_add_turn_spacing(self) -> bool:
423          """
424          Check if we need to add spacing before the next text (because it's a new turn).
425          This flag is set when a new invocation starts and cleared after spacing is added.
426  
427          Returns:
428              True if spacing should be added, False otherwise
429          """
430          with self.lock:
431              if self._need_spacing_before_next_text:
432                  self._need_spacing_before_next_text = False
433                  return True
434              return False
435  
436      def set_flag(self, key: str, value: Any) -> None:
437          """
438          Set a task-level flag.
439  
440          This method provides a generic mechanism for storing task-level state
441          that needs to persist across different parts of the task execution.
442  
443          Args:
444              key: The flag name
445              value: The flag value
446          """
447          with self.lock:
448              self._flags[key] = value
449  
450      def get_flag(self, key: str, default: Any = None) -> Any:
451          """
452          Get a task-level flag.
453  
454          Args:
455              key: The flag name
456              default: Default value to return if flag not found
457  
458          Returns:
459              The flag value, or default if not found
460          """
461          with self.lock:
462              return self._flags.get(key, default)