/ src / solace_agent_mesh / agent / tools / dynamic_tool.py
dynamic_tool.py
  1  """
  2  Defines the base classes and helpers for "dynamic" tools.
  3  Dynamic tools allow for programmatic definition of tool names, descriptions,
  4  and parameter schemas, offering more flexibility than standard Python tools.
  5  """
  6  
  7  import logging
  8  from abc import ABC, abstractmethod
  9  from typing import (
 10      Optional,
 11      List,
 12      Callable,
 13      Dict,
 14      Any,
 15      Set,
 16      get_origin,
 17      get_args,
 18      Union,
 19      Literal,
 20      TYPE_CHECKING,
 21      Type,
 22  )
 23  import inspect
 24  
 25  from pydantic import BaseModel
 26  from google.adk.tools import BaseTool, ToolContext
 27  from google.genai import types as adk_types
 28  
 29  from solace_agent_mesh.agent.utils.context_helpers import get_original_session_id
 30  from solace_agent_mesh.agent.utils.tool_context_facade import ToolContextFacade
 31  from .artifact_types import Artifact, is_artifact_type, get_artifact_info, ArtifactTypeInfo
 32  from .artifact_preloading import (
 33      is_tool_context_facade_param as _is_tool_context_facade_param,
 34      resolve_artifact_params,
 35  )
 36  
 37  from ...common.utils.embeds import (
 38      resolve_embeds_in_string,
 39      evaluate_embed,
 40      EARLY_EMBED_TYPES,
 41      LATE_EMBED_TYPES,
 42      EMBED_DELIMITER_OPEN,
 43  )
 44  from ...common.utils.embeds.types import ResolutionMode
 45  
 46  log = logging.getLogger(__name__)
 47  
 48  
 49  if TYPE_CHECKING:
 50      from ..sac.component import SamAgentComponent
 51      from .tool_config_types import AnyToolConfig
 52  
 53  
 54  # --- Base Class for Programmatic Tools ---
 55  
 56  
 57  class DynamicTool(BaseTool, ABC):
 58      """
 59      Base class for dynamic tools that can define their own function names,
 60      descriptions, and parameter schemas programmatically.
 61      """
 62  
 63      config_model: Optional[Type[BaseModel]] = None
 64  
 65      def __init__(self, tool_config: Optional[Union[dict, BaseModel]] = None):
 66          # Initialize with placeholder values, will be overridden by properties
 67          super().__init__(
 68              name="dynamic_tool_placeholder", description="dynamic_tool_placeholder"
 69          )
 70          self.tool_config = tool_config or {}
 71  
 72      async def init(
 73          self, component: "SamAgentComponent", tool_config: "AnyToolConfig"
 74      ) -> None:
 75          """
 76          (Optional) Asynchronously initializes resources for the tool.
 77          This method is called once when the agent starts up.
 78          The `component` provides access to agent-wide state, and `tool_config`
 79          is the validated Pydantic model instance if `config_model` is defined.
 80          """
 81          pass
 82  
 83      async def cleanup(
 84          self, component: "SamAgentComponent", tool_config: "AnyToolConfig"
 85      ) -> None:
 86          """
 87          (Optional) Asynchronously cleans up resources used by the tool.
 88          This method is called once when the agent is shutting down.
 89          """
 90          pass
 91  
 92      @property
 93      @abstractmethod
 94      def tool_name(self) -> str:
 95          """Return the function name that the LLM will call."""
 96          pass
 97  
 98      @property
 99      @abstractmethod
