/ tools / registry.py
registry.py
  1  """Central registry for all hermes-agent tools.
  2  
  3  Each tool file calls ``registry.register()`` at module level to declare its
  4  schema, handler, toolset membership, and availability check.  ``model_tools.py``
  5  queries the registry instead of maintaining its own parallel data structures.
  6  
  7  Import chain (circular-import safe):
  8      tools/registry.py  (no imports from model_tools or tool files)
  9             ^
 10      tools/*.py  (import from tools.registry at module level)
 11             ^
 12      model_tools.py  (imports tools.registry + all tool modules)
 13             ^
 14      run_agent.py, cli.py, batch_runner.py, etc.
 15  """
 16  
 17  import ast
 18  import importlib
 19  import json
 20  import logging
 21  import threading
 22  import time
 23  from pathlib import Path
 24  from typing import Callable, Dict, List, Optional, Set
 25  
 26  logger = logging.getLogger(__name__)
 27  
 28  
 29  def _is_registry_register_call(node: ast.AST) -> bool:
 30      """Return True when *node* is a ``registry.register(...)`` call expression."""
 31      if not isinstance(node, ast.Expr) or not isinstance(node.value, ast.Call):
 32          return False
 33      func = node.value.func
 34      return (
 35          isinstance(func, ast.Attribute)
 36          and func.attr == "register"
 37          and isinstance(func.value, ast.Name)
 38          and func.value.id == "registry"
 39      )
 40  
 41  
 42  def _module_registers_tools(module_path: Path) -> bool:
 43      """Return True when the module contains a top-level ``registry.register(...)`` call.
 44  
 45      Only inspects module-body statements so that helper modules which happen
 46      to call ``registry.register()`` inside a function are not picked up.
 47      """
 48      try:
 49          source = module_path.read_text(encoding="utf-8")
 50          tree = ast.parse(source, filename=str(module_path))
 51      except (OSError, SyntaxError):
 52          return False
 53  
 54      return any(_is_registry_register_call(stmt) for stmt in tree.body)
 55  
 56  
 57  def discover_builtin_tools(tools_dir: Optional[Path] = None) -> List[str]:
 58      """Import built-in self-registering tool modules and return their module names."""
 59      tools_path = Path(tools_dir) if tools_dir is not None else Path(__file__).resolve().parent
 60      module_names = [
 61          f"tools.{path.stem}"
 62          for path in sorted(tools_path.glob("*.py"))
 63          if path.name not in {"__init__.py", "registry.py", "mcp_tool.py"}
 64          and _module_registers_tools(path)
 65      ]
 66  
 67      imported: List[str] = []
 68      for mod_name in module_names:
 69          try:
 70              importlib.import_module(mod_name)
 71              imported.append(mod_name)
 72          except Exception as e:
 73              logger.warning("Could not import tool module %s: %s", mod_name, e)
 74      return imported
 75  
 76  
 77  class ToolEntry:
 78      """Metadata for a single registered tool."""
 79  
 80      __slots__ = (
 81          "name", "toolset", "schema", "handler", "check_fn",
 82          "requires_env", "is_async", "description", "emoji",
 83          "max_result_size_chars",
 84      )
 85  
 86      def __init__(self, name, toolset, schema, handler, check_fn,
 87                   requires_env, is_async, description, emoji,
 88                   max_result_size_chars=None):
 89          self.name = name
 90          self.toolset = toolset
 91          self.schema = schema
 92          self.handler = handler
 93          self.check_fn = check_fn
 94          self.requires_env = requires_env
 95          self.is_async = is_async
 96          self.description = description
 97          self.emoji = emoji
 98          self.max_result_size_chars = max_result_size_chars
 99  
