tool_wrapper.py
1 """ 2 Defines the ADKToolWrapper, a consolidated wrapper for ADK tools. 3 """ 4 5 import logging 6 import asyncio 7 import functools 8 import inspect 9 from typing import Any, Callable, Dict, List, Optional, Literal, Set 10 11 from ...common.utils.embeds import ( 12 resolve_embeds_in_string, 13 evaluate_embed, 14 EARLY_EMBED_TYPES, 15 LATE_EMBED_TYPES, 16 EMBED_DELIMITER_OPEN, 17 ) 18 from ...common.utils.embeds.types import ResolutionMode 19 from ..tools.artifact_types import Artifact, is_artifact_type, get_artifact_info, ArtifactTypeInfo 20 from ..tools.artifact_preloading import ( 21 is_tool_context_facade_param as _is_tool_context_facade_param, 22 resolve_artifact_params, 23 ) 24 from ..utils.tool_context_facade import ToolContextFacade 25 26 # Observability instrumentation 27 from solace_ai_connector.common.observability import MonitorLatency 28 from ...common.observability import ToolMonitor 29 30 log = logging.getLogger(__name__) 31 32 33 class ADKToolWrapper: 34 """ 35 A consolidated wrapper for ADK tools that handles: 36 1. Preserving original function metadata (__doc__, __signature__) for ADK. 37 2. Resolving early-stage embeds in string arguments before execution. 38 3. Injecting tool-specific configuration. 39 4. Providing a resilient try/except block to catch all execution errors. 40 """ 41 42 def __init__( 43 self, 44 original_func: Callable, 45 tool_config: Optional[Dict], 46 tool_name: str, 47 origin: str, 48 raw_string_args: Optional[List[str]] = None, 49 resolution_type: Literal["early", "all"] = "all", 50 artifact_args: Optional[List[str]] = None, 51 ): 52 self._original_func = original_func 53 self._tool_config = tool_config or {} 54 self._tool_name = tool_name 55 self._resolution_type = resolution_type 56 self.origin = origin 57 self._raw_string_args = set(raw_string_args) if raw_string_args else set() 58 self._is_async = inspect.iscoroutinefunction(original_func) 59 60 self._types_to_resolve = EARLY_EMBED_TYPES 61 62 if self._resolution_type == "all": 63 self._types_to_resolve = EARLY_EMBED_TYPES.union(LATE_EMBED_TYPES) 64 65 # Ensure __name__ attribute is always set before functools.update_wrapper 66 self.__name__ = tool_name 67 68 try: 69 functools.update_wrapper(self, original_func) 70 except AttributeError as e: 71 log.debug( 72 "Could not fully update wrapper for tool '%s': %s. Using fallback attributes.", 73 self._tool_name, 74 e, 75 ) 76 # Ensure essential attributes are set even if update_wrapper fails 77 self.__name__ = tool_name 78 self.__doc__ = getattr(original_func, "__doc__", None) 79 80 try: 81 self.__code__ = original_func.__code__ 82 self.__globals__ = original_func.__globals__ 83 self.__defaults__ = getattr(original_func, "__defaults__", None) 84 self.__kwdefaults__ = getattr(original_func, "__kwdefaults__", None) 85 self.__closure__ = getattr(original_func, "__closure__", None) 86 except AttributeError: 87 log.debug( 88 "Could not delegate all dunder attributes for tool '%s'. This is normal for some built-in or C-based functions.", 89 self._tool_name, 90 ) 91 92 try: 93 self.__signature__ = inspect.signature(original_func) 94 self._accepts_tool_config = "tool_config" in self.__signature__.parameters 95 except (ValueError, TypeError): 96 self.__signature__ = None 97 self._accepts_tool_config = False 98 log.warning("Could not determine signature for tool '%s'.", self._tool_name) 99 100 # Initialize artifact params from explicit config 101 # Maps param name to ArtifactTypeInfo 102 self._artifact_params: Dict[str, ArtifactTypeInfo] = {} 103 if artifact_args: 104 for name in artifact_args: 105 self._artifact_params[name] = ArtifactTypeInfo(is_artifact=True) 106 107 # Track if the function expects a ToolContextFacade 108 self._ctx_facade_param_name: Optional[str] = None 109 110 # Auto-detect Artifact and ToolContextFacade type annotations 111 self._detect_special_params() 112 113 # Sanitize __signature__ so ADK doesn't see types it can't parse. 114 # Replace Artifact annotations with str, remove ToolContextFacade params. 115 self._sanitize_signature() 116 117 def _sanitize_signature(self) -> None: 118 """Replace Artifact type annotations with str and remove framework-injected 119 params (ToolContextFacade) from the exposed signature so that ADK's 120 automatic function declaration parser doesn't choke on unknown types.""" 121 if self.__signature__ is None: 122 return 123 124 new_params = [] 125 for param_name, param in self.__signature__.parameters.items(): 126 # Remove ToolContextFacade params — they are injected by the framework 127 if param_name == self._ctx_facade_param_name: 128 continue 129 130 # Replace Artifact annotations with str 131 if param_name in self._artifact_params: 132 new_param = param.replace(annotation=str) 133 new_params.append(new_param) 134 else: 135 new_params.append(param) 136 137 self.__signature__ = self.__signature__.replace(parameters=new_params) 138 139 @property 140 def _artifact_args(self) -> Set[str]: 141 """Property returning set of artifact param names.""" 142 return set(self._artifact_params.keys()) 143 144 def _detect_special_params(self) -> None: 145 """ 146 Detect special parameter types: 147 - Artifact / List[Artifact]: Will have artifact pre-loaded 148 - ToolContextFacade: Will have facade injected automatically 149 """ 150 if self.__signature__ is None: 151 return 152 153 for param_name, param in self.__signature__.parameters.items(): 154 if param_name in ("tool_context", "tool_config", "kwargs", "self", "cls"): 155 continue 156 157 # Check for Artifact (including List[Artifact]) 158 artifact_type_info = get_artifact_info(param.annotation) 159 if artifact_type_info.is_artifact: 160 self._artifact_params[param_name] = artifact_type_info 161 if artifact_type_info.is_list: 162 log.debug( 163 "[ADKToolWrapper:%s] Detected List[Artifact] param: %s", 164 self._tool_name, 165 param_name, 166 ) 167 else: 168 log.debug( 169 "[ADKToolWrapper:%s] Detected Artifact param: %s", 170 self._tool_name, 171 param_name, 172 ) 173 174 # Check for ToolContextFacade 175 if _is_tool_context_facade_param(param.annotation): 176 self._ctx_facade_param_name = param_name 177 log.debug( 178 "[ADKToolWrapper:%s] Detected ToolContextFacade param: %s", 179 self._tool_name, 180 param_name, 181 ) 182 183 if self._artifact_params: 184 log.info( 185 "[ADKToolWrapper:%s] Will pre-load artifacts for params: %s", 186 self._tool_name, 187 list(self._artifact_params.keys()), 188 ) 189 190 if self._ctx_facade_param_name: 191 log.info( 192 "[ADKToolWrapper:%s] Will inject ToolContextFacade as '%s'", 193 self._tool_name, 194 self._ctx_facade_param_name, 195 ) 196 197 async def __call__(self, *args, **kwargs): 198 # Allow overriding the context for embed resolution, e.g., when called from a callback 199 _override_embed_context = kwargs.pop("_override_embed_context", None) 200 log_identifier = f"[ADKToolWrapper:{self._tool_name}]" 201 202 context_for_embeds = _override_embed_context or kwargs.get("tool_context") 203 resolved_args = [] 204 resolved_kwargs = kwargs.copy() 205 206 if context_for_embeds: 207 # Resolve positional args 208 for arg in args: 209 if isinstance(arg, str) and EMBED_DELIMITER_OPEN in arg: 210 resolved_arg, _, _ = await resolve_embeds_in_string( 211 text=arg, 212 context=context_for_embeds, 213 resolver_func=evaluate_embed, 214 types_to_resolve=self._types_to_resolve, 215 resolution_mode=ResolutionMode.TOOL_PARAMETER, 216 log_identifier=log_identifier, 217 config=self._tool_config, 218 ) 219 resolved_args.append(resolved_arg) 220 else: 221 resolved_args.append(arg) 222 223 for key, value in kwargs.items(): 224 if key in self._raw_string_args and isinstance(value, str): 225 log.debug( 226 "%s Skipping embed resolution for raw string kwarg '%s'", 227 log_identifier, 228 key, 229 ) 230 elif isinstance(value, str) and EMBED_DELIMITER_OPEN in value: 231 log.debug("%s Resolving embeds for kwarg '%s'", log_identifier, key) 232 resolved_value, _, _ = await resolve_embeds_in_string( 233 text=value, 234 context=context_for_embeds, 235 resolver_func=evaluate_embed, 236 types_to_resolve=self._types_to_resolve, 237 resolution_mode=ResolutionMode.TOOL_PARAMETER, 238 log_identifier=log_identifier, 239 config=self._tool_config, 240 ) 241 resolved_kwargs[key] = resolved_value 242 else: 243 log.warning( 244 "%s ToolContext not found. Skipping embed resolution for all args.", 245 log_identifier, 246 ) 247 resolved_args = list(args) 248 249 if self._accepts_tool_config: 250 resolved_kwargs["tool_config"] = self._tool_config 251 elif self._tool_config: 252 log.warning( 253 "%s Tool was provided a 'tool_config' but its function signature does not accept it. The config will be ignored.", 254 log_identifier, 255 ) 256 257 # Inject ToolContextFacade if the function expects it 258 if self._ctx_facade_param_name and context_for_embeds: 259 facade = ToolContextFacade( 260 tool_context=context_for_embeds, 261 tool_config=self._tool_config, 262 ) 263 resolved_kwargs[self._ctx_facade_param_name] = facade 264 log.debug( 265 "%s Injected ToolContextFacade as '%s'", 266 log_identifier, 267 self._ctx_facade_param_name, 268 ) 269 270 # Pre-load artifacts for Artifact parameters 271 if self._artifact_params and context_for_embeds: 272 error = await resolve_artifact_params( 273 artifact_params=self._artifact_params, 274 resolved_kwargs=resolved_kwargs, 275 tool_context=context_for_embeds, 276 tool_name=self._tool_name, 277 log_identifier=log_identifier, 278 ) 279 if error is not None: 280 return error 281 282 try: 283 # Instrument tool execution latency 284 with MonitorLatency(ToolMonitor.create(name=self._tool_name)): 285 if self._is_async: 286 return await self._original_func(*resolved_args, **resolved_kwargs) 287 else: 288 loop = asyncio.get_running_loop() 289 return await loop.run_in_executor( 290 None, 291 functools.partial( 292 self._original_func, *resolved_args, **resolved_kwargs 293 ), 294 ) 295 except Exception as e: 296 log.exception("%s Tool execution failed: %s", log_identifier, e) 297 return { 298 "status": "error", 299 "message": f"Tool '{self._tool_name}' failed with an unexpected error: {str(e)}", 300 "tool_name": self._tool_name, 301 }