/ tests / tools / test_mcp_dynamic_discovery.py
test_mcp_dynamic_discovery.py
  1  """Tests for MCP dynamic tool discovery (notifications/tools/list_changed)."""
  2  
  3  import asyncio
  4  from types import SimpleNamespace
  5  from unittest.mock import AsyncMock, MagicMock, patch
  6  
  7  import pytest
  8  
  9  from tools.mcp_tool import MCPServerTask, _register_server_tools
 10  from tools.registry import ToolRegistry
 11  
 12  
 13  def _make_mcp_tool(name: str, desc: str = ""):
 14      return SimpleNamespace(name=name, description=desc, inputSchema=None)
 15  
 16  
 17  class TestRegisterServerTools:
 18      """Tests for the extracted _register_server_tools helper."""
 19  
 20      @pytest.fixture
 21      def mock_registry(self):
 22          return ToolRegistry()
 23  
 24      def test_exposes_live_server_aliases(self, mock_registry):
 25          """Registered MCP tools are reachable via live raw-server aliases."""
 26          server = MCPServerTask("my_srv")
 27          server._tools = [_make_mcp_tool("my_tool", "desc")]
 28          server.session = MagicMock()
 29          from toolsets import resolve_toolset, validate_toolset
 30  
 31          with patch("tools.registry.registry", mock_registry):
 32              registered = _register_server_tools("my_srv", server, {})
 33              assert "mcp_my_srv_my_tool" in registered
 34              assert "mcp_my_srv_my_tool" in mock_registry.get_all_tool_names()
 35              assert validate_toolset("my_srv") is True
 36              assert "mcp_my_srv_my_tool" in resolve_toolset("my_srv")
 37  
 38  
 39  class TestRefreshTools:
 40      """Tests for MCPServerTask._refresh_tools nuke-and-repave cycle."""
 41  
 42      @pytest.fixture
 43      def mock_registry(self):
 44          return ToolRegistry()
 45  
 46      @pytest.mark.asyncio
 47      async def test_nuke_and_repave(self, mock_registry):
 48          """Old tools are removed and new tools registered on refresh."""
 49          server = MCPServerTask("live_srv")
 50          server._refresh_lock = asyncio.Lock()
 51          server._config = {}
 52          from toolsets import resolve_toolset
 53  
 54          # Seed initial state: one old tool registered
 55          mock_registry.register(
 56              name="mcp_live_srv_old_tool", toolset="mcp-live_srv", schema={},
 57              handler=lambda x: x, check_fn=lambda: True, is_async=False,
 58              description="", emoji="",
 59          )
 60          server._registered_tool_names = ["mcp_live_srv_old_tool"]
 61  
 62          # New tool list from server
 63          new_tool = _make_mcp_tool("new_tool", "new behavior")
 64          server.session = SimpleNamespace(
 65              list_tools=AsyncMock(
 66                  return_value=SimpleNamespace(tools=[new_tool])
 67              )
 68          )
 69  
 70          with patch("tools.registry.registry", mock_registry):
 71              await server._refresh_tools()
 72              assert "mcp_live_srv_old_tool" not in mock_registry.get_all_tool_names()
 73              assert "mcp_live_srv_old_tool" not in resolve_toolset("live_srv")
 74              assert "mcp_live_srv_new_tool" in mock_registry.get_all_tool_names()
 75              assert "mcp_live_srv_new_tool" in resolve_toolset("live_srv")
 76              assert server._registered_tool_names == ["mcp_live_srv_new_tool"]
 77  
 78  
 79  class TestMessageHandler:
 80      """Tests for MCPServerTask._make_message_handler dispatch."""
 81  
 82      @pytest.mark.asyncio
 83      async def test_dispatches_tool_list_changed(self):
 84          from tools.mcp_tool import _MCP_NOTIFICATION_TYPES
 85          if not _MCP_NOTIFICATION_TYPES:
 86              pytest.skip("MCP SDK ToolListChangedNotification not available")
 87  
 88          from mcp.types import ServerNotification, ToolListChangedNotification
 89  
 90          server = MCPServerTask("notif_srv")
 91          # Product now schedules the refresh as a background task (see
 92          # _schedule_tools_refresh in mcp_tool.py ~L918) rather than awaiting
 93          # it directly, to avoid wedging the stdio JSON-RPC stream. Patch at
 94          # the scheduler seam so we can still assert dispatch happened without
 95          # reaching into asyncio.create_task internals.
 96          with patch.object(MCPServerTask, "_schedule_tools_refresh") as mock_schedule:
 97              handler = server._make_message_handler()
 98              notification = ServerNotification(
 99                  root=ToolListChangedNotification(method="notifications/tools/list_changed")
100              )
101              await handler(notification)
102              mock_schedule.assert_called_once()
103  
104      @pytest.mark.asyncio
105      async def test_ignores_exceptions_and_other_messages(self):
106          server = MCPServerTask("notif_srv")
107          with patch.object(MCPServerTask, "_schedule_tools_refresh") as mock_schedule:
108              handler = server._make_message_handler()
109              # Exceptions should not trigger refresh
110              await handler(RuntimeError("connection dead"))
111              # Unknown message types should not trigger refresh
112              await handler({"jsonrpc": "2.0", "result": "ok"})
113              mock_schedule.assert_not_called()
114  
115  
116  class TestDeregister:
117      """Tests for ToolRegistry.deregister."""
118  
119      def test_removes_tool(self):
120          reg = ToolRegistry()
121          reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x)
122          assert "foo" in reg.get_all_tool_names()
123          reg.deregister("foo")
124          assert "foo" not in reg.get_all_tool_names()
125  
126      def test_cleans_up_toolset_check(self):
127          reg = ToolRegistry()
128          check = lambda: True  # noqa: E731
129          reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check)
130          assert reg.is_toolset_available("ts1")
131          reg.deregister("foo")
132          # Toolset check should be gone since no tools remain
133          assert "ts1" not in reg._toolset_checks
134  
135      def test_preserves_toolset_check_if_other_tools_remain(self):
136          reg = ToolRegistry()
137          check = lambda: True  # noqa: E731
138          reg.register(name="foo", toolset="ts1", schema={}, handler=lambda x: x, check_fn=check)
139          reg.register(name="bar", toolset="ts1", schema={}, handler=lambda x: x)
140          reg.deregister("foo")
141          # bar still in ts1, so check should remain
142          assert "ts1" in reg._toolset_checks
143  
144      def test_removes_toolset_alias_when_last_tool_is_removed(self):
145          reg = ToolRegistry()
146          reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
147          reg.register_toolset_alias("srv", "mcp-srv")
148  
149          reg.deregister("foo")
150  
151          assert reg.get_toolset_alias_target("srv") is None
152  
153      def test_preserves_toolset_alias_while_toolset_still_exists(self):
154          reg = ToolRegistry()
155          reg.register(name="foo", toolset="mcp-srv", schema={}, handler=lambda x: x)
156          reg.register(name="bar", toolset="mcp-srv", schema={}, handler=lambda x: x)
157          reg.register_toolset_alias("srv", "mcp-srv")
158  
159          reg.deregister("foo")
160  
161          assert reg.get_toolset_alias_target("srv") == "mcp-srv"
162  
163      def test_noop_for_unknown_tool(self):
164          reg = ToolRegistry()
165          reg.deregister("nonexistent")  # Should not raise