/ tests / tools / test_registry.py
test_registry.py
  1  """Tests for the central tool registry."""
  2  
  3  import json
  4  import threading
  5  from pathlib import Path
  6  from unittest.mock import patch
  7  
  8  from tools.registry import ToolRegistry, discover_builtin_tools
  9  
 10  
 11  def _dummy_handler(args, **kwargs):
 12      return json.dumps({"ok": True})
 13  
 14  
 15  def _make_schema(name="test_tool"):
 16      return {
 17          "name": name,
 18          "description": f"A {name}",
 19          "parameters": {"type": "object", "properties": {}},
 20      }
 21  
 22  
 23  class TestRegisterAndDispatch:
 24      def test_register_and_dispatch(self):
 25          reg = ToolRegistry()
 26          reg.register(
 27              name="alpha",
 28              toolset="core",
 29              schema=_make_schema("alpha"),
 30              handler=_dummy_handler,
 31          )
 32          result = json.loads(reg.dispatch("alpha", {}))
 33          assert result == {"ok": True}
 34  
 35      def test_dispatch_passes_args(self):
 36          reg = ToolRegistry()
 37  
 38          def echo_handler(args, **kw):
 39              return json.dumps(args)
 40  
 41          reg.register(
 42              name="echo",
 43              toolset="core",
 44              schema=_make_schema("echo"),
 45              handler=echo_handler,
 46          )
 47          result = json.loads(reg.dispatch("echo", {"msg": "hi"}))
 48          assert result == {"msg": "hi"}
 49  
 50  
 51  class TestGetDefinitions:
 52      def test_returns_openai_format(self):
 53          reg = ToolRegistry()
 54          reg.register(
 55              name="t1", toolset="s1", schema=_make_schema("t1"), handler=_dummy_handler
 56          )
 57          reg.register(
 58              name="t2", toolset="s1", schema=_make_schema("t2"), handler=_dummy_handler
 59          )
 60  
 61          defs = reg.get_definitions({"t1", "t2"})
 62          assert len(defs) == 2
 63          assert all(d["type"] == "function" for d in defs)
 64          names = {d["function"]["name"] for d in defs}
 65          assert names == {"t1", "t2"}
 66  
 67      def test_skips_unavailable_tools(self):
 68          reg = ToolRegistry()
 69          reg.register(
 70              name="available",
 71              toolset="s",
 72              schema=_make_schema("available"),
 73              handler=_dummy_handler,
 74              check_fn=lambda: True,
 75          )
 76          reg.register(
 77              name="unavailable",
 78              toolset="s",
 79              schema=_make_schema("unavailable"),
 80              handler=_dummy_handler,
 81              check_fn=lambda: False,
 82          )
 83          defs = reg.get_definitions({"available", "unavailable"})
 84          assert len(defs) == 1
 85          assert defs[0]["function"]["name"] == "available"
 86  
 87      def test_reuses_shared_check_fn_once_per_call(self):
 88          reg = ToolRegistry()
 89          calls = {"count": 0}
 90  
 91          def shared_check():
 92              calls["count"] += 1
 93              return True
 94  
 95          reg.register(
 96              name="first",
 97              toolset="shared",
 98              schema=_make_schema("first"),
 99              handler=_dummy_handler,
100              check_fn=shared_check,
101          )
102          reg.register(
103              name="second",
104              toolset="shared",
105              schema=_make_schema("second"),
106              handler=_dummy_handler,
107              check_fn=shared_check,
108          )
109  
110          defs = reg.get_definitions({"first", "second"})
111          assert len(defs) == 2
112          assert calls["count"] == 1
113  
114  
115  class TestUnknownToolDispatch:
116      def test_returns_error_json(self):
117          reg = ToolRegistry()
118          result = json.loads(reg.dispatch("nonexistent", {}))
119          assert "error" in result
120          assert "Unknown tool" in result["error"]
121  
122  
123  class TestToolsetAvailability:
124      def test_no_check_fn_is_available(self):
125          reg = ToolRegistry()
126          reg.register(
127              name="t", toolset="free", schema=_make_schema(), handler=_dummy_handler
128          )
129          assert reg.is_toolset_available("free") is True
130  
131      def test_check_fn_controls_availability(self):
132          reg = ToolRegistry()
133          reg.register(
134              name="t",
135              toolset="locked",
136              schema=_make_schema(),
137              handler=_dummy_handler,
138              check_fn=lambda: False,
139          )
140          assert reg.is_toolset_available("locked") is False
141  
142      def test_check_toolset_requirements(self):
143          reg = ToolRegistry()
144          reg.register(
145              name="a",
146              toolset="ok",
147              schema=_make_schema(),
148              handler=_dummy_handler,
149              check_fn=lambda: True,
150          )
151          reg.register(
152              name="b",
153              toolset="nope",
154              schema=_make_schema(),
155              handler=_dummy_handler,
156              check_fn=lambda: False,
157          )
158  
159          reqs = reg.check_toolset_requirements()
160          assert reqs["ok"] is True
161          assert reqs["nope"] is False
162  
163      def test_get_all_tool_names(self):
164          reg = ToolRegistry()
165          reg.register(
166              name="z_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler
167          )
168          reg.register(
169              name="a_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler
170          )
171          assert reg.get_all_tool_names() == ["a_tool", "z_tool"]
172  
173      def test_get_registered_toolset_names(self):
174          reg = ToolRegistry()
175          reg.register(
176              name="first", toolset="zeta", schema=_make_schema(), handler=_dummy_handler
177          )
178          reg.register(
179              name="second", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
180          )
181          reg.register(
182              name="third", toolset="alpha", schema=_make_schema(), handler=_dummy_handler
183          )
184          assert reg.get_registered_toolset_names() == ["alpha", "zeta"]
185  
186      def test_get_tool_names_for_toolset(self):
187          reg = ToolRegistry()
188          reg.register(
189              name="z_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
190          )
191          reg.register(
192              name="a_tool", toolset="grouped", schema=_make_schema(), handler=_dummy_handler
193          )
194          reg.register(
195              name="other_tool", toolset="other", schema=_make_schema(), handler=_dummy_handler
196          )
197          assert reg.get_tool_names_for_toolset("grouped") == ["a_tool", "z_tool"]
198  
199      def test_handler_exception_returns_error(self):
200          reg = ToolRegistry()
201  
202          def bad_handler(args, **kw):
203              raise RuntimeError("boom")
204  
205          reg.register(
206              name="bad", toolset="s", schema=_make_schema(), handler=bad_handler
207          )
208          result = json.loads(reg.dispatch("bad", {}))
209          assert "error" in result
210          assert "RuntimeError" in result["error"]
211  
212  
213  class TestCheckFnExceptionHandling:
214      """Verify that a raising check_fn is caught rather than crashing."""
215  
216      def test_is_toolset_available_catches_exception(self):
217          reg = ToolRegistry()
218          reg.register(
219              name="t",
220              toolset="broken",
221              schema=_make_schema(),
222              handler=_dummy_handler,
223              check_fn=lambda: 1 / 0,  # ZeroDivisionError
224          )
225          # Should return False, not raise
226          assert reg.is_toolset_available("broken") is False
227  
228      def test_check_toolset_requirements_survives_raising_check(self):
229          reg = ToolRegistry()
230          reg.register(
231              name="a",
232              toolset="good",
233              schema=_make_schema(),
234              handler=_dummy_handler,
235              check_fn=lambda: True,
236          )
237          reg.register(
238              name="b",
239              toolset="bad",
240              schema=_make_schema(),
241              handler=_dummy_handler,
242              check_fn=lambda: (_ for _ in ()).throw(ImportError("no module")),
243          )
244  
245          reqs = reg.check_toolset_requirements()
246          assert reqs["good"] is True
247          assert reqs["bad"] is False
248  
249      def test_get_definitions_skips_raising_check(self):
250          reg = ToolRegistry()
251          reg.register(
252              name="ok_tool",
253              toolset="s",
254              schema=_make_schema("ok_tool"),
255              handler=_dummy_handler,
256              check_fn=lambda: True,
257          )
258          reg.register(
259              name="bad_tool",
260              toolset="s2",
261              schema=_make_schema("bad_tool"),
262              handler=_dummy_handler,
263              check_fn=lambda: (_ for _ in ()).throw(OSError("network down")),
264          )
265          defs = reg.get_definitions({"ok_tool", "bad_tool"})
266          assert len(defs) == 1
267          assert defs[0]["function"]["name"] == "ok_tool"
268  
269      def test_check_tool_availability_survives_raising_check(self):
270          reg = ToolRegistry()
271          reg.register(
272              name="a",
273              toolset="works",
274              schema=_make_schema(),
275              handler=_dummy_handler,
276              check_fn=lambda: True,
277          )
278          reg.register(
279              name="b",
280              toolset="crashes",
281              schema=_make_schema(),
282              handler=_dummy_handler,
283              check_fn=lambda: 1 / 0,
284          )
285  
286          available, unavailable = reg.check_tool_availability()
287          assert "works" in available
288          assert any(u["name"] == "crashes" for u in unavailable)
289  
290  
291  class TestBuiltinDiscovery:
292      def test_matches_previous_manual_builtin_tool_set(self):
293          expected = {
294              "tools.browser_cdp_tool",
295              "tools.browser_dialog_tool",
296              "tools.browser_tool",
297              "tools.clarify_tool",
298              "tools.code_execution_tool",
299              "tools.cronjob_tools",
300              "tools.delegate_tool",
301              "tools.discord_tool",
302              "tools.feishu_doc_tool",
303              "tools.feishu_drive_tool",
304              "tools.file_tools",
305              "tools.homeassistant_tool",
306              "tools.image_generation_tool",
307              "tools.kanban_tools",
308              "tools.memory_tool",
309              "tools.mixture_of_agents_tool",
310              "tools.process_registry",
311              "tools.rl_training_tool",
312              "tools.send_message_tool",
313              "tools.session_search_tool",
314              "tools.skill_manager_tool",
315              "tools.skills_tool",
316              "tools.terminal_tool",
317              "tools.todo_tool",
318              "tools.tts_tool",
319              "tools.vision_tools",
320              "tools.web_tools",
321              "tools.yuanbao_tools",
322          }
323  
324          with patch("tools.registry.importlib.import_module"):
325              imported = discover_builtin_tools(Path(__file__).resolve().parents[2] / "tools")
326  
327          assert set(imported) == expected
328  
329      def test_imports_only_self_registering_modules(self, tmp_path):
330          tools_dir = tmp_path / "tools"
331          tools_dir.mkdir()
332          (tools_dir / "__init__.py").write_text("", encoding="utf-8")
333          (tools_dir / "registry.py").write_text("", encoding="utf-8")
334          (tools_dir / "alpha.py").write_text(
335              "from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n",
336              encoding="utf-8",
337          )
338          (tools_dir / "beta.py").write_text("VALUE = 1\n", encoding="utf-8")
339  
340          with patch("tools.registry.importlib.import_module") as mock_import:
341              imported = discover_builtin_tools(tools_dir)
342  
343          assert imported == ["tools.alpha"]
344          mock_import.assert_called_once_with("tools.alpha")
345  
346      def test_skips_mcp_tool_even_if_it_registers(self, tmp_path):
347          tools_dir = tmp_path / "tools"
348          tools_dir.mkdir()
349          (tools_dir / "__init__.py").write_text("", encoding="utf-8")
350          (tools_dir / "mcp_tool.py").write_text(
351              "from tools.registry import registry\nregistry.register(name='mcp_alpha', toolset='mcp-test', schema={}, handler=lambda *_a, **_k: '{}')\n",
352              encoding="utf-8",
353          )
354          (tools_dir / "alpha.py").write_text(
355              "from tools.registry import registry\nregistry.register(name='alpha', toolset='x', schema={}, handler=lambda *_a, **_k: '{}')\n",
356              encoding="utf-8",
357          )
358  
359          with patch("tools.registry.importlib.import_module") as mock_import:
360              imported = discover_builtin_tools(tools_dir)
361  
362          assert imported == ["tools.alpha"]
363          mock_import.assert_called_once_with("tools.alpha")
364  
365  
366  class TestEmojiMetadata:
367      """Verify per-tool emoji registration and lookup."""
368  
369      def test_emoji_stored_on_entry(self):
370          reg = ToolRegistry()
371          reg.register(
372              name="t", toolset="s", schema=_make_schema(),
373              handler=_dummy_handler, emoji="🔥",
374          )
375          assert reg._tools["t"].emoji == "🔥"
376  
377      def test_get_emoji_returns_registered(self):
378          reg = ToolRegistry()
379          reg.register(
380              name="t", toolset="s", schema=_make_schema(),
381              handler=_dummy_handler, emoji="🎯",
382          )
383          assert reg.get_emoji("t") == "🎯"
384  
385      def test_get_emoji_returns_default_when_unset(self):
386          reg = ToolRegistry()
387          reg.register(
388              name="t", toolset="s", schema=_make_schema(),
389              handler=_dummy_handler,
390          )
391          assert reg.get_emoji("t") == "⚡"
392          assert reg.get_emoji("t", default="🔧") == "🔧"
393  
394      def test_get_emoji_returns_default_for_unknown_tool(self):
395          reg = ToolRegistry()
396          assert reg.get_emoji("nonexistent") == "⚡"
397          assert reg.get_emoji("nonexistent", default="❓") == "❓"
398  
399      def test_emoji_empty_string_treated_as_unset(self):
400          reg = ToolRegistry()
401          reg.register(
402              name="t", toolset="s", schema=_make_schema(),
403              handler=_dummy_handler, emoji="",
404          )
405          assert reg.get_emoji("t") == "⚡"
406  
407  
408  class TestEntryLookup:
409      def test_get_entry_returns_registered_entry(self):
410          reg = ToolRegistry()
411          reg.register(
412              name="alpha", toolset="core", schema=_make_schema("alpha"), handler=_dummy_handler
413          )
414          entry = reg.get_entry("alpha")
415          assert entry is not None
416          assert entry.name == "alpha"
417          assert entry.toolset == "core"
418  
419      def test_get_entry_returns_none_for_unknown_tool(self):
420          reg = ToolRegistry()
421          assert reg.get_entry("missing") is None
422  
423  
424  class TestSecretCaptureResultContract:
425      def test_secret_request_result_does_not_include_secret_value(self):
426          result = {
427              "success": True,
428              "stored_as": "TENOR_API_KEY",
429              "validated": False,
430          }
431          assert "secret" not in json.dumps(result).lower()
432  
433  
434  class TestThreadSafety:
435      def test_get_available_toolsets_uses_coherent_snapshot(self, monkeypatch):
436          reg = ToolRegistry()
437          reg.register(
438              name="alpha",
439              toolset="gated",
440              schema=_make_schema("alpha"),
441              handler=_dummy_handler,
442              check_fn=lambda: False,
443          )
444  
445          entries, toolset_checks = reg._snapshot_state()
446  
447          def snapshot_then_mutate():
448              reg.deregister("alpha")
449              return entries, toolset_checks
450  
451          monkeypatch.setattr(reg, "_snapshot_state", snapshot_then_mutate)
452  
453          toolsets = reg.get_available_toolsets()
454          assert toolsets["gated"]["available"] is False
455          assert toolsets["gated"]["tools"] == ["alpha"]
456  
457      def test_check_tool_availability_tolerates_concurrent_register(self):
458          reg = ToolRegistry()
459          check_started = threading.Event()
460          writer_done = threading.Event()
461          errors = []
462          result_holder = {}
463          writer_completed_during_check = {}
464  
465          def blocking_check():
466              check_started.set()
467              writer_completed_during_check["value"] = writer_done.wait(timeout=1)
468              return True
469  
470          reg.register(
471              name="alpha",
472              toolset="gated",
473              schema=_make_schema("alpha"),
474              handler=_dummy_handler,
475              check_fn=blocking_check,
476          )
477          reg.register(
478              name="beta",
479              toolset="plain",
480              schema=_make_schema("beta"),
481              handler=_dummy_handler,
482          )
483  
484          def reader():
485              try:
486                  result_holder["value"] = reg.check_tool_availability()
487              except Exception as exc:  # pragma: no cover - exercised on failure only
488                  errors.append(exc)
489  
490          def writer():
491              assert check_started.wait(timeout=1)
492              reg.register(
493                  name="gamma",
494                  toolset="new",
495                  schema=_make_schema("gamma"),
496                  handler=_dummy_handler,
497              )
498              writer_done.set()
499  
500          reader_thread = threading.Thread(target=reader)
501          writer_thread = threading.Thread(target=writer)
502          reader_thread.start()
503          writer_thread.start()
504          reader_thread.join(timeout=2)
505          writer_thread.join(timeout=2)
506  
507          assert not reader_thread.is_alive()
508          assert not writer_thread.is_alive()
509          assert writer_completed_during_check["value"] is True
510          assert errors == []
511  
512          available, unavailable = result_holder["value"]
513          assert "gated" in available
514          assert "plain" in available
515          assert unavailable == []
516  
517      def test_get_available_toolsets_tolerates_concurrent_deregister(self):
518          reg = ToolRegistry()
519          check_started = threading.Event()
520          writer_done = threading.Event()
521          errors = []
522          result_holder = {}
523          writer_completed_during_check = {}
524  
525          def blocking_check():
526              check_started.set()
527              writer_completed_during_check["value"] = writer_done.wait(timeout=1)
528              return True
529  
530          reg.register(
531              name="alpha",
532              toolset="gated",
533              schema=_make_schema("alpha"),
534              handler=_dummy_handler,
535              check_fn=blocking_check,
536          )
537          reg.register(
538              name="beta",
539              toolset="plain",
540              schema=_make_schema("beta"),
541              handler=_dummy_handler,
542          )
543  
544          def reader():
545              try:
546                  result_holder["value"] = reg.get_available_toolsets()
547              except Exception as exc:  # pragma: no cover - exercised on failure only
548                  errors.append(exc)
549  
550          def writer():
551              assert check_started.wait(timeout=1)
552              reg.deregister("beta")
553              writer_done.set()
554  
555          reader_thread = threading.Thread(target=reader)
556          writer_thread = threading.Thread(target=writer)
557          reader_thread.start()
558          writer_thread.start()
559          reader_thread.join(timeout=2)
560          writer_thread.join(timeout=2)
561  
562          assert not reader_thread.is_alive()
563          assert not writer_thread.is_alive()
564          assert writer_completed_during_check["value"] is True
565          assert errors == []
566  
567          toolsets = result_holder["value"]
568          assert "gated" in toolsets
569          assert toolsets["gated"]["available"] is True