/ src / python / txtai / agent / tool / factory.py
factory.py
  1  """
  2  Factory module
  3  """
  4  
  5  import inspect
  6  
  7  from types import FunctionType, MethodType
  8  
  9  import mcpadapt.core
 10  
 11  from mcpadapt.smolagents_adapter import SmolAgentsAdapter
 12  from smolagents import PythonInterpreterTool, Tool, tool as CreateTool, UserInputTool, WebSearchTool
 13  from transformers.utils import chat_template_utils, TypeHintParsingException
 14  
 15  from ...embeddings import Embeddings
 16  from .bash import BashTool
 17  from .edit import EditTool
 18  from .embeddings import EmbeddingsTool
 19  from .function import FunctionTool
 20  from .glob import GlobTool
 21  from .grep import GrepTool
 22  from .read import ReadTool
 23  from .skill import SkillTool
 24  from .todo import TodoWriteTool
 25  from .write import WriteTool
 26  
 27  
 28  class ToolFactory:
 29      """
 30      Methods to create tools.
 31      """
 32  
 33      # Default toolkit
 34      DEFAULTS = {
 35          "bash": BashTool(),
 36          "edit": EditTool(),
 37          "glob": GlobTool(),
 38          "grep": GrepTool(),
 39          "python": PythonInterpreterTool(),
 40          "question": UserInputTool(),
 41          "read": ReadTool(),
 42          "todowrite": TodoWriteTool(),
 43          "websearch": WebSearchTool(),
 44          "write": WriteTool(),
 45      }
 46  
 47      # Backwards compatible mappings
 48      DEFAULTS["webview"] = DEFAULTS["read"]
 49  
 50      @staticmethod
 51      def create(config):
 52          """
 53          Creates a new list of tools. This method iterates of the `tools` configuration option and creates a Tool instance
 54          for each entry. This supports the following:
 55  
 56            - Tool instance
 57            - Dictionary with `name`, `description`, `inputs`, `output` and `target` function configuration
 58            - String with a tool alias name
 59  
 60          Returns:
 61              list of tools
 62          """
 63  
 64          tools = []
 65          for tool in config.pop("tools", []):
 66              # Create tool from function and it's documentation
 67              if not isinstance(tool, Tool) and (isinstance(tool, (FunctionType, MethodType)) or hasattr(tool, "__call__")):
 68                  tool = ToolFactory.createtool(tool)
 69  
 70              # Create tool from input dictionary
 71              elif isinstance(tool, dict):
 72                  # Get target function
 73                  target = tool.get("target")
 74  
 75                  # Create tool from input dictionary
 76                  tool = (
 77                      EmbeddingsTool(tool)
 78                      if isinstance(target, Embeddings) or any(x in tool for x in ["container", "path"])
 79                      else ToolFactory.createtool(target, tool)
 80                  )
 81  
 82              # Get default tool, if applicable
 83              elif isinstance(tool, str) and tool in ToolFactory.DEFAULTS:
 84                  tool = ToolFactory.DEFAULTS[tool]
 85  
 86              # Get ALL default tools, if applicable
 87              elif isinstance(tool, str) and tool == "defaults":
 88                  tools.extend(set(ToolFactory.DEFAULTS.values()))
 89                  tool = None
 90  
 91              # Support importing MCP tool collections
 92              elif isinstance(tool, str) and tool.startswith("http"):
 93                  tools.extend(mcpadapt.core.MCPAdapt({"url": tool}, SmolAgentsAdapter()).tools())
 94                  tool = None
 95  
 96              # Load skill.md files
 97              elif isinstance(tool, str) and tool.endswith(".md"):
 98                  tool = SkillTool(tool)
 99  
100              # Add tool
101              if tool:
102                  tools.append(tool)
103  
104          return tools
105  
106      @staticmethod
107      def createtool(target, config=None):
108          """
109          Creates a new Tool.
110  
111          Args:
112              target: target object or function
113              config: optional tool configuration
114  
115          Returns:
116              Tool
117          """
118  
119          try:
120              # Try to create using CreateTool function - this fails when no annotations are available
121              return CreateTool(target)
122          except (TypeHintParsingException, TypeError):
123              return ToolFactory.fromdocs(target, config if config else {})
124  
125      @staticmethod
126      def fromdocs(target, config):
127          """
128          Creates a tool from method documentation.
129  
130          Args:
131              target: target object or function
132              config: tool configuration
133  
134          Returns:
135              Tool
136          """
137  
138          # Get function name and target - use target if it's a function or method, else use target.__call__
139          name = target.__name__ if isinstance(target, (FunctionType, MethodType)) or not hasattr(target, "__call__") else target.__class__.__name__
140          target = target if isinstance(target, (FunctionType, MethodType)) or not hasattr(target, "__call__") else target.__call__
141  
142          # Extract target documentation
143          doc = inspect.getdoc(target)
144          description, parameters, _ = chat_template_utils.parse_google_format_docstring(doc.strip()) if doc else (None, {}, None)
145  
146          # Get list of required parameters
147          signature = inspect.signature(target)
148          inputs = {}
149          for pname, param in signature.parameters.items():
150              if (
151                  param.default == inspect.Parameter.empty
152                  and pname in parameters
153                  and param.kind not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
154              ):
155                  inputs[pname] = {"type": "any", "description": parameters[pname]}
156  
157          # Create function tool
158          return FunctionTool(
159              {
160                  "name": config.get("name", name.lower()),
161                  "description": config.get("description", description),
162                  "inputs": config.get("inputs", inputs),
163                  "target": config.get("target", target),
164              }
165          )