test_registry.py
1 """Tests for the central tool registry.""" 2 3 import json 4 import threading 5 from pathlib import Path 6 from unittest.mock import patch 7 8 from tools.registry import ToolRegistry, discover_builtin_tools 9 10 11 def _dummy_handler(args, **kwargs): 12 return json.dumps({"ok": True}) 13 14 15 def _make_schema(name="test_tool"): 16 return { 17 "name": name, 18 "description": f"A {name}", 19 "parameters": {"type": "object", "properties": {}}, 20 } 21 22 23 class TestRegisterAndDispatch: 24 def test_register_and_dispatch(self): 25 reg = ToolRegistry() 26 reg.register( 27 name="alpha", 28 toolset="core", 29 schema=_make_schema("alpha"), 30 handler=_dummy_handler, 31 ) 32 result = json.loads(reg.dispatch("alpha", {})) 33 assert result == {"ok": True} 34 35 def test_dispatch_passes_args(self): 36 reg = ToolRegistry() 37 38 def echo_handler(args, **kw): 39 return json.dumps(args) 40 41 reg.register( 42 name="echo", 43 toolset="core", 44 schema=_make_schema("echo"), 45 handler=echo_handler, 46 ) 47 result = json.loads(reg.dispatch("echo", {"msg": "hi"})) 48 assert result == {"msg": "hi"} 49 50 51 class TestGetDefinitions: 52 def test_returns_openai_format(self): 53 reg = ToolRegistry() 54 reg.register( 55 name="t1", toolset="s1", schema=_make_schema("t1"), handler=_dummy_handler 56 ) 57 reg.register( 58 name="t2", toolset="s1", schema=_make_schema("t2"), handler=_dummy_handler 59 ) 60 61 defs = reg.get_definitions({"t1", "t2"}) 62 assert len(defs) == 2 63 assert all(d["type"] == "function" for d in defs) 64 names = {d["function"]["name"] for d in defs} 65 assert names == {"t1", "t2"} 66 67 def test_skips_unavailable_tools(self): 68 reg = ToolRegistry() 69 reg.register( 70 name="available", 71 toolset="s", 72 schema=_make_schema("available"), 73 handler=_dummy_handler, 74 check_fn=lambda: True, 75 ) 76 reg.register( 77 name="unavailable", 78 toolset="s", 79 schema=_make_schema("unavailable"), 80 handler=_dummy_handler, 81 check_fn=lambda: False, 82 ) 83 defs = reg.get_definitions({"available", "unavailable"}) 84 assert len(defs) == 1 85 assert defs[0]["function"]["name"] == "available" 86 87 def test_reuses_shared_check_fn_once_per_call(self): 88 reg = ToolRegistry() 89 calls = {"count": 0} 90 91 def shared_check(): 92 calls["count"] += 1 93 return True 94 95 reg.register( 96 name="first", 97 toolset="shared", 98 schema=_make_schema("first"), 99 handler=_dummy_handler, 100 check_fn=shared_check, 101 ) 102 reg.register( 103 name="second", 104 toolset="shared", 105 schema=_make_schema("second"), 106 handler=_dummy_handler, 107 check_fn=shared_check, 108 ) 109 110 defs = reg.get_definitions({"first", "second"}) 111 assert len(defs) == 2 112 assert calls["count"] == 1 113 114 115 class TestUnknownToolDispatch: 116 def test_returns_error_json(self): 117 reg = ToolRegistry() 118 result = json.loads(reg.dispatch("nonexistent", {})) 119 assert "error" in result 120 assert "Unknown tool" in result["error"] 121 122 123 class TestToolsetAvailability: 124 def test_no_check_fn_is_available(self): 125 reg = ToolRegistry() 126 reg.register( 127 name="t", toolset="free", schema=_make_schema(), handler=_dummy_handler 128 ) 129 assert reg.is_toolset_available("free") is True 130 131 def test_check_fn_controls_availability(self): 132 reg = ToolRegistry() 133 reg.register( 134 name="t", 135 toolset="locked", 136 schema=_make_schema(), 137 handler=_dummy_handler, 138 check_fn=lambda: False, 139 ) 140 assert reg.is_toolset_available("locked") is False 141 142 def test_check_toolset_requirements(self): 143 reg = ToolRegistry() 144 reg.register( 145 name="a", 146 toolset="ok", 147 schema=_make_schema(), 148 handler=_dummy_handler, 149 check_fn=lambda: True, 150 ) 151 reg.register( 152 name="b", 153 toolset="nope", 154 schema=_make_schema(), 155 handler=_dummy_handler, 156 check_fn=lambda: False, 157 ) 158 159 reqs = reg.check_toolset_requirements() 160 assert reqs["ok"] is True 161 assert reqs["nope"] is False 162 163 def test_get_all_tool_names(self): 164 reg = ToolRegistry() 165 reg.register( 166 name="z_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler 167 ) 168 reg.register( 169 name="a_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler 170 ) 171 assert reg.get_all_tool_names() == ["a_tool", "z_tool"] 172 173 def test_get_registered_toolset_names(self): 174 reg = ToolRegistry() 175 reg.register( 176 name="first", toolset="zeta", schema=_make_schema(), handler=_dummy_handler 177 ) 178 reg.register( 179 name="second", toolset="alpha", schema=_make_schema(), handler=_dummy_handler 180 ) 181 reg.register( 182 name="third", toolset="alpha", schema=_make_schema(), handler=_dummy_handler 183 ) 184 assert reg.get_registered_toolset_names() == ["alpha", "zeta"] 185 186 def test_get_tool_names_for_toolset(self): 187 reg = ToolRegistry() 188 reg.register( 189 name="z_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler 190 ) 191 reg.register( 192 name="a_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler 193 ) 194 reg.register( 195 name="other_tool", toolset="other", schema=_make_schema(), handler=_dummy_handler 196 ) 197 assert reg.get_tool_names_for_toolset("grouped") == ["a_tool", "z_tool"] 198 199 def test_handler_exception_returns_error(self): 200 reg = ToolRegistry() 201 202 def bad_handler(args, **kw): 203 raise RuntimeError("boom") 204 205 reg.register( 206 name="bad", toolset="s", schema=_make_schema(), handler=bad_handler 207 ) 208 result = json.loads(reg.dispatch("bad", {})) 209 assert "error" in result 210 assert "RuntimeError" in result["error"] 211 212 213 class TestCheckFnExceptionHandling: 214 """Verify that a raising check_fn is caught rather than crashing.""" 215 216 def test_is_toolset_available_catches_exception(self): 217 reg = ToolRegistry() 218 reg.register( 219 name="t", 220 toolset="broken", 221 schema=_make_schema(), 222 handler=_dummy_handler, 223 check_fn=lambda: 1 / 0, # ZeroDivisionError 224 ) 225 # Should return False, not raise 226 assert reg.is_toolset_available("broken") is False 227 228 def test_check_toolset_requirements_survives_raising_check(self): 229 reg = ToolRegistry() 230 reg.register( 231 name="a", 232 toolset="good", 233 schema=_make_schema(), 234 handler=_dummy_handler, 235 check_fn=lambda: True, 236 ) 237 reg.register( 238 name="b", 239 toolset="bad", 240 schema=_make_schema(), 241 handler=_dummy_handler, 242 check_fn=lambda: (_ for _ in ()).throw(ImportError("no module")), 243 ) 244 245 reqs = reg.check_toolset_requirements() 246 assert reqs["good"] is True 247 assert reqs["bad"] is False 248 249 def test_get_definitions_skips_raising_check(self): 250 reg = ToolRegistry() 251 reg.register( 252 name="ok_tool", 253 toolset="s", 254 schema=_make_schema("ok_tool"), 255 handler=_dummy_handler, 256 check_fn=lambda: True, 257 ) 258 reg.register( 259 name="bad_tool", 260 toolset="s2", 261 schema=_make_schema("bad_tool"), 262 handler=_dummy_handler, 263 check_fn=lambda: (_ for _ in ()).throw(OSError("network down")), 264 ) 265 defs = reg.get_definitions({"ok_tool", "bad_tool"}) 266 assert len(defs) == 1 267 assert defs[0]["function"]["name"] == "ok_tool" 268 269 def test_check_tool_availability_survives_raising_check(self): 270 reg = ToolRegistry() 271 reg.register( 272 name="a", 273 toolset="works", 274 schema=_make_schema(), 275 handler=_dummy_handler, 276 check_fn=lambda: True, 277 ) 278 reg.register( 279 name="b", 280 toolset="crashes", 281 schema=_make_schema(), 282 handler=_dummy_handler, 283 check_fn=lambda: 1 / 0, 284 ) 285 286 available, unavailable = reg.check_tool_availability() 287 assert "works" in available 288 assert any(u["name"] == "crashes" for u in unavailable) 289 290 291 class TestBuiltinDiscovery: 292 def test_matches_previous_manual_builtin_tool_set(self): 293 expected = { 294 "tools.browser_cdp_tool", 295 "tools.browser_dialog_tool", 296 "tools.browser_tool", 297 "tools.clarify_tool", 298 "tools.code_execution_tool", 299 "tools.cronjob_tools", 300 "tools.delegate_tool", 301 "tools.discord_tool", 302 "tools.feishu_doc_tool", 303 "tools.feishu_drive_tool", 304 "tools.file_tools", 305 "tools.homeassistant_tool", 306 "tools.image_generation_tool", 307 "tools.kanban_tools", 308 "tools.memory_tool", 309 "tools.mixture_of_agents_tool", 310 "tools.process_registry", 311 "tools.rl_training_tool", 312 "tools.send_message_tool", 313 "tools.session_search_tool", 314 "tools.skill_manager_tool", 315 "tools.skills_tool", 316 "tools.terminal_tool", 317 "tools.todo_tool", 318 "tools.tts_tool", 319 "tools.vision_tools", 320 "tools.web_tools", 321 "tools.yuanbao_tools", 322 } 323 324 with patch("tools.registry.importlib.import_module"): 325 imported = discover_builtin_tools(Path(__file__).resolve().parents[2] / "tools") 326 327 assert set(imported) == expected 328 329 def test_imports_only_self_registering_modules(self, tmp_path): 330 tools_dir = tmp_path / "tools" 331 tools_dir.mkdir() 332 (tools_dir / "__init__.py").write_text("", encoding="utf-8") 333 (tools_dir / "registry.py").write_text("", encoding="utf-8") 334 (tools_dir / "alpha.py").write_text( 335 "from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n", 336 encoding="utf-8", 337 ) 338 (tools_dir / "beta.py").write_text("VALUE = 1\n", encoding="utf-8") 339 340 with patch("tools.registry.importlib.import_module") as mock_import: 341 imported = discover_builtin_tools(tools_dir) 342 343 assert imported == ["tools.alpha"] 344 mock_import.assert_called_once_with("tools.alpha") 345 346 def test_skips_mcp_tool_even_if_it_registers(self, tmp_path): 347 tools_dir = tmp_path / "tools" 348 tools_dir.mkdir() 349 (tools_dir / "__init__.py").write_text("", encoding="utf-8") 350 (tools_dir / "mcp_tool.py").write_text( 351 "from tools.registry import registry\nregistry.register(name='mcp_alpha', toolset='mcp-test', schema={}, handler=lambda *_a, **_k: '{}')\n", 352 encoding="utf-8", 353 ) 354 (tools_dir / "alpha.py").write_text( 355 "from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n", 356 encoding="utf-8", 357 ) 358 359 with patch("tools.registry.importlib.import_module") as mock_import: 360 imported = discover_builtin_tools(tools_dir) 361 362 assert imported == ["tools.alpha"] 363 mock_import.assert_called_once_with("tools.alpha") 364 365 366 class TestEmojiMetadata: 367 """Verify per-tool emoji registration and lookup.""" 368 369 def test_emoji_stored_on_entry(self): 370 reg = ToolRegistry() 371 reg.register( 372 name="t", toolset="s", schema=_make_schema(), 373 handler=_dummy_handler, emoji="🔥", 374 ) 375 assert reg._tools["t"].emoji == "🔥" 376 377 def test_get_emoji_returns_registered(self): 378 reg = ToolRegistry() 379 reg.register( 380 name="t", toolset="s", schema=_make_schema(), 381 handler=_dummy_handler, emoji="🎯", 382 ) 383 assert reg.get_emoji("t") == "🎯" 384 385 def test_get_emoji_returns_default_when_unset(self): 386 reg = ToolRegistry() 387 reg.register( 388 name="t", toolset="s", schema=_make_schema(), 389 handler=_dummy_handler, 390 ) 391 assert reg.get_emoji("t") == "⚡" 392 assert reg.get_emoji("t", default="🔧") == "🔧" 393 394 def test_get_emoji_returns_default_for_unknown_tool(self): 395 reg = ToolRegistry() 396 assert reg.get_emoji("nonexistent") == "⚡" 397 assert reg.get_emoji("nonexistent", default="❓") == "❓" 398 399 def test_emoji_empty_string_treated_as_unset(self): 400 reg = ToolRegistry() 401 reg.register( 402 name="t", toolset="s", schema=_make_schema(), 403 handler=_dummy_handler, emoji="", 404 ) 405 assert reg.get_emoji("t") == "⚡" 406 407 408 class TestEntryLookup: 409 def test_get_entry_returns_registered_entry(self): 410 reg = ToolRegistry() 411 reg.register( 412 name="alpha", toolset="core", schema=_make_schema("alpha"), handler=_dummy_handler 413 ) 414 entry = reg.get_entry("alpha") 415 assert entry is not None 416 assert entry.name == "alpha" 417 assert entry.toolset == "core" 418 419 def test_get_entry_returns_none_for_unknown_tool(self): 420 reg = ToolRegistry() 421 assert reg.get_entry("missing") is None 422 423 424 class TestSecretCaptureResultContract: 425 def test_secret_request_result_does_not_include_secret_value(self): 426 result = { 427 "success": True, 428 "stored_as": "TENOR_API_KEY", 429 "validated": False, 430 } 431 assert "secret" not in json.dumps(result).lower() 432 433 434 class TestThreadSafety: 435 def test_get_available_toolsets_uses_coherent_snapshot(self, monkeypatch): 436 reg = ToolRegistry() 437 reg.register( 438 name="alpha", 439 toolset="gated", 440 schema=_make_schema("alpha"), 441 handler=_dummy_handler, 442 check_fn=lambda: False, 443 ) 444 445 entries, toolset_checks = reg._snapshot_state() 446 447 def snapshot_then_mutate(): 448 reg.deregister("alpha") 449 return entries, toolset_checks 450 451 monkeypatch.setattr(reg, "_snapshot_state", snapshot_then_mutate) 452 453 toolsets = reg.get_available_toolsets() 454 assert toolsets["gated"]["available"] is False 455 assert toolsets["gated"]["tools"] == ["alpha"] 456 457 def test_check_tool_availability_tolerates_concurrent_register(self): 458 reg = ToolRegistry() 459 check_started = threading.Event() 460 writer_done = threading.Event() 461 errors = [] 462 result_holder = {} 463 writer_completed_during_check = {} 464 465 def blocking_check(): 466 check_started.set() 467 writer_completed_during_check["value"] = writer_done.wait(timeout=1) 468 return True 469 470 reg.register( 471 name="alpha", 472 toolset="gated", 473 schema=_make_schema("alpha"), 474 handler=_dummy_handler, 475 check_fn=blocking_check, 476 ) 477 reg.register( 478 name="beta", 479 toolset="plain", 480 schema=_make_schema("beta"), 481 handler=_dummy_handler, 482 ) 483 484 def reader(): 485 try: 486 result_holder["value"] = reg.check_tool_availability() 487 except Exception as exc: # pragma: no cover - exercised on failure only 488 errors.append(exc) 489 490 def writer(): 491 assert check_started.wait(timeout=1) 492 reg.register( 493 name="gamma", 494 toolset="new", 495 schema=_make_schema("gamma"), 496 handler=_dummy_handler, 497 ) 498 writer_done.set() 499 500 reader_thread = threading.Thread(target=reader) 501 writer_thread = threading.Thread(target=writer) 502 reader_thread.start() 503 writer_thread.start() 504 reader_thread.join(timeout=2) 505 writer_thread.join(timeout=2) 506 507 assert not reader_thread.is_alive() 508 assert not writer_thread.is_alive() 509 assert writer_completed_during_check["value"] is True 510 assert errors == [] 511 512 available, unavailable = result_holder["value"] 513 assert "gated" in available 514 assert "plain" in available 515 assert unavailable == [] 516 517 def test_get_available_toolsets_tolerates_concurrent_deregister(self): 518 reg = ToolRegistry() 519 check_started = threading.Event() 520 writer_done = threading.Event() 521 errors = [] 522 result_holder = {} 523 writer_completed_during_check = {} 524 525 def blocking_check(): 526 check_started.set() 527 writer_completed_during_check["value"] = writer_done.wait(timeout=1) 528 return True 529 530 reg.register( 531 name="alpha", 532 toolset="gated", 533 schema=_make_schema("alpha"), 534 handler=_dummy_handler, 535 check_fn=blocking_check, 536 ) 537 reg.register( 538 name="beta", 539 toolset="plain", 540 schema=_make_schema("beta"), 541 handler=_dummy_handler, 542 ) 543 544 def reader(): 545 try: 546 result_holder["value"] = reg.get_available_toolsets() 547 except Exception as exc: # pragma: no cover - exercised on failure only 548 errors.append(exc) 549 550 def writer(): 551 assert check_started.wait(timeout=1) 552 reg.deregister("beta") 553 writer_done.set() 554 555 reader_thread = threading.Thread(target=reader) 556 writer_thread = threading.Thread(target=writer) 557 reader_thread.start() 558 writer_thread.start() 559 reader_thread.join(timeout=2) 560 writer_thread.join(timeout=2) 561 562 assert not reader_thread.is_alive() 563 assert not writer_thread.is_alive() 564 assert writer_completed_during_check["value"] is True 565 assert errors == [] 566 567 toolsets = result_holder["value"] 568 assert "gated" in toolsets 569 assert toolsets["gated"]["available"] is True