registry.py
1 """Central registry for all hermes-agent tools. 2 3 Each tool file calls ``registry.register()`` at module level to declare its 4 schema, handler, toolset membership, and availability check. ``model_tools.py`` 5 queries the registry instead of maintaining its own parallel data structures. 6 7 Import chain (circular-import safe): 8 tools/registry.py (no imports from model_tools or tool files) 9 ^ 10 tools/*.py (import from tools.registry at module level) 11 ^ 12 model_tools.py (imports tools.registry + all tool modules) 13 ^ 14 run_agent.py, cli.py, batch_runner.py, etc. 15 """ 16 17 import ast 18 import importlib 19 import json 20 import logging 21 import threading 22 import time 23 from pathlib import Path 24 from typing import Callable, Dict, List, Optional, Set 25 26 logger = logging.getLogger(__name__) 27 28 29 def _is_registry_register_call(node: ast.AST) -> bool: 30 """Return True when *node* is a ``registry.register(...)`` call expression.""" 31 if not isinstance(node, ast.Expr) or not isinstance(node.value, ast.Call): 32 return False 33 func = node.value.func 34 return ( 35 isinstance(func, ast.Attribute) 36 and func.attr == "register" 37 and isinstance(func.value, ast.Name) 38 and func.value.id == "registry" 39 ) 40 41 42 def _module_registers_tools(module_path: Path) -> bool: 43 """Return True when the module contains a top-level ``registry.register(...)`` call. 44 45 Only inspects module-body statements so that helper modules which happen 46 to call ``registry.register()`` inside a function are not picked up. 47 """ 48 try: 49 source = module_path.read_text(encoding="utf-8") 50 tree = ast.parse(source, filename=str(module_path)) 51 except (OSError, SyntaxError): 52 return False 53 54 return any(_is_registry_register_call(stmt) for stmt in tree.body) 55 56 57 def discover_builtin_tools(tools_dir: Optional[Path] = None) -> List[str]: 58 """Import built-in self-registering tool modules and return their module names.""" 59 tools_path = Path(tools_dir) if tools_dir is not None else Path(__file__).resolve().parent 60 module_names = [ 61 f"tools.{path.stem}" 62 for path in sorted(tools_path.glob("*.py")) 63 if path.name not in {"__init__.py", "registry.py", "mcp_tool.py"} 64 and _module_registers_tools(path) 65 ] 66 67 imported: List[str] = [] 68 for mod_name in module_names: 69 try: 70 importlib.import_module(mod_name) 71 imported.append(mod_name) 72 except Exception as e: 73 logger.warning("Could not import tool module %s: %s", mod_name, e) 74 return imported 75 76 77 class ToolEntry: 78 """Metadata for a single registered tool.""" 79 80 __slots__ = ( 81 "name", "toolset", "schema", "handler", "check_fn", 82 "requires_env", "is_async", "description", "emoji", 83 "max_result_size_chars", 84 ) 85 86 def __init__(self, name, toolset, schema, handler, check_fn, 87 requires_env, is_async, description, emoji, 88 max_result_size_chars=None): 89 self.name = name 90 self.toolset = toolset 91 self.schema = schema 92 self.handler = handler 93 self.check_fn = check_fn 94 self.requires_env = requires_env 95 self.is_async = is_async 96 self.description = description 97 self.emoji = emoji 98 self.max_result_size_chars = max_result_size_chars 99 100 101 # --------------------------------------------------------------------------- 102 # check_fn TTL cache 103 # 104 # check_fn callables like tools/terminal_tool.check_terminal_requirements 105 # probe external state (Docker daemon, Modal SDK install, playwright binary 106 # availability). For a long-lived CLI or gateway process, calling them on 107 # every get_definitions() is pure waste — external state changes on human 108 # timescales. Cache results for ~30 s so env-var flips via ``hermes tools`` 109 # or live credential file changes propagate within a turn or two without 110 # requiring any explicit invalidation. 111 # --------------------------------------------------------------------------- 112 113 _CHECK_FN_TTL_SECONDS = 30.0 114 _check_fn_cache: Dict[Callable, tuple[float, bool]] = {} 115 _check_fn_cache_lock = threading.Lock() 116 117 118 def _check_fn_cached(fn: Callable) -> bool: 119 """Return bool(fn()), TTL-cached across calls. Swallows exceptions as False.""" 120 now = time.monotonic() 121 with _check_fn_cache_lock: 122 cached = _check_fn_cache.get(fn) 123 if cached is not None: 124 ts, value = cached 125 if now - ts < _CHECK_FN_TTL_SECONDS: 126 return value 127 try: 128 value = bool(fn()) 129 except Exception: 130 value = False 131 with _check_fn_cache_lock: 132 _check_fn_cache[fn] = (now, value) 133 return value 134 135 136 def invalidate_check_fn_cache() -> None: 137 """Drop all cached ``check_fn`` results. Call after config changes that 138 affect tool availability (e.g. ``hermes tools enable``).""" 139 with _check_fn_cache_lock: 140 _check_fn_cache.clear() 141 142 143 class ToolRegistry: 144 """Singleton registry that collects tool schemas + handlers from tool files.""" 145 146 def __init__(self): 147 self._tools: Dict[str, ToolEntry] = {} 148 self._toolset_checks: Dict[str, Callable] = {} 149 self._toolset_aliases: Dict[str, str] = {} 150 # MCP dynamic refresh can mutate the registry while other threads are 151 # reading tool metadata, so keep mutations serialized and readers on 152 # stable snapshots. 153 self._lock = threading.RLock() 154 # Monotonically-increasing generation counter. Bumped on every 155 # mutation (register / deregister / register_toolset_alias / MCP 156 # refresh). External callers (e.g. get_tool_definitions) can memoize 157 # against it: a cache entry keyed on the generation is valid for as 158 # long as the generation hasn't changed. 159 self._generation: int = 0 160 161 def _snapshot_state(self) -> tuple[List[ToolEntry], Dict[str, Callable]]: 162 """Return a coherent snapshot of registry entries and toolset checks.""" 163 with self._lock: 164 return list(self._tools.values()), dict(self._toolset_checks) 165 166 def _snapshot_entries(self) -> List[ToolEntry]: 167 """Return a stable snapshot of registered tool entries.""" 168 return self._snapshot_state()[0] 169 170 def _snapshot_toolset_checks(self) -> Dict[str, Callable]: 171 """Return a stable snapshot of toolset availability checks.""" 172 return self._snapshot_state()[1] 173 174 def _evaluate_toolset_check(self, toolset: str, check: Callable | None) -> bool: 175 """Run a toolset check, treating missing or failing checks as unavailable/available.""" 176 if not check: 177 return True 178 try: 179 return bool(check()) 180 except Exception: 181 logger.debug("Toolset %s check raised; marking unavailable", toolset) 182 return False 183 184 def get_entry(self, name: str) -> Optional[ToolEntry]: 185 """Return a registered tool entry by name, or None.""" 186 with self._lock: 187 return self._tools.get(name) 188 189 def get_registered_toolset_names(self) -> List[str]: 190 """Return sorted unique toolset names present in the registry.""" 191 return sorted({entry.toolset for entry in self._snapshot_entries()}) 192 193 def get_tool_names_for_toolset(self, toolset: str) -> List[str]: 194 """Return sorted tool names registered under a given toolset.""" 195 return sorted( 196 entry.name for entry in self._snapshot_entries() 197 if entry.toolset == toolset 198 ) 199 200 def register_toolset_alias(self, alias: str, toolset: str) -> None: 201 """Register an explicit alias for a canonical toolset name.""" 202 with self._lock: 203 existing = self._toolset_aliases.get(alias) 204 if existing and existing != toolset: 205 logger.warning( 206 "Toolset alias collision: '%s' (%s) overwritten by %s", 207 alias, existing, toolset, 208 ) 209 self._toolset_aliases[alias] = toolset 210 self._generation += 1 211 212 def get_registered_toolset_aliases(self) -> Dict[str, str]: 213 """Return a snapshot of ``{alias: canonical_toolset}`` mappings.""" 214 with self._lock: 215 return dict(self._toolset_aliases) 216 217 def get_toolset_alias_target(self, alias: str) -> Optional[str]: 218 """Return the canonical toolset name for an alias, or None.""" 219 with self._lock: 220 return self._toolset_aliases.get(alias) 221 222 # ------------------------------------------------------------------ 223 # Registration 224 # ------------------------------------------------------------------ 225 226 def register( 227 self, 228 name: str, 229 toolset: str, 230 schema: dict, 231 handler: Callable, 232 check_fn: Callable = None, 233 requires_env: list = None, 234 is_async: bool = False, 235 description: str = "", 236 emoji: str = "", 237 max_result_size_chars: int | float | None = None, 238 ): 239 """Register a tool. Called at module-import time by each tool file.""" 240 with self._lock: 241 existing = self._tools.get(name) 242 if existing and existing.toolset != toolset: 243 # Allow MCP-to-MCP overwrites (legitimate: server refresh, 244 # or two MCP servers with overlapping tool names). 245 both_mcp = ( 246 existing.toolset.startswith("mcp-") 247 and toolset.startswith("mcp-") 248 ) 249 if both_mcp: 250 logger.debug( 251 "Tool '%s': MCP toolset '%s' overwriting MCP toolset '%s'", 252 name, toolset, existing.toolset, 253 ) 254 else: 255 # Reject shadowing — prevent plugins/MCP from overwriting 256 # built-in tools or vice versa. 257 logger.error( 258 "Tool registration REJECTED: '%s' (toolset '%s') would " 259 "shadow existing tool from toolset '%s'. Deregister the " 260 "existing tool first if this is intentional.", 261 name, toolset, existing.toolset, 262 ) 263 return 264 self._tools[name] = ToolEntry( 265 name=name, 266 toolset=toolset, 267 schema=schema, 268 handler=handler, 269 check_fn=check_fn, 270 requires_env=requires_env or [], 271 is_async=is_async, 272 description=description or schema.get("description", ""), 273 emoji=emoji, 274 max_result_size_chars=max_result_size_chars, 275 ) 276 if check_fn and toolset not in self._toolset_checks: 277 self._toolset_checks[toolset] = check_fn 278 self._generation += 1 279 280 def deregister(self, name: str) -> None: 281 """Remove a tool from the registry. 282 283 Also cleans up the toolset check if no other tools remain in the 284 same toolset. Used by MCP dynamic tool discovery to nuke-and-repave 285 when a server sends ``notifications/tools/list_changed``. 286 """ 287 with self._lock: 288 entry = self._tools.pop(name, None) 289 if entry is None: 290 return 291 # Drop the toolset check and aliases if this was the last tool in 292 # that toolset. 293 toolset_still_exists = any( 294 e.toolset == entry.toolset for e in self._tools.values() 295 ) 296 if not toolset_still_exists: 297 self._toolset_checks.pop(entry.toolset, None) 298 self._toolset_aliases = { 299 alias: target 300 for alias, target in self._toolset_aliases.items() 301 if target != entry.toolset 302 } 303 self._generation += 1 304 logger.debug("Deregistered tool: %s", name) 305 306 # ------------------------------------------------------------------ 307 # Schema retrieval 308 # ------------------------------------------------------------------ 309 310 def get_definitions(self, tool_names: Set[str], quiet: bool = False) -> List[dict]: 311 """Return OpenAI-format tool schemas for the requested tool names. 312 313 Only tools whose ``check_fn()`` returns True (or have no check_fn) 314 are included. ``check_fn()`` results are cached for ~30 s via 315 :func:`_check_fn_cached` to amortize repeat probes (check_terminal_ 316 requirements probes modal/docker, browser checks probe playwright, 317 etc.); TTL chosen so env-var changes (``hermes tools enable foo``) 318 still take effect in near-real-time without forcing a full cache 319 flush on every call. 320 """ 321 result = [] 322 # Per-call cache on top of the 30 s TTL — handles repeat probes of the 323 # same check_fn within one definitions pass without re-reading the 324 # TTL clock. 325 check_results: Dict[Callable, bool] = {} 326 entries_by_name = {entry.name: entry for entry in self._snapshot_entries()} 327 for name in sorted(tool_names): 328 entry = entries_by_name.get(name) 329 if not entry: 330 continue 331 if entry.check_fn: 332 if entry.check_fn not in check_results: 333 check_results[entry.check_fn] = _check_fn_cached(entry.check_fn) 334 if not check_results[entry.check_fn]: 335 if not quiet: 336 logger.debug("Tool %s unavailable (check failed)", name) 337 continue 338 # Ensure schema always has a "name" field — use entry.name as fallback 339 schema_with_name = {**entry.schema, "name": entry.name} 340 result.append({"type": "function", "function": schema_with_name}) 341 return result 342 343 # ------------------------------------------------------------------ 344 # Dispatch 345 # ------------------------------------------------------------------ 346 347 def dispatch(self, name: str, args: dict, **kwargs) -> str: 348 """Execute a tool handler by name. 349 350 * Async handlers are bridged automatically via ``_run_async()``. 351 * All exceptions are caught and returned as ``{"error": "..."}`` 352 for consistent error format. 353 """ 354 entry = self.get_entry(name) 355 if not entry: 356 return json.dumps({"error": f"Unknown tool: {name}"}) 357 try: 358 if entry.is_async: 359 from model_tools import _run_async 360 return _run_async(entry.handler(args, **kwargs)) 361 return entry.handler(args, **kwargs) 362 except Exception as e: 363 logger.exception("Tool %s dispatch error: %s", name, e) 364 return json.dumps({"error": f"Tool execution failed: {type(e).__name__}: {e}"}) 365 366 # ------------------------------------------------------------------ 367 # Query helpers (replace redundant dicts in model_tools.py) 368 # ------------------------------------------------------------------ 369 370 def get_max_result_size(self, name: str, default: int | float | None = None) -> int | float: 371 """Return per-tool max result size, or *default* (or global default).""" 372 entry = self.get_entry(name) 373 if entry and entry.max_result_size_chars is not None: 374 return entry.max_result_size_chars 375 if default is not None: 376 return default 377 from tools.budget_config import DEFAULT_RESULT_SIZE_CHARS 378 return DEFAULT_RESULT_SIZE_CHARS 379 380 def get_all_tool_names(self) -> List[str]: 381 """Return sorted list of all registered tool names.""" 382 return sorted(entry.name for entry in self._snapshot_entries()) 383 384 def get_schema(self, name: str) -> Optional[dict]: 385 """Return a tool's raw schema dict, bypassing check_fn filtering. 386 387 Useful for token estimation and introspection where availability 388 doesn't matter — only the schema content does. 389 """ 390 entry = self.get_entry(name) 391 return entry.schema if entry else None 392 393 def get_toolset_for_tool(self, name: str) -> Optional[str]: 394 """Return the toolset a tool belongs to, or None.""" 395 entry = self.get_entry(name) 396 return entry.toolset if entry else None 397 398 def get_emoji(self, name: str, default: str = "⚡") -> str: 399 """Return the emoji for a tool, or *default* if unset.""" 400 entry = self.get_entry(name) 401 return (entry.emoji if entry and entry.emoji else default) 402 403 def get_tool_to_toolset_map(self) -> Dict[str, str]: 404 """Return ``{tool_name: toolset_name}`` for every registered tool.""" 405 return {entry.name: entry.toolset for entry in self._snapshot_entries()} 406 407 def is_toolset_available(self, toolset: str) -> bool: 408 """Check if a toolset's requirements are met. 409 410 Returns False (rather than crashing) when the check function raises 411 an unexpected exception (e.g. network error, missing import, bad config). 412 """ 413 with self._lock: 414 check = self._toolset_checks.get(toolset) 415 return self._evaluate_toolset_check(toolset, check) 416 417 def check_toolset_requirements(self) -> Dict[str, bool]: 418 """Return ``{toolset: available_bool}`` for every toolset.""" 419 entries, toolset_checks = self._snapshot_state() 420 toolsets = sorted({entry.toolset for entry in entries}) 421 return { 422 toolset: self._evaluate_toolset_check(toolset, toolset_checks.get(toolset)) 423 for toolset in toolsets 424 } 425 426 def get_available_toolsets(self) -> Dict[str, dict]: 427 """Return toolset metadata for UI display.""" 428 toolsets: Dict[str, dict] = {} 429 entries, toolset_checks = self._snapshot_state() 430 for entry in entries: 431 ts = entry.toolset 432 if ts not in toolsets: 433 toolsets[ts] = { 434 "available": self._evaluate_toolset_check( 435 ts, toolset_checks.get(ts) 436 ), 437 "tools": [], 438 "description": "", 439 "requirements": [], 440 } 441 toolsets[ts]["tools"].append(entry.name) 442 if entry.requires_env: 443 for env in entry.requires_env: 444 if env not in toolsets[ts]["requirements"]: 445 toolsets[ts]["requirements"].append(env) 446 return toolsets 447 448 def get_toolset_requirements(self) -> Dict[str, dict]: 449 """Build a TOOLSET_REQUIREMENTS-compatible dict for backward compat.""" 450 result: Dict[str, dict] = {} 451 entries, toolset_checks = self._snapshot_state() 452 for entry in entries: 453 ts = entry.toolset 454 if ts not in result: 455 result[ts] = { 456 "name": ts, 457 "env_vars": [], 458 "check_fn": toolset_checks.get(ts), 459 "setup_url": None, 460 "tools": [], 461 } 462 if entry.name not in result[ts]["tools"]: 463 result[ts]["tools"].append(entry.name) 464 for env in entry.requires_env: 465 if env not in result[ts]["env_vars"]: 466 result[ts]["env_vars"].append(env) 467 return result 468 469 def check_tool_availability(self, quiet: bool = False): 470 """Return (available_toolsets, unavailable_info) like the old function.""" 471 available = [] 472 unavailable = [] 473 seen = set() 474 entries, toolset_checks = self._snapshot_state() 475 for entry in entries: 476 ts = entry.toolset 477 if ts in seen: 478 continue 479 seen.add(ts) 480 if self._evaluate_toolset_check(ts, toolset_checks.get(ts)): 481 available.append(ts) 482 else: 483 unavailable.append({ 484 "name": ts, 485 "env_vars": entry.requires_env, 486 "tools": [e.name for e in entries if e.toolset == ts], 487 }) 488 return available, unavailable 489 490 491 # Module-level singleton 492 registry = ToolRegistry() 493 494 495 # --------------------------------------------------------------------------- 496 # Helpers for tool response serialization 497 # --------------------------------------------------------------------------- 498 # Every tool handler must return a JSON string. These helpers eliminate the 499 # boilerplate ``json.dumps({"error": msg}, ensure_ascii=False)`` that appears 500 # hundreds of times across tool files. 501 # 502 # Usage: 503 # from tools.registry import registry, tool_error, tool_result 504 # 505 # return tool_error("something went wrong") 506 # return tool_error("not found", code=404) 507 # return tool_result(success=True, data=payload) 508 # return tool_result(items) # pass a dict directly 509 510 511 def tool_error(message, **extra) -> str: 512 """Return a JSON error string for tool handlers. 513 514 >>> tool_error("file not found") 515 '{"error": "file not found"}' 516 >>> tool_error("bad input", success=False) 517 '{"error": "bad input", "success": false}' 518 """ 519 result = {"error": str(message)} 520 if extra: 521 result.update(extra) 522 return json.dumps(result, ensure_ascii=False) 523 524 525 def tool_result(data=None, **kwargs) -> str: 526 """Return a JSON result string for tool handlers. 527 528 Accepts a dict positional arg *or* keyword arguments (not both): 529 530 >>> tool_result(success=True, count=42) 531 '{"success": true, "count": 42}' 532 >>> tool_result({"key": "value"}) 533 '{"key": "value"}' 534 """ 535 if data is not None: 536 return json.dumps(data, ensure_ascii=False) 537 return json.dumps(kwargs, ensure_ascii=False)