/ src / solace_agent_mesh / agent / adk / tool_wrapper.py
tool_wrapper.py
  1  """
  2  Defines the ADKToolWrapper, a consolidated wrapper for ADK tools.
  3  """
  4  
  5  import logging
  6  import asyncio
  7  import functools
  8  import inspect
  9  from typing import Any, Callable, Dict, List, Optional, Literal, Set
 10  
 11  from ...common.utils.embeds import (
 12      resolve_embeds_in_string,
 13      evaluate_embed,
 14      EARLY_EMBED_TYPES,
 15      LATE_EMBED_TYPES,
 16      EMBED_DELIMITER_OPEN,
 17  )
 18  from ...common.utils.embeds.types import ResolutionMode
 19  from ..tools.artifact_types import Artifact, is_artifact_type, get_artifact_info, ArtifactTypeInfo
 20  from ..tools.artifact_preloading import (
 21      is_tool_context_facade_param as _is_tool_context_facade_param,
 22      resolve_artifact_params,
 23  )
 24  from ..utils.tool_context_facade import ToolContextFacade
 25  
 26  # Observability instrumentation
 27  from solace_ai_connector.common.observability import MonitorLatency
 28  from ...common.observability import ToolMonitor
 29  
 30  log = logging.getLogger(__name__)
 31  
 32  
 33  class ADKToolWrapper:
 34      """
 35      A consolidated wrapper for ADK tools that handles:
 36      1. Preserving original function metadata (__doc__, __signature__) for ADK.
 37      2. Resolving early-stage embeds in string arguments before execution.
 38      3. Injecting tool-specific configuration.
 39      4. Providing a resilient try/except block to catch all execution errors.
 40      """
 41  
 42      def __init__(
 43          self,
 44          original_func: Callable,
 45          tool_config: Optional[Dict],
 46          tool_name: str,
 47          origin: str,
 48          raw_string_args: Optional[List[str]] = None,
 49          resolution_type: Literal["early", "all"] = "all",
 50          artifact_args: Optional[List[str]] = None,
 51      ):
 52          self._original_func = original_func
 53          self._tool_config = tool_config or {}
 54          self._tool_name = tool_name
 55          self._resolution_type = resolution_type
 56          self.origin = origin
 57          self._raw_string_args = set(raw_string_args) if raw_string_args else set()
 58          self._is_async = inspect.iscoroutinefunction(original_func)
 59  
 60          self._types_to_resolve = EARLY_EMBED_TYPES
 61  
 62          if self._resolution_type == "all":
 63              self._types_to_resolve = EARLY_EMBED_TYPES.union(LATE_EMBED_TYPES)
 64  
 65          # Ensure __name__ attribute is always set before functools.update_wrapper
 66          self.__name__ = tool_name
 67  
 68          try:
 69              functools.update_wrapper(self, original_func)
 70          except AttributeError as e:
 71              log.debug(
 72                  "Could not fully update wrapper for tool '%s': %s. Using fallback attributes.",
 73                  self._tool_name,
 74                  e,
 75              )
 76              # Ensure essential attributes are set even if update_wrapper fails
 77              self.__name__ = tool_name
 78              self.__doc__ = getattr(original_func, "__doc__", None)
 79  
 80          try:
 81              self.__code__ = original_func.__code__
 82              self.__globals__ = original_func.__globals__
 83              self.__defaults__ = getattr(original_func, "__defaults__", None)
 84              self.__kwdefaults__ = getattr(original_func, "__kwdefaults__", None)
 85              self.__closure__ = getattr(original_func, "__closure__", None)
 86          except AttributeError:
 87              log.debug(
 88                  "Could not delegate all dunder attributes for tool '%s'. This is normal for some built-in or C-based functions.",
 89                  self._tool_name,
 90              )
 91  
 92          try:
 93              self.__signature__ = inspect.signature(original_func)
 94              self._accepts_tool_config = "tool_config" in self.__signature__.parameters
 95          except (ValueError, TypeError):
 96              self.__signature__ = None
 97              self._accepts_tool_config = False
 98              log.warning("Could not determine signature for tool '%s'.", self._tool_name)
 99  