100  
101  # ---------------------------------------------------------------------------
102  # check_fn TTL cache
103  #
104  # check_fn callables like tools/terminal_tool.check_terminal_requirements
105  # probe external state (Docker daemon, Modal SDK install, playwright binary
106  # availability). For a long-lived CLI or gateway process, calling them on
107  # every get_definitions() is pure waste — external state changes on human
108  # timescales. Cache results for ~30 s so env-var flips via ``hermes tools``
109  # or live credential file changes propagate within a turn or two without
110  # requiring any explicit invalidation.
111  # ---------------------------------------------------------------------------
112  
113  _CHECK_FN_TTL_SECONDS = 30.0
114  _check_fn_cache: Dict[Callable, tuple[float, bool]] = {}
115  _check_fn_cache_lock = threading.Lock()
116  
117  
118  def _check_fn_cached(fn: Callable) -> bool:
119      """Return bool(fn()), TTL-cached across calls. Swallows exceptions as False."""
120      now = time.monotonic()
121      with _check_fn_cache_lock:
122          cached = _check_fn_cache.get(fn)
123          if cached is not None:
124              ts, value = cached
125              if now - ts < _CHECK_FN_TTL_SECONDS:
126                  return value
127      try:
128          value = bool(fn())
129      except Exception:
130          value = False
131      with _check_fn_cache_lock:
132          _check_fn_cache[fn] = (now, value)
133      return value
134  
135  
136  def invalidate_check_fn_cache() -> None:
137      """Drop all cached ``check_fn`` results. Call after config changes that
138      affect tool availability (e.g. ``hermes tools enable``)."""
139      with _check_fn_cache_lock:
140          _check_fn_cache.clear()
141  
142  
143  class ToolRegistry:
144      """Singleton registry that collects tool schemas + handlers from tool files."""
145  
146      def __init__(self):
147          self._tools: Dict[str, ToolEntry] = {}
148          self._toolset_checks: Dict[str, Callable] = {}
149          self._toolset_aliases: Dict[str, str] = {}
150          # MCP dynamic refresh can mutate the registry while other threads are
151          # reading tool metadata, so keep mutations serialized and readers on
152          # stable snapshots.
153          self._lock = threading.RLock()
154          # Monotonically-increasing generation counter. Bumped on every
155          # mutation (register / deregister / register_toolset_alias / MCP
156          # refresh). External callers (e.g. get_tool_definitions) can memoize
157          # against it: a cache entry keyed on the generation is valid for as
158          # long as the generation hasn't changed.
159          self._generation: int = 0
160  
161      def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]:
162          """Return a coherent snapshot of registry entries and toolset checks."""
163          with self._lock:
164              return list(self._tools.values()), dict(self._toolset_checks)
165  
166      def _snapshot_entries(self) -> List[ToolEntry]:
167          """Return a stable snapshot of registered tool entries."""
168          return self._snapshot_state()[0]
169  
170      def _snapshot_toolset_checks(self) -> Dict[str, Callable]:
171          """Return a stable snapshot of toolset availability checks."""
172          return self._snapshot_state()[1]
173  
174      def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool:
175          """Run a toolset check, treating missing or failing checks as unavailable/available."""
176          if not check:
177              return True
178          try:
179              return bool(check())
180          except Exception:
181              logger.debug("Toolset %s check raised; marking unavailable", toolset)
182              return False
183  
184      def get_entry(self, name: str) -> Optional[ToolEntry]:
185          """Return a registered tool entry by name, or None."""
186          with self._lock:
187              return self._tools.get(name)
188  
189      def get_registered_toolset_names(self) -> List[str]:
190          """Return sorted unique toolset names present in the registry."""
191          return sorted({entry.toolset for entry in self._snapshot_entries()})
192  
193      def get_tool_names_for_toolset(self, toolset: str) -> List[str]:
194          """Return sorted tool names registered under a given toolset."""
195          return sorted(
196              entry.name for entry in self._snapshot_entries()
197              if entry.toolset == toolset
198          )
199  
200      def register_toolset_alias(self, alias: str, toolset: str) -> None:
201          """Register an explicit alias for a canonical toolset name."""
202          with self._lock:
203              existing = self._toolset_aliases.get(alias)
204              if existing and existing != toolset:
205                  logger.warning(
206                      "Toolset alias collision: '%s' (%s) overwritten by %s",
207                      alias, existing, toolset,
208                  )
209              self._toolset_aliases[alias] = toolset
210              self._generation += 1
211  
212      def get_registered_toolset_aliases(self) -> Dict[str, str]:
213          """Return a snapshot of ``{alias: canonical_toolset}`` mappings."""
214          with self._lock:
215              return dict(self._toolset_aliases)
216  
217      def get_toolset_alias_target(self, alias: str) -> Optional[str]:
218          """Return the canonical toolset name for an alias, or None."""
219          with self._lock:
220              return self._toolset_aliases.get(alias)
221  
222      # ------------------------------------------------------------------
223      # Registration
224      # ------------------------------------------------------------------
225  
226      def register(
227          self,
228          name: str,
229          toolset: str,
230          schema: dict,
231          handler: Callable,
232          check_fn: Callable = None,
233          requires_env: list = None,
234          is_async: bool = False,
235          description: str = "",
236          emoji: str = "",
237          max_result_size_chars: int | float | None = None,
238      ):
239          """Register a tool.  Called at module-import time by each tool file."""
240          with self._lock:
241              existing = self._tools.get(name)
242              if existing and existing.toolset != toolset:
243                  # Allow MCP-to-MCP overwrites (legitimate: server refresh,
244                  # or two MCP servers with overlapping tool names).
245                  both_mcp = (
246                      existing.toolset.startswith("mcp-")
247                      and toolset.startswith("mcp-")
248                  )
249                  if both_mcp:
250                      logger.debug(
251                          "Tool '%s': MCP toolset '%s' overwriting MCP toolset '%s'",
252                          name, toolset, existing.toolset,
253                      )
254                  else:
255                      # Reject shadowing — prevent plugins/MCP from overwriting
256                      # built-in tools or vice versa.
257                      logger.error(
258                          "Tool registration REJECTED: '%s' (toolset '%s') would "
259                          "shadow existing tool from toolset '%s'. Deregister the "
260                          "existing tool first if this is intentional.",
261                          name, toolset, existing.toolset,
262                      )
263                      return
264              self._tools[name] = ToolEntry(
265                  name=name,
266                  toolset=toolset,
267                  schema=schema,
268                  handler=handler,
269                  check_fn=check_fn,
270                  requires_env=requires_env or [],
271                  is_async=is_async,
272                  description=description or schema.get("description", ""),
273                  emoji=emoji,
274                  max_result_size_chars=max_result_size_chars,
275              )
276              if check_fn and toolset not in self._toolset_checks:
277                  self._toolset_checks[toolset] = check_fn
278              self._generation += 1
279  
280      def deregister(self, name: str) -> None:
281          """Remove a tool from the registry.
282  
283          Also cleans up the toolset check if no other tools remain in the
284          same toolset.  Used by MCP dynamic tool discovery to nuke-and-repave
285          when a server sends ``notifications/tools/list_changed``.
286          """
287          with self._lock:
288              entry = self._tools.pop(name, None)
289              if entry is None:
290                  return
291              # Drop the toolset check and aliases if this was the last tool in
292              # that toolset.
293              toolset_still_exists = any(
294                  e.toolset == entry.toolset for e in self._tools.values()
295              )
296              if not toolset_still_exists:
297                  self._toolset_checks.pop(entry.toolset, None)
298                  self._toolset_aliases = {
299                      alias: target
300                      for alias, target in self._toolset_aliases.items()
301                      if target != entry.toolset
302                  }
303              self._generation += 1
304          logger.debug("Deregistered tool: %s", name)
305  
306      # ------------------------------------------------------------------
307      # Schema retrieval
308      # ------------------------------------------------------------------
309  
310      def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]:
311          """Return OpenAI-format tool schemas for the requested tool names.
312  
313          Only tools whose ``check_fn()`` returns True (or have no check_fn)
314          are included. ``check_fn()`` results are cached for ~30 s via
315          :func:`_check_fn_cached` to amortize repeat probes (check_terminal_
316          requirements probes modal/docker, browser checks probe playwright,
317          etc.); TTL chosen so env-var changes (``hermes tools enable foo``)
318          still take effect in near-real-time without forcing a full cache
319          flush on every call.
320          """
321          result = []
322          # Per-call cache on top of the 30 s TTL — handles repeat probes of the
323          # same check_fn within one definitions pass without re-reading the
324          # TTL clock.
325          check_results: Dict[Callable, bool] = {}
326          entries_by_name = {entry.name: entry for entry in self._snapshot_entries()}
327          for name in sorted(tool_names):
328              entry = entries_by_name.get(name)
329              if not entry:
330                  continue
331              if entry.check_fn:
332                  if entry.check_fn not in check_results:
333                      check_results[entry.check_fn] = _check_fn_cached(entry.check_fn)
334                  if not check_results[entry.check_fn]:
335                      if not quiet:
336                          logger.debug("Tool %s unavailable (check failed)", name)
337                      continue
338              # Ensure schema always has a "name" field — use entry.name as fallback
339              schema_with_name = {**entry.schema, "name": entry.name}
340              result.append({"type": "function", "function": schema_with_name})
341          return result
342  
343      # ------------------------------------------------------------------
344      # Dispatch
345      # ------------------------------------------------------------------
346  
347      def dispatch(self, name: str, args: dict, **kwargs) -> str:
348          """Execute a tool handler by name.
349  
350          * Async handlers are bridged automatically via ``_run_async()``.
351          * All exceptions are caught and returned as ``{"error": "..."}``
352            for consistent error format.
353          """
354          entry = self.get_entry(name)
355          if not entry:
356              return json.dumps({"error": f"Unknown tool: {name}"})
357          try:
358              if entry.is_async:
359                  from model_tools import _run_async
360                  return _run_async(entry.handler(args, **kwargs))
361              return entry.handler(args, **kwargs)
362          except Exception as e:
363              logger.exception("Tool %s dispatch error: %s", name, e)
364              return json.dumps({"error": f"Tool execution failed: {type(e).__name__}: {e}"})
365  
366      # ------------------------------------------------------------------
367      # Query helpers  (replace redundant dicts in model_tools.py)
368      # ------------------------------------------------------------------
369  
370      def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float:
371          """Return per-tool max result size, or *default* (or global default)."""
372          entry = self.get_entry(name)
373          if entry and entry.max_result_size_chars is not None:
374              return entry.max_result_size_chars
375          if default is not None:
376              return default
377          from tools.budget_config import DEFAULT_RESULT_SIZE_CHARS
378          return DEFAULT_RESULT_SIZE_CHARS
379  
380      def get_all_tool_names(self) -> List[str]:
381          """Return sorted list of all registered tool names."""
382          return sorted(entry.name for entry in self._snapshot_entries())
383  
384      def get_schema(self, name: str) -> Optional[dict]:
385          """Return a tool's raw schema dict, bypassing check_fn filtering.
386  
387          Useful for token estimation and introspection where availability
388          doesn't matter — only the schema content does.
389          """
390          entry = self.get_entry(name)
391          return entry.schema if entry else None
392  
393      def get_toolset_for_tool(self, name: str) -> Optional[str]:
394          """Return the toolset a tool belongs to, or None."""
395          entry = self.get_entry(name)
396          return entry.toolset if entry else None
397  
398      def get_emoji(self, name: str, default: str = "⚡") -> str:
399          """Return the emoji for a tool, or *default* if unset."""
400          entry = self.get_entry(name)
401          return (entry.emoji if entry and entry.emoji else default)
402  
403      def get_tool_to_toolset_map(self) -> Dict[str, str]:
404          """Return ``{tool_name: toolset_name}`` for every registered tool."""
405          return {entry.name: entry.toolset for entry in self._snapshot_entries()}
406  
407      def is_toolset_available(self, toolset: str) -> bool:
408          """Check if a toolset's requirements are met.
409  
410          Returns False (rather than crashing) when the check function raises
411          an unexpected exception (e.g. network error, missing import, bad config).
412          """
413          with self._lock:
414              check = self._toolset_checks.get(toolset)
415          return self._evaluate_toolset_check(toolset, check)
416  
417      def check_toolset_requirements(self) -> Dict[str, bool]:
418          """Return ``{toolset: available_bool}`` for every toolset."""
419          entries, toolset_checks = self._snapshot_state()
420          toolsets = sorted({entry.toolset for entry in entries})
421          return {
422              toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset))
423              for toolset in toolsets
424          }
425  
426      def get_available_toolsets(self) -> Dict[str, dict]:
427          """Return toolset metadata for UI display."""
428          toolsets: Dict[str, dict] = {}
429          entries, toolset_checks = self._snapshot_state()
430          for entry in entries:
431              ts = entry.toolset
432              if ts not in toolsets:
433                  toolsets[ts] = {
434                      "available": self._evaluate_toolset_check(
435                          ts, toolset_checks.get(ts)
436                      ),
437                      "tools": [],
438                      "description": "",
439                      "requirements": [],
440                  }
441              toolsets[ts]["tools"].append(entry.name)
442              if entry.requires_env:
443                  for env in entry.requires_env:
444                      if env not in toolsets[ts]["requirements"]:
445                          toolsets[ts]["requirements"].append(env)
446          return toolsets
447  
448      def get_toolset_requirements(self) -> Dict[str, dict]:
449          """Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat."""
450          result: Dict[str, dict] = {}
451          entries, toolset_checks = self._snapshot_state()
452          for entry in entries:
453              ts = entry.toolset
454              if ts not in result:
455                  result[ts] = {
456                      "name": ts,
457                      "env_vars": [],
458                      "check_fn": toolset_checks.get(ts),
459                      "setup_url": None,
460                      "tools": [],
461                  }
462              if entry.name not in result[ts]["tools"]:
463                  result[ts]["tools"].append(entry.name)
464              for env in entry.requires_env:
465                  if env not in result[ts]["env_vars"]:
466                      result[ts]["env_vars"].append(env)
467          return result
468  
469      def check_tool_availability(self, quiet: bool = False):
470          """Return (available_toolsets, unavailable_info) like the old function."""
471          available = []
472          unavailable = []
473          seen = set()
474          entries, toolset_checks = self._snapshot_state()
475          for entry in entries:
476              ts = entry.toolset
477              if ts in seen:
478                  continue
479              seen.add(ts)
480              if self._evaluate_toolset_check(ts, toolset_checks.get(ts)):
481                  available.append(ts)
482              else:
483                  unavailable.append({
484                      "name": ts,
485                      "env_vars": entry.requires_env,
486                      "tools": [e.name for e in entries if e.toolset == ts],
487                  })
488          return available, unavailable
489  
490  
491  # Module-level singleton
492  registry = ToolRegistry()
493  
494  
495  # ---------------------------------------------------------------------------
496  # Helpers for tool response serialization
497  # ---------------------------------------------------------------------------
498  # Every tool handler must return a JSON string.  These helpers eliminate the
499  # boilerplate ``json.dumps({"error": msg}, ensure_ascii=False)`` that appears
500  # hundreds of times across tool files.
501  #
502  # Usage:
503  #   from tools.registry import registry, tool_error, tool_result
504  #
505  #   return tool_error("something went wrong")
506  #   return tool_error("not found", code=404)
507  #   return tool_result(success=True, data=payload)
508  #   return tool_result(items)            # pass a dict directly
509  
510  
511  def tool_error(message, **extra) -> str:
512      """Return a JSON error string for tool handlers.
513  
514      >>> tool_error("file not found")
515      '{"error": "file not found"}'
516      >>> tool_error("bad input", success=False)
517      '{"error": "bad input", "success": false}'
518      """
519      result = {"error": str(message)}
520      if extra:
521          result.update(extra)
522      return json.dumps(result, ensure_ascii=False)
523  
524  
525  def tool_result(data=None, **kwargs) -> str:
526      """Return a JSON result string for tool handlers.
527  
528      Accepts a dict positional arg *or* keyword arguments (not both):
529  
530      >>> tool_result(success=True, count=42)
531      '{"success": true, "count": 42}'
532      >>> tool_result({"key": "value"})
533      '{"key": "value"}'
534      """
535      if data is not None:
536          return json.dumps(data, ensure_ascii=False)
537      return json.dumps(kwargs, ensure_ascii=False)