100      def tool_description(self) -> str:
101          """Return the description of what this tool does."""
102          pass
103  
104      @property
105      @abstractmethod
106      def parameters_schema(self) -> adk_types.Schema:
107          """Return the ADK Schema defining the tool's parameters."""
108          pass
109  
110      @property
111      def raw_string_args(self) -> List[str]:
112          """
113          Return a list of argument names that should not have embeds resolved.
114          Subclasses can override this property.
115          """
116          return []
117  
118      @property
119      def resolution_type(self) -> Literal["early", "all"]:
120          """
121          Determines which embeds to resolve. 'early' resolves simple embeds like
122          math and uuid. 'all' also resolves 'artifact_content'.
123          Defaults to 'early'.
124          """
125          return "early"
126  
127      @property
128      def artifact_args(self) -> List[str]:
129          """
130          Return a list of argument names that should have artifacts pre-loaded.
131          The framework will load the artifact before invoking the tool,
132          replacing the filename with an Artifact object containing content and metadata.
133  
134          Subclasses can override this property to specify which parameters
135          should have artifacts pre-loaded.
136  
137          Returns:
138              List of parameter names to pre-load artifacts for.
139          """
140          return []
141  
142      @property
143      def artifact_params(self) -> Dict[str, ArtifactTypeInfo]:
144          """
145          Return detailed information about artifact parameters.
146  
147          This maps parameter names to ArtifactTypeInfo objects that indicate:
148          - is_list: Whether the parameter expects a list of artifacts
149          - is_optional: Whether the parameter is optional
150  
151          Subclasses can override this for fine-grained control over artifact loading.
152          Default implementation creates basic info from artifact_args.
153  
154          Returns:
155              Dict mapping parameter names to ArtifactTypeInfo.
156          """
157          # Default: create basic info from artifact_args
158          return {name: ArtifactTypeInfo(is_artifact=True) for name in self.artifact_args}
159  
160      @property
161      def ctx_facade_param_name(self) -> Optional[str]:
162          """
163          Return the parameter name that should receive a ToolContextFacade.
164  
165          If not None, the framework will create and inject a ToolContextFacade
166          instance for this parameter before invoking the tool.
167  
168          Subclasses can override this property to specify the parameter name.
169          Default is None (no facade injection).
170  
171          Returns:
172              Parameter name for ToolContextFacade injection, or None.
173          """
174          return None
175  
176      def _get_declaration(self) -> Optional[Any]:
177          """
178          Generate the FunctionDeclaration for this dynamic tool.
179          This follows the same pattern as PeerAgentTool and MCP tools.
180          """
181          # Update the tool name and description to match what the module defines
182          self.name = self.tool_name
183          self.description = self.tool_description
184          
185          return adk_types.FunctionDeclaration(
186              name=self.tool_name,
187              description=self.tool_description,
188              parameters=self.parameters_schema,
189          )
190  
191      async def run_async(
192          self, *, args: Dict[str, Any], tool_context: ToolContext
193      ) -> Dict[str, Any]:
194          """
195          Asynchronously runs the tool with the given arguments.
196          This method resolves embeds in arguments and then delegates the call
197          to the abstract _run_async_impl.
198          """
199          log_identifier = f"[DynamicTool:{self.tool_name}]"
200          resolved_kwargs = args.copy()
201  
202          types_to_resolve = EARLY_EMBED_TYPES
203          if self.resolution_type == "all":
204              types_to_resolve = EARLY_EMBED_TYPES.union(LATE_EMBED_TYPES)
205  
206          # Unlike ADKToolWrapper, DynamicTools receive all args in a single dict.
207          # We iterate through this dict to resolve embeds.
208          for key, value in args.items():
209              if key in self.raw_string_args and isinstance(value, str):
210                  log.debug(
211                      "%s Skipping embed resolution for raw string kwarg '%s'",
212                      log_identifier,
213                      key,
214                  )
215              elif isinstance(value, str) and EMBED_DELIMITER_OPEN in value:
216                  log.debug("%s Resolving embeds for kwarg '%s'", log_identifier, key)
217                  # Create the resolution context
218                  if hasattr(tool_context, "_invocation_context"):
219                      # Use the invocation context if available
220                      invocation_context = tool_context._invocation_context
221                  else:
222                      # Error if no invocation context is found
223                      raise RuntimeError(
224                          f"{log_identifier} No invocation context found in ToolContext. Cannot resolve embeds."
225                      )
226                  session_context = invocation_context.session
227                  if not session_context:
228                      raise RuntimeError(
229                          f"{log_identifier} No session context found in invocation context. Cannot resolve embeds."
230                      )
231                  resolution_context = {
232                      "artifact_service": invocation_context.artifact_service,
233                      "session_context": {
234                          "session_id": get_original_session_id(invocation_context),
235                          "user_id": session_context.user_id,
236                          "app_name": session_context.app_name,
237                      },
238                  }
239                  resolved_value, _, _ = await resolve_embeds_in_string(
240                      text=value,
241                      context=resolution_context,
242                      resolver_func=evaluate_embed,
243                      types_to_resolve=types_to_resolve,
244                      resolution_mode=ResolutionMode.TOOL_PARAMETER,
245                      log_identifier=log_identifier,
246                      config=self.tool_config,
247                  )
248                  resolved_kwargs[key] = resolved_value
249  
250          # Pre-load artifacts for Artifact parameters
251          artifact_param_info = self.artifact_params
252          if artifact_param_info:
253              error = await resolve_artifact_params(
254                  artifact_params=artifact_param_info,
255                  resolved_kwargs=resolved_kwargs,
256                  tool_context=tool_context,
257                  tool_name=self.tool_name,
258                  log_identifier=log_identifier,
259              )
260              if error is not None:
261                  return error
262  
263          # Inject ToolContextFacade if the tool expects it
264          ctx_param = self.ctx_facade_param_name
265          if ctx_param:
266              facade = ToolContextFacade(
267                  tool_context=tool_context,
268                  tool_config=self.tool_config if isinstance(self.tool_config, dict) else {},
269              )
270              resolved_kwargs[ctx_param] = facade
271              log.debug(
272                  "%s Injected ToolContextFacade as '%s'",
273                  log_identifier,
274                  ctx_param,
275              )
276  
277          return await self._run_async_impl(
278              args=resolved_kwargs, tool_context=tool_context, credential=None
279          )
280  
281      @abstractmethod
282      async def _run_async_impl(
283          self, args: dict, tool_context: ToolContext, credential: Optional[str] = None
284      ) -> dict:
285          """
286          Implement the actual tool logic.
287          Must return a dictionary response.
288          """
289          pass
290  
291  
292  # --- Internal Adapter for Function-Based Tools ---
293  
294  
295  class _SchemaDetectionResult:
296      """Result from schema generation with detected special params."""
297  
298      def __init__(self):
299          self.schema: Optional[adk_types.Schema] = None
300          # Maps param name to ArtifactTypeInfo (includes is_list, is_optional)
301          self.artifact_params: Dict[str, ArtifactTypeInfo] = {}
302          self.ctx_facade_param_name: Optional[str] = None
303  
304      @property
305      def artifact_args(self) -> Set[str]:
306          """Property returning set of artifact param names."""
307          return set(self.artifact_params.keys())
308  
309  
310  def _get_schema_from_signature(
311      func: Callable,
312      artifact_args: Optional[Set[str]] = None,
313      detection_result: Optional[_SchemaDetectionResult] = None,
314  ) -> adk_types.Schema:
315      """
316      Introspects a function's signature and generates an ADK Schema for its parameters.
317  
318      Args:
319          func: The function to introspect
320          artifact_args: Optional set to populate with param names that have
321                         Artifact type annotation (will be pre-loaded)
322          detection_result: Optional result object to populate with all detected params
323      """
324      sig = inspect.signature(func)
325      properties = {}
326      required = []
327  
328      type_map = {
329          str: adk_types.Type.STRING,
330          int: adk_types.Type.INTEGER,
331          float: adk_types.Type.NUMBER,
332          bool: adk_types.Type.BOOLEAN,
333          list: adk_types.Type.ARRAY,
334          dict: adk_types.Type.OBJECT,
335      }
336  
337      for param in sig.parameters.values():
338          if param.name in ("tool_context", "tool_config", "kwargs", "self", "cls"):
339              continue
340  
341          param_type = param.annotation
342          is_optional = False
343  
344          # Handle Optional[T] which is Union[T, None]
345          origin = get_origin(param_type)
346          args = get_args(param_type)
347          if origin is Union and type(None) in args:
348              is_optional = True
349              # Get the actual type from Union[T, None]
350              param_type = next((t for t in args if t is not type(None)), Any)
351  
352          # Check for ToolContextFacade - exclude from schema (injected by framework)
353          if _is_tool_context_facade_param(param_type):
354              if detection_result is not None:
355                  detection_result.ctx_facade_param_name = param.name
356              log.debug(
357                  "Detected ToolContextFacade param '%s' in %s, excluding from schema",
358                  param.name,
359                  func.__name__,
360              )
361              continue  # Don't add to schema - framework injects this
362  
363          # Check for Artifact type - translate to appropriate schema for LLM
364          # Also check the original annotation for List/Optional detection
365          original_annotation = param.annotation
366          artifact_type_info = get_artifact_info(original_annotation)
367  
368          if artifact_type_info.is_artifact:
369              if artifact_args is not None:
370                  artifact_args.add(param.name)
371              if detection_result is not None:
372                  detection_result.artifact_params[param.name] = artifact_type_info
373  
374              if artifact_type_info.is_list:
375                  # List[Artifact] -> array of strings (filenames)
376                  log.debug(
377                      "Detected List[Artifact] param '%s' in %s, translating to ARRAY of STRING",
378                      param.name,
379                      func.__name__,
380                  )
381                  properties[param.name] = adk_types.Schema(
382                      type=adk_types.Type.ARRAY,
383                      items=adk_types.Schema(type=adk_types.Type.STRING),
384                      nullable=is_optional or artifact_type_info.is_optional,
385                  )
386              else:
387                  # Single Artifact -> string (filename)
388                  log.debug(
389                      "Detected Artifact param '%s' in %s, translating to STRING",
390                      param.name,
391                      func.__name__,
392                  )
393                  properties[param.name] = adk_types.Schema(
394                      type=adk_types.Type.STRING,
395                      nullable=is_optional or artifact_type_info.is_optional,
396                  )
397          else:
398              # Resolve generic aliases (e.g., List[str] -> list, Dict[str, Any] -> dict)
399              resolved_type = get_origin(param_type) or param_type
400              adk_type = type_map.get(resolved_type)
401              if not adk_type:
402                  # Default to string if type is not supported or specified (e.g., Any)
403                  adk_type = adk_types.Type.STRING
404              properties[param.name] = adk_types.Schema(type=adk_type, nullable=is_optional)
405  
406          if param.default is inspect.Parameter.empty and not is_optional:
407              required.append(param.name)
408  
409      return adk_types.Schema(
410          type=adk_types.Type.OBJECT,
411          properties=properties,
412          required=required,
413      )
414  
415  
416  class _FunctionAsDynamicTool(DynamicTool):
417      """
418      Internal adapter to wrap a standard Python function as a DynamicTool.
419      """
420  
421      def __init__(
422          self,
423          func: Callable,
424          tool_config: Optional[Union[dict, BaseModel]] = None,
425          provider_instance: Optional[Any] = None,
426      ):
427          super().__init__(tool_config=tool_config)
428          self._func = func
429          self._provider_instance = provider_instance
430  
431          # Detect special params during schema generation
432          self._detection_result = _SchemaDetectionResult()
433          self._schema = _get_schema_from_signature(
434              func,
435              detection_result=self._detection_result,
436          )
437  
438          if self._detection_result.artifact_args:
439              log.info(
440                  "[_FunctionAsDynamicTool:%s] Will pre-load artifacts for params: %s",
441                  func.__name__,
442                  list(self._detection_result.artifact_args),
443              )
444  
445          if self._detection_result.ctx_facade_param_name:
446              log.info(
447                  "[_FunctionAsDynamicTool:%s] Will inject ToolContextFacade as '%s'",
448                  func.__name__,
449                  self._detection_result.ctx_facade_param_name,
450              )
451  
452          # Check if the function is an instance method that needs `self`
453          self._is_instance_method = False
454          sig = inspect.signature(self._func)
455          if sig.parameters:
456              first_param = next(iter(sig.parameters.values()))
457              if first_param.name == "self":
458                  self._is_instance_method = True
459  
460      @property
461      def tool_name(self) -> str:
462          return self._func.__name__
463  
464      @property
465      def tool_description(self) -> str:
466          return inspect.getdoc(self._func) or ""
467  
468      @property
469      def parameters_schema(self) -> adk_types.Schema:
470          return self._schema
471  
472      @property
473      def artifact_args(self) -> List[str]:
474          """Return the detected Artifact parameters."""
475          return list(self._detection_result.artifact_args)
476  
477      @property
478      def artifact_params(self) -> Dict[str, ArtifactTypeInfo]:
479          """Return detailed info about Artifact parameters (including is_list)."""
480          return self._detection_result.artifact_params
481  
482      @property
483      def ctx_facade_param_name(self) -> Optional[str]:
484          """Return the detected ToolContextFacade parameter name."""
485          return self._detection_result.ctx_facade_param_name
486  
487      async def _run_async_impl(
488          self,
489          args: dict,
490          tool_context: ToolContext,
491          credential: Optional[str] = None,
492      ) -> dict:
493          # Inject tool_context and tool_config if the function expects them
494          sig = inspect.signature(self._func)
495          if "tool_context" in sig.parameters:
496              args["tool_context"] = tool_context
497          if "tool_config" in sig.parameters:
498              args["tool_config"] = self.tool_config
499  
500          if self._provider_instance and self._is_instance_method:
501              # It's an instance method, call it on the provider instance
502              return await self._func(self._provider_instance, **args)
503          else:
504              # It's a static method or a standalone function
505              return await self._func(**args)
506  
507  
508  # --- Base Class for Tool Providers ---
509  
510  
511  class DynamicToolProvider(ABC):
512      """
513      Base class for dynamic tool providers that can generate a list of tools
514      programmatically from a single configuration block.
515      """
516  
517      config_model: Optional[Type[BaseModel]] = None
518      _decorated_tools: List[Callable] = []
519  
520      @classmethod
521      def register_tool(cls, func: Callable) -> Callable:
522          """
523          A decorator to register a standard async function as a tool.
524          The decorated function's signature and docstring will be used to
525          create the tool definition.
526          """
527          # This check is crucial. It runs for each decorated method.
528          # If the current class `cls` is using the list from the base class
529          # `DynamicToolProvider`, it creates a new, empty list just for `cls`.
530          # On subsequent decorator calls for the same `cls`, this condition will
531          # be false, and it will append to the existing list.
532          if (
533              not hasattr(cls, "_decorated_tools")
534              or cls._decorated_tools is DynamicToolProvider._decorated_tools
535          ):
536              cls._decorated_tools = []
537  
538          cls._decorated_tools.append(func)
539          return func
540  
541      def _create_tools_from_decorators(
542          self, tool_config: Optional[Union[dict, BaseModel]] = None
543      ) -> List[DynamicTool]:
544          """
545          Internal helper to convert decorated functions into DynamicTool instances.
546          """
547          tools = []
548          for func in self._decorated_tools:
549              adapter = _FunctionAsDynamicTool(func, tool_config, provider_instance=self)
550              tools.append(adapter)
551          return tools
552  
553      def get_all_tools_for_framework(
554          self, tool_config: Optional[Union[dict, BaseModel]] = None
555      ) -> List[DynamicTool]:
556          """
557          Framework-internal method that automatically combines decorated tools with custom tools.
558          This is called by the ADK setup code, not by users.
559  
560          Args:
561              tool_config: The configuration dictionary from the agent's YAML file.
562  
563          Returns:
564              A list of all DynamicTool objects (decorated + custom).
565          """
566          # Get tools from decorators automatically
567          decorated_tools = self._create_tools_from_decorators(tool_config)
568  
569          # Get custom tools from the user's implementation
570          custom_tools = self.create_tools(tool_config)
571  
572          return decorated_tools + custom_tools
573  
574      @abstractmethod
575      def create_tools(self, tool_config: Optional[Union[dict, BaseModel]] = None) -> List[DynamicTool]:
576          """
577          Generate and return a list of custom DynamicTool instances.
578  
579          Note: Tools registered with the @register_tool decorator are automatically
580          included by the framework - you don't need to handle them here.
581  
582          Args:
583              tool_config: The configuration dictionary from the agent's YAML file.
584  
585          Returns:
586              A list of custom DynamicTool objects (decorated tools are added automatically).
587          """
588          pass