100          # Initialize artifact params from explicit config
101          # Maps param name to ArtifactTypeInfo
102          self._artifact_params: Dict[str, ArtifactTypeInfo] = {}
103          if artifact_args:
104              for name in artifact_args:
105                  self._artifact_params[name] = ArtifactTypeInfo(is_artifact=True)
106  
107          # Track if the function expects a ToolContextFacade
108          self._ctx_facade_param_name: Optional[str] = None
109  
110          # Auto-detect Artifact and ToolContextFacade type annotations
111          self._detect_special_params()
112  
113          # Sanitize __signature__ so ADK doesn't see types it can't parse.
114          # Replace Artifact annotations with str, remove ToolContextFacade params.
115          self._sanitize_signature()
116  
117      def _sanitize_signature(self) -> None:
118          """Replace Artifact type annotations with str and remove framework-injected
119          params (ToolContextFacade) from the exposed signature so that ADK's
120          automatic function declaration parser doesn't choke on unknown types."""
121          if self.__signature__ is None:
122              return
123  
124          new_params = []
125          for param_name, param in self.__signature__.parameters.items():
126              # Remove ToolContextFacade params — they are injected by the framework
127              if param_name == self._ctx_facade_param_name:
128                  continue
129  
130              # Replace Artifact annotations with str
131              if param_name in self._artifact_params:
132                  new_param = param.replace(annotation=str)
133                  new_params.append(new_param)
134              else:
135                  new_params.append(param)
136  
137          self.__signature__ = self.__signature__.replace(parameters=new_params)
138  
139      @property
140      def _artifact_args(self) -> Set[str]:
141          """Property returning set of artifact param names."""
142          return set(self._artifact_params.keys())
143  
144      def _detect_special_params(self) -> None:
145          """
146          Detect special parameter types:
147          - Artifact / List[Artifact]: Will have artifact pre-loaded
148          - ToolContextFacade: Will have facade injected automatically
149          """
150          if self.__signature__ is None:
151              return
152  
153          for param_name, param in self.__signature__.parameters.items():
154              if param_name in ("tool_context", "tool_config", "kwargs", "self", "cls"):
155                  continue
156  
157              # Check for Artifact (including List[Artifact])
158              artifact_type_info = get_artifact_info(param.annotation)
159              if artifact_type_info.is_artifact:
160                  self._artifact_params[param_name] = artifact_type_info
161                  if artifact_type_info.is_list:
162                      log.debug(
163                          "[ADKToolWrapper:%s] Detected List[Artifact] param: %s",
164                          self._tool_name,
165                          param_name,
166                      )
167                  else:
168                      log.debug(
169                          "[ADKToolWrapper:%s] Detected Artifact param: %s",
170                          self._tool_name,
171                          param_name,
172                      )
173  
174              # Check for ToolContextFacade
175              if _is_tool_context_facade_param(param.annotation):
176                  self._ctx_facade_param_name = param_name
177                  log.debug(
178                      "[ADKToolWrapper:%s] Detected ToolContextFacade param: %s",
179                      self._tool_name,
180                      param_name,
181                  )
182  
183          if self._artifact_params:
184              log.info(
185                  "[ADKToolWrapper:%s] Will pre-load artifacts for params: %s",
186                  self._tool_name,
187                  list(self._artifact_params.keys()),
188              )
189  
190          if self._ctx_facade_param_name:
191              log.info(
192                  "[ADKToolWrapper:%s] Will inject ToolContextFacade as '%s'",
193                  self._tool_name,
194                  self._ctx_facade_param_name,
195              )
196  
197      async def __call__(self, *args, **kwargs):
198          # Allow overriding the context for embed resolution, e.g., when called from a callback
199          _override_embed_context = kwargs.pop("_override_embed_context", None)
200          log_identifier = f"[ADKToolWrapper:{self._tool_name}]"
201  
202          context_for_embeds = _override_embed_context or kwargs.get("tool_context")
203          resolved_args = []
204          resolved_kwargs = kwargs.copy()
205  
206          if context_for_embeds:
207              # Resolve positional args
208              for arg in args:
209                  if isinstance(arg, str) and EMBED_DELIMITER_OPEN in arg:
210                      resolved_arg, _, _ = await resolve_embeds_in_string(
211                          text=arg,
212                          context=context_for_embeds,
213                          resolver_func=evaluate_embed,
214                          types_to_resolve=self._types_to_resolve,
215                          resolution_mode=ResolutionMode.TOOL_PARAMETER,
216                          log_identifier=log_identifier,
217                          config=self._tool_config,
218                      )
219                      resolved_args.append(resolved_arg)
220                  else:
221                      resolved_args.append(arg)
222  
223              for key, value in kwargs.items():
224                  if key in self._raw_string_args and isinstance(value, str):
225                      log.debug(
226                          "%s Skipping embed resolution for raw string kwarg '%s'",
227                          log_identifier,
228                          key,
229                      )
230                  elif isinstance(value, str) and EMBED_DELIMITER_OPEN in value:
231                      log.debug("%s Resolving embeds for kwarg '%s'", log_identifier, key)
232                      resolved_value, _, _ = await resolve_embeds_in_string(
233                          text=value,
234                          context=context_for_embeds,
235                          resolver_func=evaluate_embed,
236                          types_to_resolve=self._types_to_resolve,
237                          resolution_mode=ResolutionMode.TOOL_PARAMETER,
238                          log_identifier=log_identifier,
239                          config=self._tool_config,
240                      )
241                      resolved_kwargs[key] = resolved_value
242          else:
243              log.warning(
244                  "%s ToolContext not found. Skipping embed resolution for all args.",
245                  log_identifier,
246              )
247              resolved_args = list(args)
248  
249          if self._accepts_tool_config:
250              resolved_kwargs["tool_config"] = self._tool_config
251          elif self._tool_config:
252              log.warning(
253                  "%s Tool was provided a 'tool_config' but its function signature does not accept it. The config will be ignored.",
254                  log_identifier,
255              )
256  
257          # Inject ToolContextFacade if the function expects it
258          if self._ctx_facade_param_name and context_for_embeds:
259              facade = ToolContextFacade(
260                  tool_context=context_for_embeds,
261                  tool_config=self._tool_config,
262              )
263              resolved_kwargs[self._ctx_facade_param_name] = facade
264              log.debug(
265                  "%s Injected ToolContextFacade as '%s'",
266                  log_identifier,
267                  self._ctx_facade_param_name,
268              )
269  
270          # Pre-load artifacts for Artifact parameters
271          if self._artifact_params and context_for_embeds:
272              error = await resolve_artifact_params(
273                  artifact_params=self._artifact_params,
274                  resolved_kwargs=resolved_kwargs,
275                  tool_context=context_for_embeds,
276                  tool_name=self._tool_name,
277                  log_identifier=log_identifier,
278              )
279              if error is not None:
280                  return error
281  
282          try:
283              # Instrument tool execution latency
284              with MonitorLatency(ToolMonitor.create(name=self._tool_name)):
285                  if self._is_async:
286                      return await self._original_func(*resolved_args, **resolved_kwargs)
287                  else:
288                      loop = asyncio.get_running_loop()
289                      return await loop.run_in_executor(
290                          None,
291                          functools.partial(
292                              self._original_func, *resolved_args, **resolved_kwargs
293                          ),
294                      )
295          except Exception as e:
296              log.exception("%s Tool execution failed: %s", log_identifier, e)
297              return {
298                  "status": "error",
299                  "message": f"Tool '{self._tool_name}' failed with an unexpected error: {str(e)}",
300                  "tool_name": self._tool_name,
301              }