/ tests / tools / test_mcp_tool.py
test_mcp_tool.py
   1  """Tests for the MCP (Model Context Protocol) client support.
   2  
   3  All tests use mocks -- no real MCP servers or subprocesses are started.
   4  """
   5  
   6  import asyncio
   7  import json
   8  import os
   9  import threading
  10  import time
  11  from types import SimpleNamespace
  12  from unittest.mock import AsyncMock, MagicMock, patch
  13  
  14  import pytest
  15  
  16  
  17  # ---------------------------------------------------------------------------
  18  # Helpers
  19  # ---------------------------------------------------------------------------
  20  
  21  def _make_mcp_tool(name="read_file", description="Read a file", input_schema=None):
  22      """Create a fake MCP Tool object matching the SDK interface."""
  23      tool = SimpleNamespace()
  24      tool.name = name
  25      tool.description = description
  26      tool.inputSchema = input_schema or {
  27          "type": "object",
  28          "properties": {
  29              "path": {"type": "string", "description": "File path"},
  30          },
  31          "required": ["path"],
  32      }
  33      return tool
  34  
  35  
  36  def _make_call_result(text="file contents here", is_error=False):
  37      """Create a fake MCP CallToolResult."""
  38      block = SimpleNamespace(text=text)
  39      return SimpleNamespace(content=[block], isError=is_error)
  40  
  41  
  42  def _make_mock_server(name, session=None, tools=None):
  43      """Create an MCPServerTask with mock attributes for testing."""
  44      from tools.mcp_tool import MCPServerTask
  45      server = MCPServerTask(name)
  46      server.session = session
  47      server._tools = tools or []
  48      return server
  49  
  50  
  51  # ---------------------------------------------------------------------------
  52  # Config loading
  53  # ---------------------------------------------------------------------------
  54  
  55  class TestLoadMCPConfig:
  56      def test_no_config_returns_empty(self):
  57          """No mcp_servers key in config -> empty dict."""
  58          with patch("hermes_cli.config.load_config", return_value={"model": "test"}):
  59              from tools.mcp_tool import _load_mcp_config
  60              result = _load_mcp_config()
  61              assert result == {}
  62  
  63      def test_valid_config_parsed(self):
  64          """Valid mcp_servers config is returned as-is."""
  65          servers = {
  66              "filesystem": {
  67                  "command": "npx",
  68                  "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
  69                  "env": {},
  70              }
  71          }
  72          with patch("hermes_cli.config.load_config", return_value={"mcp_servers": servers}):
  73              from tools.mcp_tool import _load_mcp_config
  74              result = _load_mcp_config()
  75              assert "filesystem" in result
  76              assert result["filesystem"]["command"] == "npx"
  77  
  78      def test_mcp_servers_not_dict_returns_empty(self):
  79          """mcp_servers set to non-dict value -> empty dict."""
  80          with patch("hermes_cli.config.load_config", return_value={"mcp_servers": "invalid"}):
  81              from tools.mcp_tool import _load_mcp_config
  82              result = _load_mcp_config()
  83              assert result == {}
  84  
  85  
  86  # ---------------------------------------------------------------------------
  87  # Schema conversion
  88  # ---------------------------------------------------------------------------
  89  
  90  class TestSchemaConversion:
  91      def test_converts_mcp_tool_to_hermes_schema(self):
  92          from tools.mcp_tool import _convert_mcp_schema
  93  
  94          mcp_tool = _make_mcp_tool(name="read_file", description="Read a file")
  95          schema = _convert_mcp_schema("filesystem", mcp_tool)
  96  
  97          assert schema["name"] == "mcp_filesystem_read_file"
  98          assert schema["description"] == "Read a file"
  99          assert "properties" in schema["parameters"]
 100  
 101      def test_empty_input_schema_gets_default(self):
 102          from tools.mcp_tool import _convert_mcp_schema
 103  
 104          mcp_tool = _make_mcp_tool(name="ping", description="Ping", input_schema=None)
 105          mcp_tool.inputSchema = None
 106          schema = _convert_mcp_schema("test", mcp_tool)
 107  
 108          assert schema["parameters"]["type"] == "object"
 109          assert schema["parameters"]["properties"] == {}
 110  
 111      def test_object_schema_without_properties_gets_normalized(self):
 112          from tools.mcp_tool import _convert_mcp_schema
 113  
 114          mcp_tool = _make_mcp_tool(
 115              name="ask",
 116              description="Ask Crawl4AI",
 117              input_schema={"type": "object"},
 118          )
 119          schema = _convert_mcp_schema("crawl4ai", mcp_tool)
 120  
 121          assert schema["parameters"] == {"type": "object", "properties": {}}
 122  
 123      def test_definitions_refs_are_rewritten_to_defs(self):
 124          from tools.mcp_tool import _convert_mcp_schema
 125  
 126          mcp_tool = _make_mcp_tool(
 127              name="submit",
 128              description="Submit a payload",
 129              input_schema={
 130                  "type": "object",
 131                  "properties": {
 132                      "input": {"$ref": "#/definitions/Payload"},
 133                  },
 134                  "required": ["input"],
 135                  "definitions": {
 136                      "Payload": {
 137                          "type": "object",
 138                          "properties": {
 139                              "query": {"type": "string"},
 140                          },
 141                          "required": ["query"],
 142                      }
 143                  },
 144              },
 145          )
 146  
 147          schema = _convert_mcp_schema("forms", mcp_tool)
 148  
 149          assert schema["parameters"]["properties"]["input"]["$ref"] == "#/$defs/Payload"
 150          assert "$defs" in schema["parameters"]
 151          assert "definitions" not in schema["parameters"]
 152  
 153      def test_nested_definition_refs_are_rewritten_recursively(self):
 154          from tools.mcp_tool import _convert_mcp_schema
 155  
 156          mcp_tool = _make_mcp_tool(
 157              name="nested",
 158              description="Nested schema",
 159              input_schema={
 160                  "type": "object",
 161                  "properties": {
 162                      "items": {
 163                          "type": "array",
 164                          "items": {"$ref": "#/definitions/Entry"},
 165                      },
 166                  },
 167                  "definitions": {
 168                      "Entry": {
 169                          "type": "object",
 170                          "properties": {
 171                              "child": {"$ref": "#/definitions/Child"},
 172                          },
 173                      },
 174                      "Child": {
 175                          "type": "object",
 176                          "properties": {
 177                              "value": {"type": "string"},
 178                          },
 179                      },
 180                  },
 181              },
 182          )
 183  
 184          schema = _convert_mcp_schema("forms", mcp_tool)
 185  
 186          assert schema["parameters"]["properties"]["items"]["items"]["$ref"] == "#/$defs/Entry"
 187          assert schema["parameters"]["$defs"]["Entry"]["properties"]["child"]["$ref"] == "#/$defs/Child"
 188  
 189      def test_missing_type_on_object_is_coerced(self):
 190          """Schemas that describe an object but omit ``type`` get type='object'."""
 191          from tools.mcp_tool import _normalize_mcp_input_schema
 192  
 193          schema = _normalize_mcp_input_schema({
 194              "properties": {"q": {"type": "string"}},
 195              "required": ["q"],
 196          })
 197  
 198          assert schema["type"] == "object"
 199          assert schema["properties"]["q"]["type"] == "string"
 200          assert schema["required"] == ["q"]
 201  
 202      def test_null_type_on_object_is_coerced(self):
 203          """type: None should be treated like missing type (common MCP server bug)."""
 204          from tools.mcp_tool import _normalize_mcp_input_schema
 205  
 206          schema = _normalize_mcp_input_schema({
 207              "type": None,
 208              "properties": {"x": {"type": "integer"}},
 209          })
 210  
 211          assert schema["type"] == "object"
 212  
 213      def test_required_pruned_when_property_missing(self):
 214          """Gemini 400s on required names that don't exist in properties."""
 215          from tools.mcp_tool import _normalize_mcp_input_schema
 216  
 217          schema = _normalize_mcp_input_schema({
 218              "type": "object",
 219              "properties": {"a": {"type": "string"}},
 220              "required": ["a", "ghost", "phantom"],
 221          })
 222  
 223          assert schema["required"] == ["a"]
 224  
 225      def test_required_removed_when_all_names_dangle(self):
 226          from tools.mcp_tool import _normalize_mcp_input_schema
 227  
 228          schema = _normalize_mcp_input_schema({
 229              "type": "object",
 230              "properties": {},
 231              "required": ["ghost"],
 232          })
 233  
 234          assert "required" not in schema
 235  
 236      def test_required_pruning_applies_recursively_inside_nested_objects(self):
 237          """Nested object schemas also get required pruning."""
 238          from tools.mcp_tool import _normalize_mcp_input_schema
 239  
 240          schema = _normalize_mcp_input_schema({
 241              "type": "object",
 242              "properties": {
 243                  "filter": {
 244                      "type": "object",
 245                      "properties": {"field": {"type": "string"}},
 246                      "required": ["field", "missing"],
 247                  },
 248              },
 249          })
 250  
 251          assert schema["properties"]["filter"]["required"] == ["field"]
 252  
 253      def test_object_in_array_items_gets_properties_filled(self):
 254          """Array-item object schemas without properties get an empty dict."""
 255          from tools.mcp_tool import _normalize_mcp_input_schema
 256  
 257          schema = _normalize_mcp_input_schema({
 258              "type": "object",
 259              "properties": {
 260                  "items": {
 261                      "type": "array",
 262                      "items": {"type": "object"},
 263                  },
 264              },
 265          })
 266  
 267          assert schema["properties"]["items"]["items"]["properties"] == {}
 268  
 269      def test_optional_nullable_field_is_collapsed_to_non_null_schema(self):
 270          """Anthropic rejects MCP/Pydantic anyOf-null optional parameter schemas."""
 271          from tools.mcp_tool import _normalize_mcp_input_schema
 272  
 273          schema = _normalize_mcp_input_schema({
 274              "type": "object",
 275              "properties": {
 276                  "command": {"type": "string"},
 277                  "workdir": {
 278                      "anyOf": [{"type": "string"}, {"type": "null"}],
 279                      "default": None,
 280                      "description": "Optional working directory",
 281                  },
 282              },
 283              "required": ["command"],
 284          })
 285  
 286          assert schema["properties"]["workdir"] == {
 287              "type": "string",
 288              "nullable": True,
 289              "default": None,
 290              "description": "Optional working directory",
 291          }
 292          assert schema["required"] == ["command"]
 293  
 294      def test_nested_nullable_array_items_are_collapsed(self):
 295          from tools.mcp_tool import _normalize_mcp_input_schema
 296  
 297          schema = _normalize_mcp_input_schema({
 298              "type": "object",
 299              "properties": {
 300                  "filters": {
 301                      "type": "array",
 302                      "items": {
 303                          "oneOf": [
 304                              {
 305                                  "type": "object",
 306                                  "properties": {"field": {"type": "string"}},
 307                              },
 308                              {"type": "null"},
 309                          ]
 310                      },
 311                  }
 312              },
 313          })
 314  
 315          assert schema["properties"]["filters"]["items"] == {
 316              "type": "object",
 317              "properties": {"field": {"type": "string"}},
 318              "nullable": True,
 319          }
 320  
 321      def test_convert_mcp_schema_survives_missing_inputschema_attribute(self):
 322          """A Tool object without .inputSchema must not crash registration."""
 323          import types
 324  
 325          from tools.mcp_tool import _convert_mcp_schema
 326  
 327          bare_tool = types.SimpleNamespace(name="probe", description="Probe")
 328          schema = _convert_mcp_schema("srv", bare_tool)
 329  
 330          assert schema["name"] == "mcp_srv_probe"
 331          assert schema["parameters"] == {"type": "object", "properties": {}}
 332  
 333      def test_convert_mcp_schema_with_none_inputschema(self):
 334          """Tool with inputSchema=None produces a valid empty object schema."""
 335          import types
 336  
 337          from tools.mcp_tool import _convert_mcp_schema
 338  
 339          # Note: _make_mcp_tool(input_schema=None) falls back to a default —
 340          # build the namespace directly so .inputSchema really is None.
 341          mcp_tool = types.SimpleNamespace(name="probe", description="Probe", inputSchema=None)
 342          schema = _convert_mcp_schema("srv", mcp_tool)
 343  
 344          assert schema["parameters"] == {"type": "object", "properties": {}}
 345  
 346      def test_tool_name_prefix_format(self):
 347          from tools.mcp_tool import _convert_mcp_schema
 348  
 349          mcp_tool = _make_mcp_tool(name="list_dir")
 350          schema = _convert_mcp_schema("my_server", mcp_tool)
 351  
 352          assert schema["name"] == "mcp_my_server_list_dir"
 353  
 354      def test_hyphens_sanitized_to_underscores(self):
 355          """Hyphens in tool/server names are replaced with underscores for LLM compat."""
 356          from tools.mcp_tool import _convert_mcp_schema
 357  
 358          mcp_tool = _make_mcp_tool(name="get-sum")
 359          schema = _convert_mcp_schema("my-server", mcp_tool)
 360  
 361          assert schema["name"] == "mcp_my_server_get_sum"
 362          assert "-" not in schema["name"]
 363  
 364  
 365  # ---------------------------------------------------------------------------
 366  # Check function
 367  # ---------------------------------------------------------------------------
 368  
 369  class TestCheckFunction:
 370      def test_disconnected_returns_false(self):
 371          from tools.mcp_tool import _make_check_fn, _servers
 372  
 373          _servers.pop("test_server", None)
 374          check = _make_check_fn("test_server")
 375          assert check() is False
 376  
 377      def test_connected_returns_true(self):
 378          from tools.mcp_tool import _make_check_fn, _servers
 379  
 380          server = _make_mock_server("test_server", session=MagicMock())
 381          _servers["test_server"] = server
 382          try:
 383              check = _make_check_fn("test_server")
 384              assert check() is True
 385          finally:
 386              _servers.pop("test_server", None)
 387  
 388      def test_session_none_returns_false(self):
 389          from tools.mcp_tool import _make_check_fn, _servers
 390  
 391          server = _make_mock_server("test_server", session=None)
 392          _servers["test_server"] = server
 393          try:
 394              check = _make_check_fn("test_server")
 395              assert check() is False
 396          finally:
 397              _servers.pop("test_server", None)
 398  
 399  
 400  # ---------------------------------------------------------------------------
 401  # Tool handler
 402  # ---------------------------------------------------------------------------
 403  
 404  class TestToolHandler:
 405      """Tool handlers are sync functions that schedule work on the MCP loop."""
 406  
 407      def _patch_mcp_loop(self, coro_side_effect=None):
 408          """Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
 409          def fake_run(coro, timeout=30):
 410              return asyncio.run(coro)
 411          if coro_side_effect:
 412              return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect)
 413          return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
 414  
 415      def test_successful_call(self):
 416          from tools.mcp_tool import _make_tool_handler, _servers
 417  
 418          mock_session = MagicMock()
 419          mock_session.call_tool = AsyncMock(
 420              return_value=_make_call_result("hello world", is_error=False)
 421          )
 422          server = _make_mock_server("test_srv", session=mock_session)
 423          _servers["test_srv"] = server
 424  
 425          try:
 426              handler = _make_tool_handler("test_srv", "greet", 120)
 427              with self._patch_mcp_loop():
 428                  result = json.loads(handler({"name": "world"}))
 429              assert result["result"] == "hello world"
 430              mock_session.call_tool.assert_called_once_with("greet", arguments={"name": "world"})
 431          finally:
 432              _servers.pop("test_srv", None)
 433  
 434      def test_mcp_error_result(self):
 435          from tools.mcp_tool import _make_tool_handler, _servers
 436  
 437          mock_session = MagicMock()
 438          mock_session.call_tool = AsyncMock(
 439              return_value=_make_call_result("something went wrong", is_error=True)
 440          )
 441          server = _make_mock_server("test_srv", session=mock_session)
 442          _servers["test_srv"] = server
 443  
 444          try:
 445              handler = _make_tool_handler("test_srv", "fail_tool", 120)
 446              with self._patch_mcp_loop():
 447                  result = json.loads(handler({}))
 448              assert "error" in result
 449              assert "something went wrong" in result["error"]
 450          finally:
 451              _servers.pop("test_srv", None)
 452  
 453      def test_disconnected_server(self):
 454          from tools.mcp_tool import _make_tool_handler, _servers
 455  
 456          _servers.pop("ghost", None)
 457          handler = _make_tool_handler("ghost", "any_tool", 120)
 458          result = json.loads(handler({}))
 459          assert "error" in result
 460          assert "not connected" in result["error"]
 461  
 462      def test_exception_during_call(self):
 463          from tools.mcp_tool import _make_tool_handler, _servers
 464  
 465          mock_session = MagicMock()
 466          mock_session.call_tool = AsyncMock(side_effect=RuntimeError("connection lost"))
 467          server = _make_mock_server("test_srv", session=mock_session)
 468          _servers["test_srv"] = server
 469  
 470          try:
 471              handler = _make_tool_handler("test_srv", "broken_tool", 120)
 472              with self._patch_mcp_loop():
 473                  result = json.loads(handler({}))
 474              assert "error" in result
 475              assert "connection lost" in result["error"]
 476          finally:
 477              _servers.pop("test_srv", None)
 478  
 479      def test_interrupted_call_returns_interrupted_error(self):
 480          from tools.mcp_tool import _make_tool_handler, _servers
 481  
 482          mock_session = MagicMock()
 483          server = _make_mock_server("test_srv", session=mock_session)
 484          _servers["test_srv"] = server
 485  
 486          try:
 487              handler = _make_tool_handler("test_srv", "greet", 120)
 488              def _interrupting_run(coro, timeout=30):
 489                  coro.close()
 490                  raise InterruptedError("User sent a new message")
 491              with patch(
 492                  "tools.mcp_tool._run_on_mcp_loop",
 493                  side_effect=_interrupting_run,
 494              ):
 495                  result = json.loads(handler({}))
 496              assert result == {"error": "MCP call interrupted: user sent a new message"}
 497          finally:
 498              _servers.pop("test_srv", None)
 499  
 500  
 501  class TestRunOnMCPLoopInterrupts:
 502      def test_interrupt_cancels_waiting_mcp_call(self):
 503          import tools.mcp_tool as mcp_mod
 504          from tools.interrupt import set_interrupt
 505  
 506          loop = asyncio.new_event_loop()
 507          thread = threading.Thread(target=loop.run_forever, daemon=True)
 508          thread.start()
 509  
 510          cancelled = threading.Event()
 511  
 512          async def _slow_call():
 513              try:
 514                  await asyncio.sleep(5)
 515                  return "done"
 516              except asyncio.CancelledError:
 517                  cancelled.set()
 518                  raise
 519  
 520          old_loop = mcp_mod._mcp_loop
 521          old_thread = mcp_mod._mcp_thread
 522          mcp_mod._mcp_loop = loop
 523          mcp_mod._mcp_thread = thread
 524  
 525          waiter_tid = threading.current_thread().ident
 526  
 527          def _interrupt_soon():
 528              time.sleep(0.2)
 529              set_interrupt(True, waiter_tid)
 530  
 531          interrupter = threading.Thread(target=_interrupt_soon, daemon=True)
 532          interrupter.start()
 533  
 534          try:
 535              with pytest.raises(InterruptedError, match="User sent a new message"):
 536                  mcp_mod._run_on_mcp_loop(_slow_call(), timeout=2)
 537  
 538              deadline = time.time() + 2
 539              while time.time() < deadline and not cancelled.is_set():
 540                  time.sleep(0.05)
 541              assert cancelled.is_set()
 542          finally:
 543              set_interrupt(False, waiter_tid)
 544              loop.call_soon_threadsafe(loop.stop)
 545              thread.join(timeout=2)
 546              loop.close()
 547              mcp_mod._mcp_loop = old_loop
 548              mcp_mod._mcp_thread = old_thread
 549  
 550  
 551  # ---------------------------------------------------------------------------
 552  # Tool registration (discovery + register)
 553  # ---------------------------------------------------------------------------
 554  
 555  class TestDiscoverAndRegister:
 556      def test_tools_registered_in_registry(self):
 557          """_discover_and_register_server registers tools with correct names."""
 558          from tools.registry import ToolRegistry
 559          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
 560  
 561          mock_registry = ToolRegistry()
 562          mock_tools = [
 563              _make_mcp_tool("read_file", "Read a file"),
 564              _make_mcp_tool("write_file", "Write a file"),
 565          ]
 566          mock_session = MagicMock()
 567  
 568          async def fake_connect(name, config):
 569              server = MCPServerTask(name)
 570              server.session = mock_session
 571              server._tools = mock_tools
 572              return server
 573  
 574          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
 575               patch("tools.registry.registry", mock_registry):
 576              registered = asyncio.run(
 577                  _discover_and_register_server("fs", {"command": "npx", "args": []})
 578              )
 579  
 580          assert "mcp_fs_read_file" in registered
 581          assert "mcp_fs_write_file" in registered
 582          assert "mcp_fs_read_file" in mock_registry.get_all_tool_names()
 583          assert "mcp_fs_write_file" in mock_registry.get_all_tool_names()
 584  
 585          _servers.pop("fs", None)
 586  
 587      def test_toolset_resolves_live_from_registry(self):
 588          """MCP toolsets resolve through the live registry without TOOLSETS mutation."""
 589          from tools.registry import ToolRegistry
 590          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
 591          from toolsets import resolve_toolset, validate_toolset
 592  
 593          mock_registry = ToolRegistry()
 594          mock_tools = [_make_mcp_tool("ping", "Ping")]
 595          mock_session = MagicMock()
 596  
 597          async def fake_connect(name, config):
 598              server = MCPServerTask(name)
 599              server.session = mock_session
 600              server._tools = mock_tools
 601              return server
 602  
 603          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
 604               patch("tools.registry.registry", mock_registry):
 605              asyncio.run(
 606                  _discover_and_register_server("myserver", {"command": "test"})
 607              )
 608  
 609              assert validate_toolset("myserver") is True
 610              assert validate_toolset("mcp-myserver") is True
 611              assert "mcp_myserver_ping" in resolve_toolset("myserver")
 612              assert "mcp_myserver_ping" in resolve_toolset("mcp-myserver")
 613  
 614          _servers.pop("myserver", None)
 615  
 616      def test_schema_format_correct(self):
 617          """Registered schemas have the correct format."""
 618          from tools.registry import ToolRegistry
 619          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
 620  
 621          mock_registry = ToolRegistry()
 622          mock_tools = [_make_mcp_tool("do_thing", "Do something")]
 623          mock_session = MagicMock()
 624  
 625          async def fake_connect(name, config):
 626              server = MCPServerTask(name)
 627              server.session = mock_session
 628              server._tools = mock_tools
 629              return server
 630  
 631          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
 632               patch("tools.registry.registry", mock_registry):
 633              asyncio.run(
 634                  _discover_and_register_server("srv", {"command": "test"})
 635              )
 636  
 637          entry = mock_registry._tools.get("mcp_srv_do_thing")
 638          assert entry is not None
 639          assert entry.schema["name"] == "mcp_srv_do_thing"
 640          assert "parameters" in entry.schema
 641          assert entry.is_async is False
 642          assert entry.toolset == "mcp-srv"
 643  
 644          _servers.pop("srv", None)
 645  
 646  
 647  # ---------------------------------------------------------------------------
 648  # MCPServerTask (run / start / shutdown)
 649  # ---------------------------------------------------------------------------
 650  
 651  class TestMCPServerTask:
 652      """Test the MCPServerTask lifecycle with mocked MCP SDK."""
 653  
 654      def _mock_stdio_and_session(self, session):
 655          """Return patches for stdio_client and ClientSession as async CMs."""
 656          mock_read, mock_write = MagicMock(), MagicMock()
 657  
 658          mock_stdio_cm = MagicMock()
 659          mock_stdio_cm.__aenter__ = AsyncMock(return_value=(mock_read, mock_write))
 660          mock_stdio_cm.__aexit__ = AsyncMock(return_value=False)
 661  
 662          mock_cs_cm = MagicMock()
 663          mock_cs_cm.__aenter__ = AsyncMock(return_value=session)
 664          mock_cs_cm.__aexit__ = AsyncMock(return_value=False)
 665  
 666          return (
 667              patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm),
 668              patch("tools.mcp_tool.ClientSession", return_value=mock_cs_cm),
 669              mock_read, mock_write,
 670          )
 671  
 672      def test_start_connects_and_discovers_tools(self):
 673          """start() creates a Task that connects, discovers tools, and waits."""
 674          from tools.mcp_tool import MCPServerTask
 675  
 676          mock_tools = [_make_mcp_tool("echo")]
 677          mock_session = MagicMock()
 678          mock_session.initialize = AsyncMock()
 679          mock_session.list_tools = AsyncMock(
 680              return_value=SimpleNamespace(tools=mock_tools)
 681          )
 682  
 683          p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
 684  
 685          async def _test():
 686              with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs:
 687                  server = MCPServerTask("test_srv")
 688                  await server.start({"command": "npx", "args": ["-y", "test"]})
 689  
 690                  assert server.session is mock_session
 691                  assert len(server._tools) == 1
 692                  assert server._tools[0].name == "echo"
 693                  mock_session.initialize.assert_called_once()
 694  
 695                  await server.shutdown()
 696                  assert server.session is None
 697  
 698          asyncio.run(_test())
 699  
 700      def test_no_command_raises(self):
 701          """Missing 'command' in config raises ValueError."""
 702          from tools.mcp_tool import MCPServerTask
 703  
 704          async def _test():
 705              server = MCPServerTask("bad")
 706              with pytest.raises(ValueError, match="no 'command'"):
 707                  await server.start({"args": []})
 708  
 709          asyncio.run(_test())
 710  
 711      def test_refresh_tools_deregisters_removed_tools(self):
 712          """Dynamic refresh removes stale registry entries for deleted tools."""
 713          from tools.registry import ToolRegistry
 714          from tools.mcp_tool import MCPServerTask
 715  
 716          mock_registry = ToolRegistry()
 717          server = MCPServerTask("srv")
 718          server._config = {"command": "test"}
 719          server._tools = [_make_mcp_tool("old"), _make_mcp_tool("keep")]
 720          server._registered_tool_names = ["mcp_srv_old", "mcp_srv_keep"]
 721          server.session = MagicMock()
 722          server.session.list_tools = AsyncMock(
 723              return_value=SimpleNamespace(tools=[_make_mcp_tool("keep"), _make_mcp_tool("new")])
 724          )
 725  
 726          with patch("tools.registry.registry", mock_registry):
 727              mock_registry.register(
 728                  name="mcp_srv_old",
 729                  toolset="mcp-srv",
 730                  schema={"name": "mcp_srv_old", "description": "Old"},
 731                  handler=lambda *_args, **_kwargs: "{}",
 732              )
 733              mock_registry.register(
 734                  name="mcp_srv_keep",
 735                  toolset="mcp-srv",
 736                  schema={"name": "mcp_srv_keep", "description": "Keep"},
 737                  handler=lambda *_args, **_kwargs: "{}",
 738              )
 739  
 740              asyncio.run(server._refresh_tools())
 741  
 742              names = mock_registry.get_all_tool_names()
 743              assert "mcp_srv_old" not in names
 744              assert "mcp_srv_keep" in names
 745              assert "mcp_srv_new" in names
 746              assert set(server._registered_tool_names) == {
 747                  "mcp_srv_keep",
 748                  "mcp_srv_new",
 749                  "mcp_srv_list_resources",
 750                  "mcp_srv_read_resource",
 751                  "mcp_srv_list_prompts",
 752                  "mcp_srv_get_prompt",
 753              }
 754  
 755      def test_schedule_tools_refresh_keeps_task_until_done(self):
 756          """Background refresh tasks are strongly referenced and then discarded."""
 757          from tools.mcp_tool import MCPServerTask
 758  
 759          async def _test():
 760              started = asyncio.Event()
 761              finish = asyncio.Event()
 762              server = MCPServerTask("srv")
 763  
 764              async def fake_refresh(_server):
 765                  started.set()
 766                  await finish.wait()
 767  
 768              with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
 769                  server._schedule_tools_refresh()
 770  
 771                  await started.wait()
 772                  assert len(server._pending_refresh_tasks) == 1
 773                  task = next(iter(server._pending_refresh_tasks))
 774                  assert not task.done()
 775  
 776                  finish.set()
 777                  await task
 778                  await asyncio.sleep(0)
 779                  assert server._pending_refresh_tasks == set()
 780  
 781          asyncio.run(_test())
 782  
 783      def test_shutdown_cancels_pending_refresh_tasks(self):
 784          """shutdown() cancels in-flight background refresh tasks."""
 785          from tools.mcp_tool import MCPServerTask
 786  
 787          async def _test():
 788              started = asyncio.Event()
 789              cancelled = asyncio.Event()
 790              server = MCPServerTask("srv")
 791  
 792              async def fake_refresh(_server):
 793                  started.set()
 794                  try:
 795                      await asyncio.sleep(3600)
 796                  except asyncio.CancelledError:
 797                      cancelled.set()
 798                      raise
 799  
 800              with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh):
 801                  server._schedule_tools_refresh()
 802                  await started.wait()
 803  
 804                  await server.shutdown()
 805  
 806              assert cancelled.is_set()
 807              assert server._pending_refresh_tasks == set()
 808  
 809          asyncio.run(_test())
 810  
 811      def test_empty_env_gets_safe_defaults(self):
 812          """Empty env dict gets safe default env vars (PATH, HOME, etc.)."""
 813          from tools.mcp_tool import MCPServerTask
 814  
 815          mock_session = MagicMock()
 816          mock_session.initialize = AsyncMock()
 817          mock_session.list_tools = AsyncMock(
 818              return_value=SimpleNamespace(tools=[])
 819          )
 820  
 821          p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
 822  
 823          async def _test():
 824              with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \
 825                   p_stdio, p_cs, \
 826                   patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False):
 827                  server = MCPServerTask("srv")
 828                  await server.start({"command": "node", "env": {}})
 829  
 830                  # Empty dict -> safe env vars (not None)
 831                  call_kwargs = mock_params.call_args
 832                  env_arg = call_kwargs.kwargs.get("env")
 833                  assert env_arg is not None
 834                  assert isinstance(env_arg, dict)
 835                  assert "PATH" in env_arg
 836                  assert "HOME" in env_arg
 837  
 838                  await server.shutdown()
 839  
 840          asyncio.run(_test())
 841  
 842      def test_shutdown_signals_task_exit(self):
 843          """shutdown() signals the event and waits for task completion."""
 844          from tools.mcp_tool import MCPServerTask
 845  
 846          mock_session = MagicMock()
 847          mock_session.initialize = AsyncMock()
 848          mock_session.list_tools = AsyncMock(
 849              return_value=SimpleNamespace(tools=[])
 850          )
 851  
 852          p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session)
 853  
 854          async def _test():
 855              with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs:
 856                  server = MCPServerTask("srv")
 857                  await server.start({"command": "npx"})
 858  
 859                  assert server.session is not None
 860                  assert not server._task.done()
 861  
 862                  await server.shutdown()
 863  
 864                  assert server.session is None
 865                  assert server._task.done()
 866  
 867          asyncio.run(_test())
 868  
 869  
 870  # ---------------------------------------------------------------------------
 871  # discover_mcp_tools toolset injection
 872  # ---------------------------------------------------------------------------
 873  
 874  class TestToolsetInjection:
 875      def test_mcp_tools_resolve_through_server_aliases(self):
 876          """Discovered MCP tools resolve through raw server-name aliases."""
 877          from tools.mcp_tool import MCPServerTask
 878          from tools.registry import ToolRegistry
 879          from toolsets import resolve_toolset, validate_toolset
 880  
 881          mock_tools = [_make_mcp_tool("list_files", "List files")]
 882          mock_session = MagicMock()
 883          mock_registry = ToolRegistry()
 884  
 885          fresh_servers = {}
 886  
 887          async def fake_connect(name, config):
 888              server = MCPServerTask(name)
 889              server.session = mock_session
 890              server._tools = mock_tools
 891              return server
 892  
 893          fake_config = {"fs": {"command": "npx", "args": []}}
 894  
 895          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
 896               patch("tools.mcp_tool._servers", fresh_servers), \
 897               patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
 898               patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
 899               patch("tools.registry.registry", mock_registry):
 900              from tools.mcp_tool import discover_mcp_tools
 901              result = discover_mcp_tools()
 902  
 903              assert "mcp_fs_list_files" in result
 904              assert validate_toolset("fs") is True
 905              assert validate_toolset("mcp-fs") is True
 906              assert "mcp_fs_list_files" in resolve_toolset("fs")
 907              assert "mcp_fs_list_files" in resolve_toolset("mcp-fs")
 908  
 909      def test_server_toolset_skips_builtin_collision(self):
 910          """MCP raw aliases never overwrite a built-in toolset name."""
 911          from tools.mcp_tool import MCPServerTask
 912          from tools.registry import ToolRegistry
 913          from toolsets import resolve_toolset, validate_toolset
 914  
 915          mock_tools = [_make_mcp_tool("run", "Run command")]
 916          mock_session = MagicMock()
 917          fresh_servers = {}
 918          mock_registry = ToolRegistry()
 919  
 920          async def fake_connect(name, config):
 921              server = MCPServerTask(name)
 922              server.session = mock_session
 923              server._tools = mock_tools
 924              return server
 925  
 926          fake_toolsets = {
 927              "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []},
 928              # Built-in toolset named "terminal" — must not be overwritten
 929              "terminal": {"tools": ["terminal"], "description": "Terminal tools", "includes": []},
 930          }
 931          fake_config = {"terminal": {"command": "npx", "args": []}}
 932  
 933          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
 934               patch("tools.mcp_tool._servers", fresh_servers), \
 935               patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
 936               patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
 937               patch("tools.registry.registry", mock_registry), \
 938               patch("toolsets.TOOLSETS", fake_toolsets):
 939              from tools.mcp_tool import discover_mcp_tools
 940              discover_mcp_tools()
 941  
 942              assert fake_toolsets["terminal"]["description"] == "Terminal tools"
 943              assert "mcp_terminal_run" not in resolve_toolset("terminal")
 944              assert validate_toolset("mcp-terminal") is True
 945              assert "mcp_terminal_run" in resolve_toolset("mcp-terminal")
 946  
 947      def test_server_connection_failure_skipped(self):
 948          """If one server fails to connect, others still proceed."""
 949          from tools.mcp_tool import MCPServerTask
 950  
 951          mock_tools = [_make_mcp_tool("ping", "Ping")]
 952          mock_session = MagicMock()
 953  
 954          fresh_servers = {}
 955          call_count = 0
 956  
 957          async def flaky_connect(name, config):
 958              nonlocal call_count
 959              call_count += 1
 960              if name == "broken":
 961                  raise ConnectionError("cannot reach server")
 962              server = MCPServerTask(name)
 963              server.session = mock_session
 964              server._tools = mock_tools
 965              return server
 966  
 967          fake_config = {
 968              "broken": {"command": "bad"},
 969              "good": {"command": "npx", "args": []},
 970          }
 971          fake_toolsets = {
 972              "hermes-cli": {"tools": [], "description": "CLI", "includes": []},
 973          }
 974  
 975          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
 976               patch("tools.mcp_tool._servers", fresh_servers), \
 977               patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
 978               patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \
 979               patch("toolsets.TOOLSETS", fake_toolsets):
 980              from tools.mcp_tool import discover_mcp_tools
 981              result = discover_mcp_tools()
 982  
 983          assert "mcp_good_ping" in result
 984          assert "mcp_broken_ping" not in result
 985          assert call_count == 2
 986  
 987      def test_partial_failure_retry_on_second_call(self):
 988          """Failed servers are retried on subsequent discover_mcp_tools() calls."""
 989          from tools.mcp_tool import MCPServerTask
 990  
 991          mock_tools = [_make_mcp_tool("ping", "Ping")]
 992          mock_session = MagicMock()
 993  
 994          # Use a real dict so idempotency logic works correctly
 995          fresh_servers = {}
 996          call_count = 0
 997          broken_fixed = False
 998  
 999          async def flaky_connect(name, config):
1000              nonlocal call_count
1001              call_count += 1
1002              if name == "broken" and not broken_fixed:
1003                  raise ConnectionError("cannot reach server")
1004              server = MCPServerTask(name)
1005              server.session = mock_session
1006              server._tools = mock_tools
1007              return server
1008  
1009          fake_config = {
1010              "broken": {"command": "bad"},
1011              "good": {"command": "npx", "args": []},
1012          }
1013          fake_toolsets = {
1014              "hermes-cli": {"tools": [], "description": "CLI", "includes": []},
1015          }
1016  
1017          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
1018               patch("tools.mcp_tool._servers", fresh_servers), \
1019               patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
1020               patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \
1021               patch("toolsets.TOOLSETS", fake_toolsets):
1022              from tools.mcp_tool import discover_mcp_tools
1023  
1024              # First call: good connects, broken fails
1025              result1 = discover_mcp_tools()
1026              assert "mcp_good_ping" in result1
1027              assert "mcp_broken_ping" not in result1
1028              first_attempts = call_count
1029  
1030              # "Fix" the broken server
1031              broken_fixed = True
1032              call_count = 0
1033  
1034              # Second call: should retry broken, skip good
1035              result2 = discover_mcp_tools()
1036              assert "mcp_good_ping" in result2
1037              assert "mcp_broken_ping" in result2
1038              assert call_count == 1  # Only broken retried
1039  
1040  
1041  # ---------------------------------------------------------------------------
1042  # Graceful fallback
1043  # ---------------------------------------------------------------------------
1044  
1045  class TestGracefulFallback:
1046      def test_mcp_unavailable_returns_empty(self):
1047          """When _MCP_AVAILABLE is False, discover_mcp_tools is a no-op."""
1048          with patch("tools.mcp_tool._MCP_AVAILABLE", False):
1049              from tools.mcp_tool import discover_mcp_tools
1050              result = discover_mcp_tools()
1051              assert result == []
1052  
1053      def test_no_servers_returns_empty(self):
1054          """No MCP servers configured -> empty list."""
1055          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
1056               patch("tools.mcp_tool._servers", {}), \
1057               patch("tools.mcp_tool._load_mcp_config", return_value={}):
1058              from tools.mcp_tool import discover_mcp_tools
1059              result = discover_mcp_tools()
1060              assert result == []
1061  
1062  
1063  # ---------------------------------------------------------------------------
1064  # Shutdown (public API)
1065  # ---------------------------------------------------------------------------
1066  
1067  class TestShutdown:
1068      def test_no_servers_safe(self):
1069          """shutdown_mcp_servers with no servers does nothing."""
1070          from tools.mcp_tool import shutdown_mcp_servers, _servers
1071  
1072          _servers.clear()
1073          shutdown_mcp_servers()  # Should not raise
1074  
1075      def test_shutdown_clears_servers(self):
1076          """shutdown_mcp_servers calls shutdown() on each server and clears dict."""
1077          import tools.mcp_tool as mcp_mod
1078          from tools.mcp_tool import shutdown_mcp_servers, _servers
1079  
1080          _servers.clear()
1081          mock_server = MagicMock()
1082          mock_server.name = "test"
1083          mock_server.shutdown = AsyncMock()
1084          _servers["test"] = mock_server
1085  
1086          mcp_mod._ensure_mcp_loop()
1087          try:
1088              shutdown_mcp_servers()
1089          finally:
1090              mcp_mod._mcp_loop = None
1091              mcp_mod._mcp_thread = None
1092  
1093          assert len(_servers) == 0
1094          mock_server.shutdown.assert_called_once()
1095  
1096      def test_shutdown_deregisters_registered_tools(self):
1097          """shutdown_mcp_servers removes MCP tools and their raw alias."""
1098          import tools.mcp_tool as mcp_mod
1099          from tools.mcp_tool import MCPServerTask, shutdown_mcp_servers, _servers
1100          from tools.registry import registry
1101          from toolsets import resolve_toolset, validate_toolset
1102  
1103          _servers.clear()
1104          registry.register(
1105              name="mcp_test_ping",
1106              toolset="mcp-test",
1107              schema={
1108                  "name": "mcp_test_ping",
1109                  "description": "Ping",
1110                  "parameters": {"type": "object", "properties": {}},
1111              },
1112              handler=lambda *_args, **_kwargs: "{}",
1113          )
1114          registry.register_toolset_alias("test", "mcp-test")
1115  
1116          server = MCPServerTask("test")
1117          server._registered_tool_names = ["mcp_test_ping"]
1118          _servers["test"] = server
1119  
1120          mcp_mod._ensure_mcp_loop()
1121          try:
1122              assert validate_toolset("test") is True
1123              assert "mcp_test_ping" in resolve_toolset("test")
1124              shutdown_mcp_servers()
1125          finally:
1126              mcp_mod._mcp_loop = None
1127              mcp_mod._mcp_thread = None
1128  
1129          assert "mcp_test_ping" not in registry.get_all_tool_names()
1130          assert validate_toolset("test") is False
1131  
1132      def test_shutdown_handles_errors(self):
1133          """shutdown_mcp_servers handles errors during close gracefully."""
1134          import tools.mcp_tool as mcp_mod
1135          from tools.mcp_tool import shutdown_mcp_servers, _servers
1136  
1137          _servers.clear()
1138          mock_server = MagicMock()
1139          mock_server.name = "broken"
1140          mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed"))
1141          _servers["broken"] = mock_server
1142  
1143          mcp_mod._ensure_mcp_loop()
1144          try:
1145              shutdown_mcp_servers()  # Should not raise
1146          finally:
1147              mcp_mod._mcp_loop = None
1148              mcp_mod._mcp_thread = None
1149  
1150          assert len(_servers) == 0
1151  
1152      def test_shutdown_is_parallel(self):
1153          """Multiple servers are shut down in parallel via asyncio.gather."""
1154          import tools.mcp_tool as mcp_mod
1155          from tools.mcp_tool import shutdown_mcp_servers, _servers
1156          import time
1157  
1158          _servers.clear()
1159  
1160          # 3 servers each taking 1s to shut down
1161          for i in range(3):
1162              mock_server = MagicMock()
1163              mock_server.name = f"srv_{i}"
1164              async def slow_shutdown():
1165                  await asyncio.sleep(1)
1166              mock_server.shutdown = slow_shutdown
1167              _servers[f"srv_{i}"] = mock_server
1168  
1169          mcp_mod._ensure_mcp_loop()
1170          try:
1171              start = time.monotonic()
1172              shutdown_mcp_servers()
1173              elapsed = time.monotonic() - start
1174          finally:
1175              mcp_mod._mcp_loop = None
1176              mcp_mod._mcp_thread = None
1177  
1178          assert len(_servers) == 0
1179          # Parallel: ~1s, not ~3s. Allow some margin.
1180          assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)"
1181  
1182  
1183  # ---------------------------------------------------------------------------
1184  # _build_safe_env
1185  # ---------------------------------------------------------------------------
1186  
1187  class TestBuildSafeEnv:
1188      """Tests for _build_safe_env() environment filtering."""
1189  
1190      def test_only_safe_vars_passed(self):
1191          """Only safe baseline vars and XDG_* from os.environ are included."""
1192          from tools.mcp_tool import _build_safe_env
1193  
1194          fake_env = {
1195              "PATH": "/usr/bin",
1196              "HOME": "/home/test",
1197              "USER": "test",
1198              "LANG": "en_US.UTF-8",
1199              "LC_ALL": "C",
1200              "TERM": "xterm",
1201              "SHELL": "/bin/bash",
1202              "TMPDIR": "/tmp",
1203              "XDG_DATA_HOME": "/home/test/.local/share",
1204              "SECRET_KEY": "should_not_appear",
1205              "AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE",
1206          }
1207          with patch.dict("os.environ", fake_env, clear=True):
1208              result = _build_safe_env(None)
1209  
1210          # Safe vars present
1211          assert result["PATH"] == "/usr/bin"
1212          assert result["HOME"] == "/home/test"
1213          assert result["USER"] == "test"
1214          assert result["LANG"] == "en_US.UTF-8"
1215          assert result["XDG_DATA_HOME"] == "/home/test/.local/share"
1216          # Unsafe vars excluded
1217          assert "SECRET_KEY" not in result
1218          assert "AWS_ACCESS_KEY_ID" not in result
1219  
1220      def test_user_env_merged(self):
1221          """User-specified env vars are merged into the safe env."""
1222          from tools.mcp_tool import _build_safe_env
1223  
1224          with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
1225              result = _build_safe_env({"MY_CUSTOM_VAR": "hello"})
1226  
1227          assert result["PATH"] == "/usr/bin"
1228          assert result["MY_CUSTOM_VAR"] == "hello"
1229  
1230      def test_user_env_overrides_safe(self):
1231          """User env can override safe defaults."""
1232          from tools.mcp_tool import _build_safe_env
1233  
1234          with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True):
1235              result = _build_safe_env({"PATH": "/custom/bin"})
1236  
1237          assert result["PATH"] == "/custom/bin"
1238  
1239      def test_none_user_env(self):
1240          """None user_env still returns safe vars from os.environ."""
1241          from tools.mcp_tool import _build_safe_env
1242  
1243          with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True):
1244              result = _build_safe_env(None)
1245  
1246          assert isinstance(result, dict)
1247          assert result["PATH"] == "/usr/bin"
1248          assert result["HOME"] == "/root"
1249  
1250      def test_secret_vars_excluded(self):
1251          """Sensitive env vars from os.environ are NOT passed through."""
1252          from tools.mcp_tool import _build_safe_env
1253  
1254          fake_env = {
1255              "PATH": "/usr/bin",
1256              "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
1257              "GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
1258              "OPENAI_API_KEY": "sk-proj-abc123",
1259              "DATABASE_URL": "postgres://user:pass@localhost/db",
1260              "API_SECRET": "supersecret",
1261          }
1262          with patch.dict("os.environ", fake_env, clear=True):
1263              result = _build_safe_env(None)
1264  
1265          assert "PATH" in result
1266          assert "AWS_SECRET_ACCESS_KEY" not in result
1267          assert "GITHUB_TOKEN" not in result
1268          assert "OPENAI_API_KEY" not in result
1269          assert "DATABASE_URL" not in result
1270          assert "API_SECRET" not in result
1271  
1272  
1273  # ---------------------------------------------------------------------------
1274  # _sanitize_error
1275  # ---------------------------------------------------------------------------
1276  
1277  class TestSanitizeError:
1278      """Tests for _sanitize_error() credential stripping."""
1279  
1280      def test_strips_github_pat(self):
1281          from tools.mcp_tool import _sanitize_error
1282          result = _sanitize_error("Error with ghp_abc123def456")
1283          assert result == "Error with [REDACTED]"
1284  
1285      def test_strips_openai_key(self):
1286          from tools.mcp_tool import _sanitize_error
1287          result = _sanitize_error("key sk-projABC123xyz")
1288          assert result == "key [REDACTED]"
1289  
1290      def test_strips_bearer_token(self):
1291          from tools.mcp_tool import _sanitize_error
1292          result = _sanitize_error("Authorization: Bearer eyJabc123def")
1293          assert result == "Authorization: [REDACTED]"
1294  
1295      def test_strips_token_param(self):
1296          from tools.mcp_tool import _sanitize_error
1297          result = _sanitize_error("url?token=secret123")
1298          assert result == "url?[REDACTED]"
1299  
1300      def test_no_credentials_unchanged(self):
1301          from tools.mcp_tool import _sanitize_error
1302          result = _sanitize_error("normal error message")
1303          assert result == "normal error message"
1304  
1305      def test_multiple_credentials(self):
1306          from tools.mcp_tool import _sanitize_error
1307          result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo")
1308          assert "ghp_" not in result
1309          assert "sk-" not in result
1310          assert "token=" not in result
1311          assert result.count("[REDACTED]") == 3
1312  
1313  
1314  # ---------------------------------------------------------------------------
1315  # HTTP config
1316  # ---------------------------------------------------------------------------
1317  
1318  class TestHTTPConfig:
1319      """Tests for HTTP transport detection and handling."""
1320  
1321      def test_is_http_with_url(self):
1322          from tools.mcp_tool import MCPServerTask
1323          server = MCPServerTask("remote")
1324          server._config = {"url": "https://example.com/mcp"}
1325          assert server._is_http() is True
1326  
1327      def test_is_stdio_with_command(self):
1328          from tools.mcp_tool import MCPServerTask
1329          server = MCPServerTask("local")
1330          server._config = {"command": "npx", "args": []}
1331          assert server._is_http() is False
1332  
1333      def test_conflicting_url_and_command_warns(self):
1334          """Config with both url and command logs a warning and uses HTTP."""
1335          from tools.mcp_tool import MCPServerTask
1336          server = MCPServerTask("conflict")
1337          config = {"url": "https://example.com/mcp", "command": "npx", "args": []}
1338          # url takes precedence
1339          server._config = config
1340          assert server._is_http() is True
1341  
1342      def test_http_unavailable_raises(self):
1343          from tools.mcp_tool import MCPServerTask
1344  
1345          server = MCPServerTask("remote")
1346          config = {"url": "https://example.com/mcp"}
1347  
1348          async def _test():
1349              with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False):
1350                  with pytest.raises(ImportError, match="HTTP transport"):
1351                      await server._run_http(config)
1352  
1353          asyncio.run(_test())
1354  
1355      def test_http_seeds_initial_protocol_header(self):
1356          from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask
1357  
1358          server = MCPServerTask("remote")
1359          captured = {}
1360  
1361          class DummyAsyncClient:
1362              def __init__(self, **kwargs):
1363                  captured.update(kwargs)
1364  
1365              async def __aenter__(self):
1366                  return self
1367  
1368              async def __aexit__(self, exc_type, exc, tb):
1369                  return False
1370  
1371          class DummyTransportCtx:
1372              async def __aenter__(self):
1373                  return MagicMock(), MagicMock(), (lambda: None)
1374  
1375              async def __aexit__(self, exc_type, exc, tb):
1376                  return False
1377  
1378          class DummySession:
1379              def __init__(self, *args, **kwargs):
1380                  pass
1381  
1382              async def __aenter__(self):
1383                  return self
1384  
1385              async def __aexit__(self, exc_type, exc, tb):
1386                  return False
1387  
1388              async def initialize(self):
1389                  return None
1390  
1391          class DummyLegacyTransportCtx:
1392              def __init__(self, **kwargs):
1393                  captured["legacy_headers"] = kwargs.get("headers")
1394  
1395              async def __aenter__(self):
1396                  return MagicMock(), MagicMock(), (lambda: None)
1397  
1398              async def __aexit__(self, exc_type, exc, tb):
1399                  return False
1400  
1401          async def _discover_tools(self):
1402              self._shutdown_event.set()
1403  
1404          async def _run(config, *, new_http):
1405              captured.clear()
1406              with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \
1407                   patch("tools.mcp_tool._MCP_NEW_HTTP", new_http), \
1408                   patch("httpx.AsyncClient", DummyAsyncClient), \
1409                   patch("tools.mcp_tool.streamable_http_client", return_value=DummyTransportCtx()), \
1410                   patch("tools.mcp_tool.streamablehttp_client", side_effect=lambda url, **kwargs: DummyLegacyTransportCtx(**kwargs)), \
1411                   patch("tools.mcp_tool.ClientSession", DummySession), \
1412                   patch.object(MCPServerTask, "_discover_tools", _discover_tools):
1413                  await server._run_http(config)
1414  
1415          asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=True))
1416          assert captured["headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
1417  
1418          asyncio.run(_run({
1419              "url": "https://example.com/mcp",
1420              "headers": {"mcp-protocol-version": "custom-version"},
1421          }, new_http=True))
1422          assert captured["headers"]["mcp-protocol-version"] == "custom-version"
1423  
1424          asyncio.run(_run({
1425              "url": "https://example.com/mcp",
1426              "headers": {"MCP-Protocol-Version": "custom-version"},
1427          }, new_http=True))
1428          assert captured["headers"]["MCP-Protocol-Version"] == "custom-version"
1429          assert "mcp-protocol-version" not in captured["headers"]
1430  
1431          asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=False))
1432          assert captured["legacy_headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION
1433  
1434          asyncio.run(_run({
1435              "url": "https://example.com/mcp",
1436              "headers": {"MCP-Protocol-Version": "custom-version"},
1437          }, new_http=False))
1438          assert captured["legacy_headers"]["MCP-Protocol-Version"] == "custom-version"
1439          assert "mcp-protocol-version" not in captured["legacy_headers"]
1440  
1441  
1442  # ---------------------------------------------------------------------------
1443  # Reconnection logic
1444  # ---------------------------------------------------------------------------
1445  
1446  class TestReconnection:
1447      """Tests for automatic reconnection behavior in MCPServerTask.run()."""
1448  
1449      def test_reconnect_on_disconnect(self):
1450          """After initial success, a connection drop triggers reconnection."""
1451          from tools.mcp_tool import MCPServerTask
1452  
1453          run_count = 0
1454          target_server = None
1455  
1456          original_run_stdio = MCPServerTask._run_stdio
1457  
1458          async def patched_run_stdio(self_srv, config):
1459              nonlocal run_count, target_server
1460              run_count += 1
1461              if target_server is not self_srv:
1462                  return await original_run_stdio(self_srv, config)
1463              if run_count == 1:
1464                  # First connection succeeds, then simulate disconnect
1465                  self_srv.session = MagicMock()
1466                  self_srv._tools = []
1467                  self_srv._ready.set()
1468                  raise ConnectionError("connection dropped")
1469              else:
1470                  # Reconnection succeeds; signal shutdown so run() exits
1471                  self_srv.session = MagicMock()
1472                  self_srv._shutdown_event.set()
1473                  await self_srv._shutdown_event.wait()
1474  
1475          async def _test():
1476              nonlocal target_server
1477              server = MCPServerTask("test_srv")
1478              target_server = server
1479  
1480              with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
1481                   patch("asyncio.sleep", new_callable=AsyncMock):
1482                  await server.run({"command": "test"})
1483  
1484              assert run_count >= 2  # At least one reconnection attempt
1485  
1486          asyncio.run(_test())
1487  
1488      def test_no_reconnect_on_shutdown(self):
1489          """If shutdown is requested, don't attempt reconnection."""
1490          from tools.mcp_tool import MCPServerTask
1491  
1492          run_count = 0
1493          target_server = None
1494  
1495          original_run_stdio = MCPServerTask._run_stdio
1496  
1497          async def patched_run_stdio(self_srv, config):
1498              nonlocal run_count, target_server
1499              run_count += 1
1500              if target_server is not self_srv:
1501                  return await original_run_stdio(self_srv, config)
1502              self_srv.session = MagicMock()
1503              self_srv._tools = []
1504              self_srv._ready.set()
1505              raise ConnectionError("connection dropped")
1506  
1507          async def _test():
1508              nonlocal target_server
1509              server = MCPServerTask("test_srv")
1510              target_server = server
1511              server._shutdown_event.set()  # Shutdown already requested
1512  
1513              with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
1514                   patch("asyncio.sleep", new_callable=AsyncMock):
1515                  await server.run({"command": "test"})
1516  
1517              # Should not retry because shutdown was set
1518              assert run_count == 1
1519  
1520          asyncio.run(_test())
1521  
1522      def test_no_reconnect_on_initial_failure(self):
1523          """First connection failure retries up to _MAX_INITIAL_CONNECT_RETRIES times.
1524  
1525          Before the MCP resilience fix, initial failures gave up immediately.
1526          Now they retry with backoff to handle transient DNS/network blips.
1527          """
1528          from tools.mcp_tool import MCPServerTask, _MAX_INITIAL_CONNECT_RETRIES
1529  
1530          run_count = 0
1531          target_server = None
1532  
1533          original_run_stdio = MCPServerTask._run_stdio
1534  
1535          async def patched_run_stdio(self_srv, config):
1536              nonlocal run_count, target_server
1537              run_count += 1
1538              if target_server is not self_srv:
1539                  return await original_run_stdio(self_srv, config)
1540              raise ConnectionError("cannot connect")
1541  
1542          async def _test():
1543              nonlocal target_server
1544              server = MCPServerTask("test_srv")
1545              target_server = server
1546  
1547              with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \
1548                   patch("asyncio.sleep", new_callable=AsyncMock):
1549                  await server.run({"command": "test"})
1550  
1551              # Now retries up to _MAX_INITIAL_CONNECT_RETRIES before giving up
1552              assert run_count == _MAX_INITIAL_CONNECT_RETRIES + 1
1553              assert server._error is not None
1554              assert "cannot connect" in str(server._error)
1555  
1556          asyncio.run(_test())
1557  
1558  
1559  # ---------------------------------------------------------------------------
1560  # Configurable timeouts
1561  # ---------------------------------------------------------------------------
1562  
1563  class TestConfigurableTimeouts:
1564      """Tests for configurable per-server timeouts."""
1565  
1566      def test_default_timeout(self):
1567          """Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT."""
1568          from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT
1569  
1570          server = MCPServerTask("test_srv")
1571          assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT
1572          assert server.tool_timeout == 120
1573  
1574      def test_custom_timeout(self):
1575          """Server with timeout=180 in config gets 180."""
1576          from tools.mcp_tool import MCPServerTask
1577  
1578          target_server = None
1579  
1580          original_run_stdio = MCPServerTask._run_stdio
1581  
1582          async def patched_run_stdio(self_srv, config):
1583              if target_server is not self_srv:
1584                  return await original_run_stdio(self_srv, config)
1585              self_srv.session = MagicMock()
1586              self_srv._tools = []
1587              self_srv._ready.set()
1588              await self_srv._shutdown_event.wait()
1589  
1590          async def _test():
1591              nonlocal target_server
1592              server = MCPServerTask("test_srv")
1593              target_server = server
1594  
1595              with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio):
1596                  task = asyncio.ensure_future(
1597                      server.run({"command": "test", "timeout": 180})
1598                  )
1599                  await server._ready.wait()
1600                  assert server.tool_timeout == 180
1601                  server._shutdown_event.set()
1602                  await task
1603  
1604          asyncio.run(_test())
1605  
1606      def test_timeout_passed_to_handler(self):
1607          """The tool handler uses the server's configured timeout."""
1608          from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask
1609  
1610          mock_session = MagicMock()
1611          mock_session.call_tool = AsyncMock(
1612              return_value=_make_call_result("ok", is_error=False)
1613          )
1614          server = _make_mock_server("test_srv", session=mock_session)
1615          server.tool_timeout = 180
1616          _servers["test_srv"] = server
1617  
1618          try:
1619              handler = _make_tool_handler("test_srv", "my_tool", 180)
1620              with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run:
1621                  def fake_run(coro, timeout=30):
1622                      coro.close()
1623                      return json.dumps({"result": "ok"})
1624  
1625                  mock_run.side_effect = fake_run
1626                  handler({})
1627                  # Verify timeout=180 was passed
1628                  call_kwargs = mock_run.call_args
1629                  assert call_kwargs.kwargs.get("timeout") == 180 or \
1630                         (len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \
1631                         call_kwargs[1].get("timeout") == 180
1632          finally:
1633              _servers.pop("test_srv", None)
1634  
1635  
1636  # ---------------------------------------------------------------------------
1637  # Utility tool schemas (Resources & Prompts)
1638  # ---------------------------------------------------------------------------
1639  
1640  class TestUtilitySchemas:
1641      """Tests for _build_utility_schemas() and the schema format of utility tools."""
1642  
1643      def test_builds_four_utility_schemas(self):
1644          from tools.mcp_tool import _build_utility_schemas
1645  
1646          schemas = _build_utility_schemas("myserver")
1647          assert len(schemas) == 4
1648          names = [s["schema"]["name"] for s in schemas]
1649          assert "mcp_myserver_list_resources" in names
1650          assert "mcp_myserver_read_resource" in names
1651          assert "mcp_myserver_list_prompts" in names
1652          assert "mcp_myserver_get_prompt" in names
1653  
1654      def test_hyphens_sanitized_in_utility_names(self):
1655          from tools.mcp_tool import _build_utility_schemas
1656  
1657          schemas = _build_utility_schemas("my-server")
1658          names = [s["schema"]["name"] for s in schemas]
1659          for name in names:
1660              assert "-" not in name
1661          assert "mcp_my_server_list_resources" in names
1662  
1663      def test_list_resources_schema_no_required_params(self):
1664          from tools.mcp_tool import _build_utility_schemas
1665  
1666          schemas = _build_utility_schemas("srv")
1667          lr = next(s for s in schemas if s["handler_key"] == "list_resources")
1668          params = lr["schema"]["parameters"]
1669          assert params["type"] == "object"
1670          assert params["properties"] == {}
1671          assert "required" not in params
1672  
1673      def test_read_resource_schema_requires_uri(self):
1674          from tools.mcp_tool import _build_utility_schemas
1675  
1676          schemas = _build_utility_schemas("srv")
1677          rr = next(s for s in schemas if s["handler_key"] == "read_resource")
1678          params = rr["schema"]["parameters"]
1679          assert "uri" in params["properties"]
1680          assert params["properties"]["uri"]["type"] == "string"
1681          assert params["required"] == ["uri"]
1682  
1683      def test_list_prompts_schema_no_required_params(self):
1684          from tools.mcp_tool import _build_utility_schemas
1685  
1686          schemas = _build_utility_schemas("srv")
1687          lp = next(s for s in schemas if s["handler_key"] == "list_prompts")
1688          params = lp["schema"]["parameters"]
1689          assert params["type"] == "object"
1690          assert params["properties"] == {}
1691          assert "required" not in params
1692  
1693      def test_get_prompt_schema_requires_name(self):
1694          from tools.mcp_tool import _build_utility_schemas
1695  
1696          schemas = _build_utility_schemas("srv")
1697          gp = next(s for s in schemas if s["handler_key"] == "get_prompt")
1698          params = gp["schema"]["parameters"]
1699          assert "name" in params["properties"]
1700          assert params["properties"]["name"]["type"] == "string"
1701          assert "arguments" in params["properties"]
1702          assert params["properties"]["arguments"]["type"] == "object"
1703          assert params["required"] == ["name"]
1704  
1705      def test_schemas_have_descriptions(self):
1706          from tools.mcp_tool import _build_utility_schemas
1707  
1708          schemas = _build_utility_schemas("test_srv")
1709          for entry in schemas:
1710              desc = entry["schema"]["description"]
1711              assert desc and len(desc) > 0
1712              assert "test_srv" in desc
1713  
1714  
1715  # ---------------------------------------------------------------------------
1716  # Utility tool handlers (Resources & Prompts)
1717  # ---------------------------------------------------------------------------
1718  
1719  class TestUtilityHandlers:
1720      """Tests for the MCP Resources & Prompts handler functions."""
1721  
1722      def _patch_mcp_loop(self):
1723          """Return a patch for _run_on_mcp_loop that runs the coroutine directly."""
1724          def fake_run(coro, timeout=30):
1725              return asyncio.run(coro)
1726          return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run)
1727  
1728      # -- list_resources --
1729  
1730      def test_list_resources_success(self):
1731          from tools.mcp_tool import _make_list_resources_handler, _servers
1732  
1733          mock_resource = SimpleNamespace(
1734              uri="file:///tmp/test.txt", name="test.txt",
1735              description="A test file", mimeType="text/plain",
1736          )
1737          mock_session = MagicMock()
1738          mock_session.list_resources = AsyncMock(
1739              return_value=SimpleNamespace(resources=[mock_resource])
1740          )
1741          server = _make_mock_server("srv", session=mock_session)
1742          _servers["srv"] = server
1743  
1744          try:
1745              handler = _make_list_resources_handler("srv", 120)
1746              with self._patch_mcp_loop():
1747                  result = json.loads(handler({}))
1748              assert "resources" in result
1749              assert len(result["resources"]) == 1
1750              assert result["resources"][0]["uri"] == "file:///tmp/test.txt"
1751              assert result["resources"][0]["name"] == "test.txt"
1752          finally:
1753              _servers.pop("srv", None)
1754  
1755      def test_list_resources_empty(self):
1756          from tools.mcp_tool import _make_list_resources_handler, _servers
1757  
1758          mock_session = MagicMock()
1759          mock_session.list_resources = AsyncMock(
1760              return_value=SimpleNamespace(resources=[])
1761          )
1762          server = _make_mock_server("srv", session=mock_session)
1763          _servers["srv"] = server
1764  
1765          try:
1766              handler = _make_list_resources_handler("srv", 120)
1767              with self._patch_mcp_loop():
1768                  result = json.loads(handler({}))
1769              assert result["resources"] == []
1770          finally:
1771              _servers.pop("srv", None)
1772  
1773      def test_list_resources_disconnected(self):
1774          from tools.mcp_tool import _make_list_resources_handler, _servers
1775          _servers.pop("ghost", None)
1776          handler = _make_list_resources_handler("ghost", 120)
1777          result = json.loads(handler({}))
1778          assert "error" in result
1779          assert "not connected" in result["error"]
1780  
1781      # -- read_resource --
1782  
1783      def test_read_resource_success(self):
1784          from tools.mcp_tool import _make_read_resource_handler, _servers
1785  
1786          content_block = SimpleNamespace(text="Hello from resource")
1787          mock_session = MagicMock()
1788          mock_session.read_resource = AsyncMock(
1789              return_value=SimpleNamespace(contents=[content_block])
1790          )
1791          server = _make_mock_server("srv", session=mock_session)
1792          _servers["srv"] = server
1793  
1794          try:
1795              handler = _make_read_resource_handler("srv", 120)
1796              with self._patch_mcp_loop():
1797                  result = json.loads(handler({"uri": "file:///tmp/test.txt"}))
1798              assert result["result"] == "Hello from resource"
1799              mock_session.read_resource.assert_called_once_with("file:///tmp/test.txt")
1800          finally:
1801              _servers.pop("srv", None)
1802  
1803      def test_read_resource_missing_uri(self):
1804          from tools.mcp_tool import _make_read_resource_handler, _servers
1805  
1806          server = _make_mock_server("srv", session=MagicMock())
1807          _servers["srv"] = server
1808  
1809          try:
1810              handler = _make_read_resource_handler("srv", 120)
1811              result = json.loads(handler({}))
1812              assert "error" in result
1813              assert "uri" in result["error"].lower()
1814          finally:
1815              _servers.pop("srv", None)
1816  
1817      def test_read_resource_disconnected(self):
1818          from tools.mcp_tool import _make_read_resource_handler, _servers
1819          _servers.pop("ghost", None)
1820          handler = _make_read_resource_handler("ghost", 120)
1821          result = json.loads(handler({"uri": "test://x"}))
1822          assert "error" in result
1823          assert "not connected" in result["error"]
1824  
1825      # -- list_prompts --
1826  
1827      def test_list_prompts_success(self):
1828          from tools.mcp_tool import _make_list_prompts_handler, _servers
1829  
1830          mock_prompt = SimpleNamespace(
1831              name="summarize", description="Summarize text",
1832              arguments=[
1833                  SimpleNamespace(name="text", description="Text to summarize", required=True),
1834              ],
1835          )
1836          mock_session = MagicMock()
1837          mock_session.list_prompts = AsyncMock(
1838              return_value=SimpleNamespace(prompts=[mock_prompt])
1839          )
1840          server = _make_mock_server("srv", session=mock_session)
1841          _servers["srv"] = server
1842  
1843          try:
1844              handler = _make_list_prompts_handler("srv", 120)
1845              with self._patch_mcp_loop():
1846                  result = json.loads(handler({}))
1847              assert "prompts" in result
1848              assert len(result["prompts"]) == 1
1849              assert result["prompts"][0]["name"] == "summarize"
1850              assert result["prompts"][0]["arguments"][0]["name"] == "text"
1851          finally:
1852              _servers.pop("srv", None)
1853  
1854      def test_list_prompts_empty(self):
1855          from tools.mcp_tool import _make_list_prompts_handler, _servers
1856  
1857          mock_session = MagicMock()
1858          mock_session.list_prompts = AsyncMock(
1859              return_value=SimpleNamespace(prompts=[])
1860          )
1861          server = _make_mock_server("srv", session=mock_session)
1862          _servers["srv"] = server
1863  
1864          try:
1865              handler = _make_list_prompts_handler("srv", 120)
1866              with self._patch_mcp_loop():
1867                  result = json.loads(handler({}))
1868              assert result["prompts"] == []
1869          finally:
1870              _servers.pop("srv", None)
1871  
1872      def test_list_prompts_disconnected(self):
1873          from tools.mcp_tool import _make_list_prompts_handler, _servers
1874          _servers.pop("ghost", None)
1875          handler = _make_list_prompts_handler("ghost", 120)
1876          result = json.loads(handler({}))
1877          assert "error" in result
1878          assert "not connected" in result["error"]
1879  
1880      # -- get_prompt --
1881  
1882      def test_get_prompt_success(self):
1883          from tools.mcp_tool import _make_get_prompt_handler, _servers
1884  
1885          mock_msg = SimpleNamespace(
1886              role="assistant",
1887              content=SimpleNamespace(text="Here is a summary of your text."),
1888          )
1889          mock_session = MagicMock()
1890          mock_session.get_prompt = AsyncMock(
1891              return_value=SimpleNamespace(messages=[mock_msg], description=None)
1892          )
1893          server = _make_mock_server("srv", session=mock_session)
1894          _servers["srv"] = server
1895  
1896          try:
1897              handler = _make_get_prompt_handler("srv", 120)
1898              with self._patch_mcp_loop():
1899                  result = json.loads(handler({"name": "summarize", "arguments": {"text": "hello"}}))
1900              assert "messages" in result
1901              assert len(result["messages"]) == 1
1902              assert result["messages"][0]["role"] == "assistant"
1903              assert "summary" in result["messages"][0]["content"].lower()
1904              mock_session.get_prompt.assert_called_once_with(
1905                  "summarize", arguments={"text": "hello"}
1906              )
1907          finally:
1908              _servers.pop("srv", None)
1909  
1910      def test_get_prompt_missing_name(self):
1911          from tools.mcp_tool import _make_get_prompt_handler, _servers
1912  
1913          server = _make_mock_server("srv", session=MagicMock())
1914          _servers["srv"] = server
1915  
1916          try:
1917              handler = _make_get_prompt_handler("srv", 120)
1918              result = json.loads(handler({}))
1919              assert "error" in result
1920              assert "name" in result["error"].lower()
1921          finally:
1922              _servers.pop("srv", None)
1923  
1924      def test_get_prompt_disconnected(self):
1925          from tools.mcp_tool import _make_get_prompt_handler, _servers
1926          _servers.pop("ghost", None)
1927          handler = _make_get_prompt_handler("ghost", 120)
1928          result = json.loads(handler({"name": "test"}))
1929          assert "error" in result
1930          assert "not connected" in result["error"]
1931  
1932      def test_get_prompt_default_arguments(self):
1933          from tools.mcp_tool import _make_get_prompt_handler, _servers
1934  
1935          mock_session = MagicMock()
1936          mock_session.get_prompt = AsyncMock(
1937              return_value=SimpleNamespace(messages=[], description=None)
1938          )
1939          server = _make_mock_server("srv", session=mock_session)
1940          _servers["srv"] = server
1941  
1942          try:
1943              handler = _make_get_prompt_handler("srv", 120)
1944              with self._patch_mcp_loop():
1945                  handler({"name": "test_prompt"})
1946              # arguments defaults to {} when not provided
1947              mock_session.get_prompt.assert_called_once_with(
1948                  "test_prompt", arguments={}
1949              )
1950          finally:
1951              _servers.pop("srv", None)
1952  
1953  
1954  # ---------------------------------------------------------------------------
1955  # Utility tools registration in _discover_and_register_server
1956  # ---------------------------------------------------------------------------
1957  
1958  class TestUtilityToolRegistration:
1959      """Verify utility tools are registered alongside regular MCP tools."""
1960  
1961      def test_utility_tools_registered(self):
1962          """_discover_and_register_server registers all 4 utility tools."""
1963          from tools.registry import ToolRegistry
1964          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
1965  
1966          mock_registry = ToolRegistry()
1967          mock_tools = [_make_mcp_tool("read_file", "Read a file")]
1968          mock_session = MagicMock()
1969  
1970          async def fake_connect(name, config):
1971              server = MCPServerTask(name)
1972              server.session = mock_session
1973              server._tools = mock_tools
1974              return server
1975  
1976          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
1977               patch("tools.registry.registry", mock_registry):
1978              registered = asyncio.run(
1979                  _discover_and_register_server("fs", {"command": "npx", "args": []})
1980              )
1981  
1982          # Regular tool + 4 utility tools
1983          assert "mcp_fs_read_file" in registered
1984          assert "mcp_fs_list_resources" in registered
1985          assert "mcp_fs_read_resource" in registered
1986          assert "mcp_fs_list_prompts" in registered
1987          assert "mcp_fs_get_prompt" in registered
1988          assert len(registered) == 5
1989  
1990          # All in the registry
1991          all_names = mock_registry.get_all_tool_names()
1992          for name in registered:
1993              assert name in all_names
1994  
1995          _servers.pop("fs", None)
1996  
1997      def test_utility_tools_in_same_toolset(self):
1998          """Utility tools belong to the same mcp-{server} toolset."""
1999          from tools.registry import ToolRegistry
2000          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
2001  
2002          mock_registry = ToolRegistry()
2003          mock_session = MagicMock()
2004  
2005          async def fake_connect(name, config):
2006              server = MCPServerTask(name)
2007              server.session = mock_session
2008              server._tools = []
2009              return server
2010  
2011          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
2012               patch("tools.registry.registry", mock_registry):
2013              asyncio.run(
2014                  _discover_and_register_server("myserv", {"command": "test"})
2015              )
2016  
2017          # Check that utility tools are in the right toolset
2018          for tool_name in ["mcp_myserv_list_resources", "mcp_myserv_read_resource",
2019                            "mcp_myserv_list_prompts", "mcp_myserv_get_prompt"]:
2020              entry = mock_registry._tools.get(tool_name)
2021              assert entry is not None, f"{tool_name} not found in registry"
2022              assert entry.toolset == "mcp-myserv"
2023  
2024          _servers.pop("myserv", None)
2025  
2026      def test_utility_tools_have_check_fn(self):
2027          """Utility tools have a working check_fn."""
2028          from tools.registry import ToolRegistry
2029          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
2030  
2031          mock_registry = ToolRegistry()
2032          mock_session = MagicMock()
2033  
2034          async def fake_connect(name, config):
2035              server = MCPServerTask(name)
2036              server.session = mock_session
2037              server._tools = []
2038              return server
2039  
2040          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
2041               patch("tools.registry.registry", mock_registry):
2042              asyncio.run(
2043                  _discover_and_register_server("chk", {"command": "test"})
2044              )
2045  
2046          entry = mock_registry._tools.get("mcp_chk_list_resources")
2047          assert entry is not None
2048          # Server is connected, check_fn should return True
2049          assert entry.check_fn() is True
2050  
2051          # Disconnect the server
2052          _servers["chk"].session = None
2053          assert entry.check_fn() is False
2054  
2055          _servers.pop("chk", None)
2056  
2057  
2058  # ===========================================================================
2059  # SamplingHandler tests
2060  # ===========================================================================
2061  
2062  import math
2063  import time
2064  
2065  class _CompatType:
2066      def __init__(self, **kwargs):
2067          self.__dict__.update(kwargs)
2068  
2069  
2070  try:
2071      from mcp.types import (
2072          CreateMessageResult,
2073          ErrorData,
2074          SamplingCapability,
2075          TextContent,
2076      )
2077  except ImportError:
2078      CreateMessageResult = _CompatType
2079      ErrorData = _CompatType
2080      SamplingCapability = _CompatType
2081      TextContent = _CompatType
2082  
2083  try:
2084      from mcp.types import CreateMessageResultWithTools
2085  except ImportError:
2086      CreateMessageResultWithTools = _CompatType
2087  
2088  try:
2089      from mcp.types import SamplingToolsCapability
2090  except ImportError:
2091      SamplingToolsCapability = _CompatType
2092  
2093  try:
2094      from mcp.types import ToolUseContent
2095  except ImportError:
2096      ToolUseContent = _CompatType
2097  
2098  from tools.mcp_tool import (
2099      CreateMessageResultWithTools,
2100      SamplingHandler,
2101      SamplingToolsCapability,
2102      ToolUseContent,
2103      _safe_numeric,
2104  )
2105  
2106  
2107  # ---------------------------------------------------------------------------
2108  # Helpers for sampling tests
2109  # ---------------------------------------------------------------------------
2110  
2111  def _make_sampling_params(
2112      messages=None,
2113      max_tokens=100,
2114      system_prompt=None,
2115      model_preferences=None,
2116      temperature=None,
2117      stop_sequences=None,
2118      tools=None,
2119      tool_choice=None,
2120  ):
2121      """Create a fake CreateMessageRequestParams using SimpleNamespace.
2122  
2123      Each message must have a ``content_as_list`` attribute that mirrors
2124      the SDK helper so that ``_convert_messages`` works correctly.
2125      """
2126      if messages is None:
2127          content = SimpleNamespace(text="Hello")
2128          msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
2129          messages = [msg]
2130  
2131      params = SimpleNamespace(
2132          messages=messages,
2133          maxTokens=max_tokens,
2134          modelPreferences=model_preferences,
2135          temperature=temperature,
2136          stopSequences=stop_sequences,
2137          tools=tools,
2138          toolChoice=tool_choice,
2139      )
2140      if system_prompt is not None:
2141          params.systemPrompt = system_prompt
2142      return params
2143  
2144  
2145  def _make_llm_response(
2146      content="LLM response",
2147      model="test-model",
2148      finish_reason="stop",
2149      tool_calls=None,
2150  ):
2151      """Create a fake OpenAI chat completion response (text)."""
2152      message = SimpleNamespace(content=content, tool_calls=tool_calls)
2153      choice = SimpleNamespace(
2154          finish_reason=finish_reason,
2155          message=message,
2156      )
2157      usage = SimpleNamespace(total_tokens=42)
2158      return SimpleNamespace(choices=[choice], model=model, usage=usage)
2159  
2160  
2161  def _make_llm_tool_response(tool_calls_data=None, model="test-model"):
2162      """Create a fake response with tool_calls.
2163  
2164      ``tool_calls_data``: list of (id, name, arguments_json) tuples.
2165      """
2166      if tool_calls_data is None:
2167          tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')]
2168  
2169      tc_list = [
2170          SimpleNamespace(
2171              id=tc_id,
2172              function=SimpleNamespace(name=name, arguments=args),
2173          )
2174          for tc_id, name, args in tool_calls_data
2175      ]
2176      return _make_llm_response(
2177          content=None,
2178          model=model,
2179          finish_reason="tool_calls",
2180          tool_calls=tc_list,
2181      )
2182  
2183  
2184  # ---------------------------------------------------------------------------
2185  # 1. _safe_numeric helper
2186  # ---------------------------------------------------------------------------
2187  
2188  class TestSafeNumeric:
2189      def test_int_passthrough(self):
2190          assert _safe_numeric(10, 5, int) == 10
2191  
2192      def test_string_coercion(self):
2193          assert _safe_numeric("20", 5, int) == 20
2194  
2195      def test_none_returns_default(self):
2196          assert _safe_numeric(None, 7, int) == 7
2197  
2198      def test_inf_returns_default(self):
2199          assert _safe_numeric(float("inf"), 3.0, float) == 3.0
2200  
2201      def test_nan_returns_default(self):
2202          assert _safe_numeric(float("nan"), 4.0, float) == 4.0
2203  
2204      def test_below_minimum_clamps(self):
2205          assert _safe_numeric(-5, 10, int, minimum=1) == 1
2206  
2207      def test_minimum_zero_allowed(self):
2208          assert _safe_numeric(0, 10, int, minimum=0) == 0
2209  
2210      def test_non_numeric_string_returns_default(self):
2211          assert _safe_numeric("abc", 42, int) == 42
2212  
2213      def test_float_coercion(self):
2214          assert _safe_numeric("3.5", 1.0, float) == 3.5
2215  
2216  
2217  # ---------------------------------------------------------------------------
2218  # 2. SamplingHandler initialization and config parsing
2219  # ---------------------------------------------------------------------------
2220  
2221  class TestSamplingHandlerInit:
2222      def test_defaults(self):
2223          h = SamplingHandler("srv", {})
2224          assert h.server_name == "srv"
2225          assert h.max_rpm == 10
2226          assert h.timeout == 30
2227          assert h.max_tokens_cap == 4096
2228          assert h.max_tool_rounds == 5
2229          assert h.model_override is None
2230          assert h.allowed_models == []
2231          assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
2232  
2233      def test_custom_config(self):
2234          cfg = {
2235              "max_rpm": 20,
2236              "timeout": 60,
2237              "max_tokens_cap": 2048,
2238              "max_tool_rounds": 3,
2239              "model": "gpt-4o",
2240              "allowed_models": ["gpt-4o", "gpt-3.5-turbo"],
2241              "log_level": "debug",
2242          }
2243          h = SamplingHandler("custom", cfg)
2244          assert h.max_rpm == 20
2245          assert h.timeout == 60.0
2246          assert h.max_tokens_cap == 2048
2247          assert h.max_tool_rounds == 3
2248          assert h.model_override == "gpt-4o"
2249          assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"]
2250  
2251      def test_string_numeric_config_values(self):
2252          """YAML sometimes delivers numeric values as strings."""
2253          cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"}
2254          h = SamplingHandler("s", cfg)
2255          assert h.max_rpm == 15
2256          assert h.timeout == 45.5
2257          assert h.max_tokens_cap == 1024
2258  
2259  
2260  # ---------------------------------------------------------------------------
2261  # 3. Rate limiting
2262  # ---------------------------------------------------------------------------
2263  
2264  class TestRateLimit:
2265      def setup_method(self):
2266          self.handler = SamplingHandler("rl", {"max_rpm": 3})
2267  
2268      def test_allows_under_limit(self):
2269          assert self.handler._check_rate_limit() is True
2270          assert self.handler._check_rate_limit() is True
2271          assert self.handler._check_rate_limit() is True
2272  
2273      def test_rejects_over_limit(self):
2274          for _ in range(3):
2275              self.handler._check_rate_limit()
2276          assert self.handler._check_rate_limit() is False
2277  
2278      def test_window_expiry(self):
2279          """Old timestamps should be purged from the sliding window."""
2280          for _ in range(3):
2281              self.handler._check_rate_limit()
2282          # Simulate timestamps from 61 seconds ago
2283          self.handler._rate_timestamps[:] = [time.time() - 61] * 3
2284          assert self.handler._check_rate_limit() is True
2285  
2286  
2287  # ---------------------------------------------------------------------------
2288  # 4. Model resolution
2289  # ---------------------------------------------------------------------------
2290  
2291  class TestResolveModel:
2292      def setup_method(self):
2293          self.handler = SamplingHandler("mr", {})
2294  
2295      def test_no_preference_no_override(self):
2296          assert self.handler._resolve_model(None) is None
2297  
2298      def test_config_override_wins(self):
2299          self.handler.model_override = "override-model"
2300          prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
2301          assert self.handler._resolve_model(prefs) == "override-model"
2302  
2303      def test_hint_used_when_no_override(self):
2304          prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
2305          assert self.handler._resolve_model(prefs) == "hint-model"
2306  
2307      def test_empty_hints(self):
2308          prefs = SimpleNamespace(hints=[])
2309          assert self.handler._resolve_model(prefs) is None
2310  
2311      def test_hint_without_name(self):
2312          prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)])
2313          assert self.handler._resolve_model(prefs) is None
2314  
2315  
2316  # ---------------------------------------------------------------------------
2317  # 5. Message conversion
2318  # ---------------------------------------------------------------------------
2319  
2320  class TestConvertMessages:
2321      def setup_method(self):
2322          self.handler = SamplingHandler("mc", {})
2323  
2324      def test_single_text_message(self):
2325          content = SimpleNamespace(text="Hello world")
2326          msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
2327          params = _make_sampling_params(messages=[msg])
2328          result = self.handler._convert_messages(params)
2329          assert len(result) == 1
2330          assert result[0] == {"role": "user", "content": "Hello world"}
2331  
2332      def test_image_message(self):
2333          text_block = SimpleNamespace(text="Look at this")
2334          img_block = SimpleNamespace(data="abc123", mimeType="image/png")
2335          msg = SimpleNamespace(
2336              role="user",
2337              content=[text_block, img_block],
2338              content_as_list=[text_block, img_block],
2339          )
2340          params = _make_sampling_params(messages=[msg])
2341          result = self.handler._convert_messages(params)
2342          assert len(result) == 1
2343          parts = result[0]["content"]
2344          assert len(parts) == 2
2345          assert parts[0] == {"type": "text", "text": "Look at this"}
2346          assert parts[1]["type"] == "image_url"
2347          assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"]
2348  
2349      def test_tool_result_message(self):
2350          inner = SimpleNamespace(text="42 degrees")
2351          tr_block = SimpleNamespace(toolUseId="call_1", content=[inner])
2352          msg = SimpleNamespace(
2353              role="user",
2354              content=[tr_block],
2355              content_as_list=[tr_block],
2356          )
2357          params = _make_sampling_params(messages=[msg])
2358          result = self.handler._convert_messages(params)
2359          assert len(result) == 1
2360          assert result[0]["role"] == "tool"
2361          assert result[0]["tool_call_id"] == "call_1"
2362          assert result[0]["content"] == "42 degrees"
2363  
2364      def test_tool_use_message(self):
2365          tu_block = SimpleNamespace(
2366              id="call_2", name="get_weather", input={"city": "London"}
2367          )
2368          msg = SimpleNamespace(
2369              role="assistant",
2370              content=[tu_block],
2371              content_as_list=[tu_block],
2372          )
2373          params = _make_sampling_params(messages=[msg])
2374          result = self.handler._convert_messages(params)
2375          assert len(result) == 1
2376          assert result[0]["role"] == "assistant"
2377          assert len(result[0]["tool_calls"]) == 1
2378          assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather"
2379          assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"}
2380  
2381      def test_mixed_text_and_tool_use(self):
2382          """Assistant message with both text and tool_calls."""
2383          text_block = SimpleNamespace(text="Let me check the weather")
2384          tu_block = SimpleNamespace(
2385              id="call_3", name="get_weather", input={"city": "Paris"}
2386          )
2387          msg = SimpleNamespace(
2388              role="assistant",
2389              content=[text_block, tu_block],
2390              content_as_list=[text_block, tu_block],
2391          )
2392          params = _make_sampling_params(messages=[msg])
2393          result = self.handler._convert_messages(params)
2394          assert len(result) == 1
2395          assert result[0]["content"] == "Let me check the weather"
2396          assert len(result[0]["tool_calls"]) == 1
2397  
2398      def test_fallback_without_content_as_list(self):
2399          """When content_as_list is absent, falls back to content."""
2400          content = SimpleNamespace(text="Fallback text")
2401          msg = SimpleNamespace(role="user", content=content)
2402          params = _make_sampling_params(messages=[msg])
2403          result = self.handler._convert_messages(params)
2404          assert len(result) == 1
2405          assert result[0]["content"] == "Fallback text"
2406  
2407  
2408  # ---------------------------------------------------------------------------
2409  # 6. Text-only sampling callback (full flow)
2410  # ---------------------------------------------------------------------------
2411  
2412  class TestSamplingCallbackText:
2413      def setup_method(self):
2414          self.handler = SamplingHandler("txt", {})
2415  
2416      def test_text_response(self):
2417          """Full flow: text response returns CreateMessageResult."""
2418          fake_client = MagicMock()
2419          fake_client.chat.completions.create.return_value = _make_llm_response(
2420              content="Hello from LLM"
2421          )
2422  
2423          with patch(
2424              "agent.auxiliary_client.call_llm",
2425              return_value=fake_client.chat.completions.create.return_value,
2426          ):
2427              params = _make_sampling_params()
2428              result = asyncio.run(self.handler(None, params))
2429  
2430          assert isinstance(result, CreateMessageResult)
2431          assert isinstance(result.content, TextContent)
2432          assert result.content.text == "Hello from LLM"
2433          assert result.model == "test-model"
2434          assert result.role == "assistant"
2435          assert result.stopReason == "endTurn"
2436  
2437      def test_system_prompt_prepended(self):
2438          """System prompt is inserted as the first message."""
2439          fake_client = MagicMock()
2440          fake_client.chat.completions.create.return_value = _make_llm_response()
2441  
2442          with patch(
2443              "agent.auxiliary_client.call_llm",
2444              return_value=fake_client.chat.completions.create.return_value,
2445          ) as mock_call:
2446              params = _make_sampling_params(system_prompt="Be helpful")
2447              asyncio.run(self.handler(None, params))
2448  
2449          call_args = mock_call.call_args
2450          messages = call_args.kwargs["messages"]
2451          assert messages[0] == {"role": "system", "content": "Be helpful"}
2452  
2453      def test_server_tools_with_object_schema_are_normalized(self):
2454          """Server-provided tools should gain empty properties for object schemas."""
2455          fake_client = MagicMock()
2456          fake_client.chat.completions.create.return_value = _make_llm_response()
2457          server_tool = SimpleNamespace(
2458              name="ask",
2459              description="Ask Crawl4AI",
2460              inputSchema={"type": "object"},
2461          )
2462  
2463          with patch(
2464              "agent.auxiliary_client.call_llm",
2465              return_value=fake_client.chat.completions.create.return_value,
2466          ) as mock_call:
2467              params = _make_sampling_params(tools=[server_tool])
2468              asyncio.run(self.handler(None, params))
2469  
2470          tools = mock_call.call_args.kwargs["tools"]
2471          assert tools == [{
2472              "type": "function",
2473              "function": {
2474                  "name": "ask",
2475                  "description": "Ask Crawl4AI",
2476                  "parameters": {"type": "object", "properties": {}},
2477              },
2478          }]
2479  
2480      def test_length_stop_reason(self):
2481          """finish_reason='length' maps to stopReason='maxTokens'."""
2482          fake_client = MagicMock()
2483          fake_client.chat.completions.create.return_value = _make_llm_response(
2484              finish_reason="length"
2485          )
2486  
2487          with patch(
2488              "agent.auxiliary_client.call_llm",
2489              return_value=fake_client.chat.completions.create.return_value,
2490          ):
2491              params = _make_sampling_params()
2492              result = asyncio.run(self.handler(None, params))
2493  
2494          assert isinstance(result, CreateMessageResult)
2495          assert result.stopReason == "maxTokens"
2496  
2497  
2498  # ---------------------------------------------------------------------------
2499  # 7. Tool use sampling callback
2500  # ---------------------------------------------------------------------------
2501  
2502  class TestSamplingCallbackToolUse:
2503      def setup_method(self):
2504          self.handler = SamplingHandler("tu", {})
2505  
2506      def test_tool_use_response(self):
2507          """LLM tool_calls response returns CreateMessageResultWithTools."""
2508          fake_client = MagicMock()
2509          fake_client.chat.completions.create.return_value = _make_llm_tool_response()
2510  
2511          with patch(
2512              "agent.auxiliary_client.call_llm",
2513              return_value=fake_client.chat.completions.create.return_value,
2514          ):
2515              params = _make_sampling_params()
2516              result = asyncio.run(self.handler(None, params))
2517  
2518          assert isinstance(result, CreateMessageResultWithTools)
2519          assert result.stopReason == "toolUse"
2520          assert result.model == "test-model"
2521          assert len(result.content) == 1
2522          tc = result.content[0]
2523          assert isinstance(tc, ToolUseContent)
2524          assert tc.name == "get_weather"
2525          assert tc.id == "call_1"
2526          assert tc.input == {"city": "London"}
2527  
2528      def test_multiple_tool_calls(self):
2529          """Multiple tool_calls in a single response."""
2530          fake_client = MagicMock()
2531          fake_client.chat.completions.create.return_value = _make_llm_tool_response(
2532              tool_calls_data=[
2533                  ("call_a", "func_a", '{"x": 1}'),
2534                  ("call_b", "func_b", '{"y": 2}'),
2535              ]
2536          )
2537  
2538          with patch(
2539              "agent.auxiliary_client.call_llm",
2540              return_value=fake_client.chat.completions.create.return_value,
2541          ):
2542              result = asyncio.run(self.handler(None, _make_sampling_params()))
2543  
2544          assert isinstance(result, CreateMessageResultWithTools)
2545          assert len(result.content) == 2
2546          assert result.content[0].name == "func_a"
2547          assert result.content[1].name == "func_b"
2548  
2549  
2550  # ---------------------------------------------------------------------------
2551  # 8. Tool loop governance
2552  # ---------------------------------------------------------------------------
2553  
2554  class TestToolLoopGovernance:
2555      def test_max_tool_rounds_enforcement(self):
2556          """After max_tool_rounds consecutive tool responses, an error is returned."""
2557          handler = SamplingHandler("tl", {"max_tool_rounds": 2})
2558          fake_client = MagicMock()
2559          fake_client.chat.completions.create.return_value = _make_llm_tool_response()
2560  
2561          with patch(
2562              "agent.auxiliary_client.call_llm",
2563              return_value=fake_client.chat.completions.create.return_value,
2564          ):
2565              params = _make_sampling_params()
2566              # Round 1, 2: allowed
2567              r1 = asyncio.run(handler(None, params))
2568              assert isinstance(r1, CreateMessageResultWithTools)
2569              r2 = asyncio.run(handler(None, params))
2570              assert isinstance(r2, CreateMessageResultWithTools)
2571              # Round 3: exceeds limit
2572              r3 = asyncio.run(handler(None, params))
2573              assert isinstance(r3, ErrorData)
2574              assert "Tool loop limit exceeded" in r3.message
2575  
2576      def test_text_response_resets_counter(self):
2577          """A text response resets the tool loop counter."""
2578          handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
2579  
2580          # Use a list to hold the current response, so the side_effect can
2581          # pick up changes between calls.
2582          responses = [_make_llm_tool_response()]
2583  
2584          with patch(
2585              "agent.auxiliary_client.call_llm",
2586              side_effect=lambda **kw: responses[0],
2587          ):
2588              # Tool response (round 1 of 1 allowed)
2589              r1 = asyncio.run(handler(None, _make_sampling_params()))
2590              assert isinstance(r1, CreateMessageResultWithTools)
2591  
2592              # Text response resets counter
2593              responses[0] = _make_llm_response()
2594              r2 = asyncio.run(handler(None, _make_sampling_params()))
2595              assert isinstance(r2, CreateMessageResult)
2596  
2597              # Tool response again (should succeed since counter was reset)
2598              responses[0] = _make_llm_tool_response()
2599              r3 = asyncio.run(handler(None, _make_sampling_params()))
2600              assert isinstance(r3, CreateMessageResultWithTools)
2601  
2602      def test_max_tool_rounds_zero_disables(self):
2603          """max_tool_rounds=0 means tool loops are disabled entirely."""
2604          handler = SamplingHandler("tl3", {"max_tool_rounds": 0})
2605          fake_client = MagicMock()
2606          fake_client.chat.completions.create.return_value = _make_llm_tool_response()
2607  
2608          with patch(
2609              "agent.auxiliary_client.call_llm",
2610              return_value=fake_client.chat.completions.create.return_value,
2611          ):
2612              result = asyncio.run(handler(None, _make_sampling_params()))
2613              assert isinstance(result, ErrorData)
2614              assert "Tool loops disabled" in result.message
2615  
2616  
2617  # ---------------------------------------------------------------------------
2618  # 9. Error paths: rate limit, timeout, no provider
2619  # ---------------------------------------------------------------------------
2620  
2621  class TestSamplingErrors:
2622      def test_rate_limit_error(self):
2623          handler = SamplingHandler("rle", {"max_rpm": 1})
2624          fake_client = MagicMock()
2625          fake_client.chat.completions.create.return_value = _make_llm_response()
2626  
2627          with patch(
2628              "agent.auxiliary_client.call_llm",
2629              return_value=fake_client.chat.completions.create.return_value,
2630          ):
2631              # First call succeeds
2632              r1 = asyncio.run(handler(None, _make_sampling_params()))
2633              assert isinstance(r1, CreateMessageResult)
2634              # Second call is rate limited
2635              r2 = asyncio.run(handler(None, _make_sampling_params()))
2636              assert isinstance(r2, ErrorData)
2637              assert "rate limit" in r2.message.lower()
2638              assert handler.metrics["errors"] == 1
2639  
2640      def test_timeout_error(self):
2641          handler = SamplingHandler("to", {"timeout": 0.05})
2642  
2643          def slow_call(**kwargs):
2644              import threading
2645              evt = threading.Event()
2646              evt.wait(5)  # blocks for up to 5 seconds (cancelled by timeout)
2647              return _make_llm_response()
2648  
2649          with patch(
2650              "agent.auxiliary_client.call_llm",
2651              side_effect=slow_call,
2652          ):
2653              result = asyncio.run(handler(None, _make_sampling_params()))
2654              assert isinstance(result, ErrorData)
2655              assert "timed out" in result.message.lower()
2656              assert handler.metrics["errors"] == 1
2657  
2658      def test_no_provider_error(self):
2659          handler = SamplingHandler("np", {})
2660  
2661          with patch(
2662              "agent.auxiliary_client.call_llm",
2663              side_effect=RuntimeError("No LLM provider configured"),
2664          ):
2665              result = asyncio.run(handler(None, _make_sampling_params()))
2666              assert isinstance(result, ErrorData)
2667              assert handler.metrics["errors"] == 1
2668  
2669      def test_empty_choices_returns_error(self):
2670          """LLM returning choices=[] is handled gracefully, not IndexError."""
2671          handler = SamplingHandler("ec", {})
2672          fake_client = MagicMock()
2673          fake_client.chat.completions.create.return_value = SimpleNamespace(
2674              choices=[],
2675              model="test-model",
2676              usage=SimpleNamespace(total_tokens=0),
2677          )
2678  
2679          with patch(
2680              "agent.auxiliary_client.call_llm",
2681              return_value=fake_client.chat.completions.create.return_value,
2682          ):
2683              result = asyncio.run(handler(None, _make_sampling_params()))
2684  
2685          assert isinstance(result, ErrorData)
2686          assert "empty response" in result.message.lower()
2687          assert handler.metrics["errors"] == 1
2688  
2689      def test_none_choices_returns_error(self):
2690          """LLM returning choices=None is handled gracefully, not TypeError."""
2691          handler = SamplingHandler("nc", {})
2692          fake_client = MagicMock()
2693          fake_client.chat.completions.create.return_value = SimpleNamespace(
2694              choices=None,
2695              model="test-model",
2696              usage=SimpleNamespace(total_tokens=0),
2697          )
2698  
2699          with patch(
2700              "agent.auxiliary_client.call_llm",
2701              return_value=fake_client.chat.completions.create.return_value,
2702          ):
2703              result = asyncio.run(handler(None, _make_sampling_params()))
2704  
2705          assert isinstance(result, ErrorData)
2706          assert "empty response" in result.message.lower()
2707          assert handler.metrics["errors"] == 1
2708  
2709      def test_missing_choices_attr_returns_error(self):
2710          """LLM response without choices attribute is handled gracefully."""
2711          handler = SamplingHandler("mc", {})
2712          fake_client = MagicMock()
2713          fake_client.chat.completions.create.return_value = SimpleNamespace(
2714              model="test-model",
2715              usage=SimpleNamespace(total_tokens=0),
2716          )
2717  
2718          with patch(
2719              "agent.auxiliary_client.call_llm",
2720              return_value=fake_client.chat.completions.create.return_value,
2721          ):
2722              result = asyncio.run(handler(None, _make_sampling_params()))
2723  
2724          assert isinstance(result, ErrorData)
2725          assert "empty response" in result.message.lower()
2726          assert handler.metrics["errors"] == 1
2727  
2728  
2729  # ---------------------------------------------------------------------------
2730  # 10. Model whitelist
2731  # ---------------------------------------------------------------------------
2732  
2733  class TestModelWhitelist:
2734      def test_allowed_model_passes(self):
2735          handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]})
2736          fake_client = MagicMock()
2737          fake_client.chat.completions.create.return_value = _make_llm_response()
2738  
2739          with patch(
2740              "agent.auxiliary_client.call_llm",
2741              return_value=fake_client.chat.completions.create.return_value,
2742          ):
2743              result = asyncio.run(handler(None, _make_sampling_params()))
2744              assert isinstance(result, CreateMessageResult)
2745  
2746      def test_disallowed_model_rejected(self):
2747          handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"], "model": "test-model"})
2748          fake_client = MagicMock()
2749  
2750          with patch(
2751              "agent.auxiliary_client.call_llm",
2752              return_value=fake_client.chat.completions.create.return_value,
2753          ):
2754              result = asyncio.run(handler(None, _make_sampling_params()))
2755              assert isinstance(result, ErrorData)
2756              assert "not allowed" in result.message
2757              assert handler.metrics["errors"] == 1
2758  
2759      def test_empty_whitelist_allows_all(self):
2760          handler = SamplingHandler("wl3", {"allowed_models": []})
2761          fake_client = MagicMock()
2762          fake_client.chat.completions.create.return_value = _make_llm_response()
2763  
2764          with patch(
2765              "agent.auxiliary_client.call_llm",
2766              return_value=fake_client.chat.completions.create.return_value,
2767          ):
2768              result = asyncio.run(handler(None, _make_sampling_params()))
2769              assert isinstance(result, CreateMessageResult)
2770  
2771  
2772  # ---------------------------------------------------------------------------
2773  # 11. Malformed tool_call arguments
2774  # ---------------------------------------------------------------------------
2775  
2776  class TestMalformedToolCallArgs:
2777      def test_invalid_json_wrapped_as_raw(self):
2778          """Malformed JSON arguments get wrapped in {"_raw": ...}."""
2779          handler = SamplingHandler("mf", {})
2780          fake_client = MagicMock()
2781          fake_client.chat.completions.create.return_value = _make_llm_tool_response(
2782              tool_calls_data=[("call_x", "some_tool", "not valid json {{{")]
2783          )
2784  
2785          with patch(
2786              "agent.auxiliary_client.call_llm",
2787              return_value=fake_client.chat.completions.create.return_value,
2788          ):
2789              result = asyncio.run(handler(None, _make_sampling_params()))
2790  
2791          assert isinstance(result, CreateMessageResultWithTools)
2792          tc = result.content[0]
2793          assert isinstance(tc, ToolUseContent)
2794          assert tc.input == {"_raw": "not valid json {{{"}
2795  
2796      def test_dict_args_pass_through(self):
2797          """When arguments are already a dict, they pass through directly."""
2798          handler = SamplingHandler("mf2", {})
2799  
2800          # Build a tool call where arguments is already a dict
2801          tc_obj = SimpleNamespace(
2802              id="call_d",
2803              function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}),
2804          )
2805          message = SimpleNamespace(content=None, tool_calls=[tc_obj])
2806          choice = SimpleNamespace(finish_reason="tool_calls", message=message)
2807          usage = SimpleNamespace(total_tokens=10)
2808          response = SimpleNamespace(choices=[choice], model="m", usage=usage)
2809  
2810          fake_client = MagicMock()
2811          fake_client.chat.completions.create.return_value = response
2812  
2813          with patch(
2814              "agent.auxiliary_client.call_llm",
2815              return_value=fake_client.chat.completions.create.return_value,
2816          ):
2817              result = asyncio.run(handler(None, _make_sampling_params()))
2818  
2819          assert isinstance(result, CreateMessageResultWithTools)
2820          assert result.content[0].input == {"key": "val"}
2821  
2822  
2823  # ---------------------------------------------------------------------------
2824  # 12. Metrics tracking
2825  # ---------------------------------------------------------------------------
2826  
2827  class TestMetricsTracking:
2828      def test_request_and_token_metrics(self):
2829          handler = SamplingHandler("met", {})
2830          fake_client = MagicMock()
2831          fake_client.chat.completions.create.return_value = _make_llm_response()
2832  
2833          with patch(
2834              "agent.auxiliary_client.call_llm",
2835              return_value=fake_client.chat.completions.create.return_value,
2836          ):
2837              asyncio.run(handler(None, _make_sampling_params()))
2838  
2839          assert handler.metrics["requests"] == 1
2840          assert handler.metrics["tokens_used"] == 42
2841          assert handler.metrics["errors"] == 0
2842  
2843      def test_tool_use_count_metric(self):
2844          handler = SamplingHandler("met2", {})
2845          fake_client = MagicMock()
2846          fake_client.chat.completions.create.return_value = _make_llm_tool_response()
2847  
2848          with patch(
2849              "agent.auxiliary_client.call_llm",
2850              return_value=fake_client.chat.completions.create.return_value,
2851          ):
2852              asyncio.run(handler(None, _make_sampling_params()))
2853  
2854          assert handler.metrics["tool_use_count"] == 1
2855          assert handler.metrics["requests"] == 1
2856  
2857      def test_error_metric_incremented(self):
2858          handler = SamplingHandler("met3", {})
2859  
2860          with patch(
2861              "agent.auxiliary_client.call_llm",
2862              side_effect=RuntimeError("No LLM provider configured"),
2863          ):
2864              asyncio.run(handler(None, _make_sampling_params()))
2865  
2866          assert handler.metrics["errors"] == 1
2867          assert handler.metrics["requests"] == 0
2868  
2869  
2870  # ---------------------------------------------------------------------------
2871  # 13. session_kwargs()
2872  # ---------------------------------------------------------------------------
2873  
2874  class TestSessionKwargs:
2875      def test_returns_correct_keys(self):
2876          handler = SamplingHandler("sk", {})
2877          kwargs = handler.session_kwargs()
2878          assert "sampling_callback" in kwargs
2879          assert "sampling_capabilities" in kwargs
2880          assert kwargs["sampling_callback"] is handler
2881  
2882      def test_sampling_capabilities_type(self):
2883          handler = SamplingHandler("sk2", {})
2884          kwargs = handler.session_kwargs()
2885          cap = kwargs["sampling_capabilities"]
2886          assert isinstance(cap, SamplingCapability)
2887          assert isinstance(cap.tools, SamplingToolsCapability)
2888  
2889  
2890  # ---------------------------------------------------------------------------
2891  # 14. MCPServerTask integration
2892  # ---------------------------------------------------------------------------
2893  
2894  class TestMCPServerTaskSamplingIntegration:
2895      def test_sampling_handler_created_when_enabled(self):
2896          """MCPServerTask.run() creates a SamplingHandler when sampling is enabled."""
2897          from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
2898  
2899          server = MCPServerTask("int_test")
2900          config = {
2901              "command": "fake",
2902              "sampling": {"enabled": True, "max_rpm": 5},
2903          }
2904          # We only need to test the setup logic, not the actual connection.
2905          # Calling run() would attempt a real connection, so we test the
2906          # sampling setup portion directly.
2907          server._config = config
2908          sampling_config = config.get("sampling", {})
2909          if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
2910              server._sampling = SamplingHandler(server.name, sampling_config)
2911          else:
2912              server._sampling = None
2913  
2914          assert server._sampling is not None
2915          assert isinstance(server._sampling, SamplingHandler)
2916          assert server._sampling.server_name == "int_test"
2917          assert server._sampling.max_rpm == 5
2918  
2919      def test_sampling_handler_none_when_disabled(self):
2920          """MCPServerTask._sampling is None when sampling is disabled."""
2921          from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
2922  
2923          server = MCPServerTask("int_test2")
2924          config = {
2925              "command": "fake",
2926              "sampling": {"enabled": False},
2927          }
2928          server._config = config
2929          sampling_config = config.get("sampling", {})
2930          if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
2931              server._sampling = SamplingHandler(server.name, sampling_config)
2932          else:
2933              server._sampling = None
2934  
2935          assert server._sampling is None
2936  
2937      def test_session_kwargs_used_in_stdio(self):
2938          """When sampling is set, session_kwargs() are passed to ClientSession."""
2939          from tools.mcp_tool import MCPServerTask
2940  
2941          server = MCPServerTask("sk_test")
2942          server._sampling = SamplingHandler("sk_test", {"max_rpm": 7})
2943          kwargs = server._sampling.session_kwargs()
2944          assert "sampling_callback" in kwargs
2945          assert "sampling_capabilities" in kwargs
2946  
2947  
2948  # ---------------------------------------------------------------------------
2949  # Discovery failed_count tracking
2950  # ---------------------------------------------------------------------------
2951  
2952  class TestDiscoveryFailedCount:
2953      """Verify discover_mcp_tools() correctly tracks failed server connections."""
2954  
2955      def test_failed_server_increments_failed_count(self):
2956          """When _discover_and_register_server raises, failed_count increments."""
2957          from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop
2958  
2959          fake_config = {
2960              "good_server": {"command": "npx", "args": ["good"]},
2961              "bad_server": {"command": "npx", "args": ["bad"]},
2962          }
2963  
2964          async def fake_register(name, cfg):
2965              if name == "bad_server":
2966                  raise ConnectionError("Connection refused")
2967              # Simulate successful registration
2968              from tools.mcp_tool import MCPServerTask
2969              server = MCPServerTask(name)
2970              server.session = MagicMock()
2971              server._tools = [_make_mcp_tool("tool_a")]
2972              _servers[name] = server
2973              return [f"mcp_{name}_tool_a"]
2974  
2975          with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
2976               patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
2977               patch("tools.mcp_tool._MCP_AVAILABLE", True), \
2978               patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_good_server_tool_a"]):
2979              _ensure_mcp_loop()
2980  
2981              # Capture the logger to verify failed_count in summary
2982              with patch("tools.mcp_tool.logger") as mock_logger:
2983                  discover_mcp_tools()
2984  
2985                  # Find the summary info call
2986                  info_calls = [
2987                      str(call)
2988                      for call in mock_logger.info.call_args_list
2989                      if "failed" in str(call).lower() or "MCP:" in str(call)
2990                  ]
2991                  # The summary should mention the failure
2992                  assert any("1 failed" in str(c) for c in info_calls), (
2993                      f"Summary should report 1 failed server, got: {info_calls}"
2994                  )
2995  
2996          _servers.pop("good_server", None)
2997          _servers.pop("bad_server", None)
2998  
2999      def test_all_servers_fail_still_prints_summary(self):
3000          """When all servers fail, a summary with failure count is still printed."""
3001          from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop
3002  
3003          fake_config = {
3004              "srv1": {"command": "npx", "args": ["a"]},
3005              "srv2": {"command": "npx", "args": ["b"]},
3006          }
3007  
3008          async def always_fail(name, cfg):
3009              raise ConnectionError(f"Server {name} refused")
3010  
3011          with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
3012               patch("tools.mcp_tool._discover_and_register_server", side_effect=always_fail), \
3013               patch("tools.mcp_tool._MCP_AVAILABLE", True), \
3014               patch("tools.mcp_tool._existing_tool_names", return_value=[]):
3015              _ensure_mcp_loop()
3016  
3017              with patch("tools.mcp_tool.logger") as mock_logger:
3018                  discover_mcp_tools()
3019  
3020                  # Summary must be printed even when all servers fail
3021                  info_calls = [str(call) for call in mock_logger.info.call_args_list]
3022                  assert any("2 failed" in str(c) for c in info_calls), (
3023                      f"Summary should report 2 failed servers, got: {info_calls}"
3024                  )
3025  
3026          _servers.pop("srv1", None)
3027          _servers.pop("srv2", None)
3028  
3029      def test_ok_servers_excludes_failures(self):
3030          """ok_servers count correctly excludes failed servers."""
3031          from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop
3032  
3033          fake_config = {
3034              "ok1": {"command": "npx", "args": ["ok1"]},
3035              "ok2": {"command": "npx", "args": ["ok2"]},
3036              "fail1": {"command": "npx", "args": ["fail"]},
3037          }
3038  
3039          async def selective_register(name, cfg):
3040              if name == "fail1":
3041                  raise ConnectionError("Refused")
3042              from tools.mcp_tool import MCPServerTask
3043              server = MCPServerTask(name)
3044              server.session = MagicMock()
3045              server._tools = [_make_mcp_tool("t")]
3046              _servers[name] = server
3047              return [f"mcp_{name}_t"]
3048  
3049          with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
3050               patch("tools.mcp_tool._discover_and_register_server", side_effect=selective_register), \
3051               patch("tools.mcp_tool._MCP_AVAILABLE", True), \
3052               patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_ok1_t", "mcp_ok2_t"]):
3053              _ensure_mcp_loop()
3054  
3055              with patch("tools.mcp_tool.logger") as mock_logger:
3056                  discover_mcp_tools()
3057  
3058                  info_calls = [str(call) for call in mock_logger.info.call_args_list]
3059                  # Should say "2 server(s)" not "3 server(s)"
3060                  assert any("2 server" in str(c) for c in info_calls), (
3061                      f"Summary should report 2 ok servers, got: {info_calls}"
3062                  )
3063                  assert any("1 failed" in str(c) for c in info_calls), (
3064                      f"Summary should report 1 failed, got: {info_calls}"
3065                  )
3066  
3067          _servers.pop("ok1", None)
3068          _servers.pop("ok2", None)
3069          _servers.pop("fail1", None)
3070  
3071  
3072  class TestMCPSelectiveToolLoading:
3073      """Tests for per-server MCP filtering and utility tool policies."""
3074  
3075      def _make_server(self, name, tool_names, session=None):
3076          server = _make_mock_server(
3077              name,
3078              session=session or SimpleNamespace(),
3079              tools=[_make_mcp_tool(n, n) for n in tool_names],
3080          )
3081          return server
3082  
3083      def _run_discover(self, name, tool_names, config, session=None):
3084          from tools.registry import ToolRegistry
3085          from tools.mcp_tool import _discover_and_register_server, _servers
3086  
3087          mock_registry = ToolRegistry()
3088          server = self._make_server(name, tool_names, session=session)
3089  
3090          async def fake_connect(_name, _config):
3091              return server
3092  
3093          async def run():
3094              with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
3095                   patch("tools.registry.registry", mock_registry), \
3096                   patch("toolsets.create_custom_toolset"):
3097                  return await _discover_and_register_server(name, config)
3098  
3099          try:
3100              registered = asyncio.run(run())
3101          finally:
3102              _servers.pop(name, None)
3103          return registered, mock_registry
3104  
3105      def test_include_takes_precedence_over_exclude(self):
3106          config = {
3107              "url": "https://mcp.example.com",
3108              "tools": {
3109                  "include": ["create_service"],
3110                  "exclude": ["create_service", "delete_service"],
3111              },
3112          }
3113          registered, _ = self._run_discover(
3114              "ink",
3115              ["create_service", "delete_service", "list_services"],
3116              config,
3117              session=SimpleNamespace(),
3118          )
3119          assert registered == ["mcp_ink_create_service"]
3120  
3121      def test_exclude_filter_registers_all_except_listed_tools(self):
3122          config = {
3123              "url": "https://mcp.example.com",
3124              "tools": {"exclude": ["delete_service"]},
3125          }
3126          registered, _ = self._run_discover(
3127              "ink_exclude",
3128              ["create_service", "delete_service", "list_services"],
3129              config,
3130              session=SimpleNamespace(),
3131          )
3132          assert registered == [
3133              "mcp_ink_exclude_create_service",
3134              "mcp_ink_exclude_list_services",
3135          ]
3136  
3137      def test_include_filter_skips_utility_tools_without_capabilities(self):
3138          config = {
3139              "url": "https://mcp.example.com",
3140              "tools": {"include": ["create_service"]},
3141          }
3142          registered, mock_registry = self._run_discover(
3143              "ink_no_caps",
3144              ["create_service", "delete_service"],
3145              config,
3146              session=SimpleNamespace(),
3147          )
3148          assert registered == ["mcp_ink_no_caps_create_service"]
3149          assert set(mock_registry.get_all_tool_names()) == {"mcp_ink_no_caps_create_service"}
3150  
3151      def test_no_filter_registers_all_server_tools_when_no_utilities_supported(self):
3152          registered, _ = self._run_discover(
3153              "ink_no_filter",
3154              ["create_service", "delete_service", "list_services"],
3155              {"url": "https://mcp.example.com"},
3156              session=SimpleNamespace(),
3157          )
3158          assert registered == [
3159              "mcp_ink_no_filter_create_service",
3160              "mcp_ink_no_filter_delete_service",
3161              "mcp_ink_no_filter_list_services",
3162          ]
3163  
3164      def test_resources_and_prompts_can_be_disabled_explicitly(self):
3165          session = SimpleNamespace(
3166              list_resources=AsyncMock(),
3167              read_resource=AsyncMock(),
3168              list_prompts=AsyncMock(),
3169              get_prompt=AsyncMock(),
3170          )
3171          config = {
3172              "url": "https://mcp.example.com",
3173              "tools": {
3174                  "resources": False,
3175                  "prompts": False,
3176              },
3177          }
3178          registered, _ = self._run_discover(
3179              "ink_disabled_utils",
3180              ["create_service"],
3181              config,
3182              session=session,
3183          )
3184          assert registered == ["mcp_ink_disabled_utils_create_service"]
3185  
3186      def test_registers_only_utility_tools_supported_by_server_capabilities(self):
3187          session = SimpleNamespace(
3188              list_resources=AsyncMock(return_value=SimpleNamespace(resources=[])),
3189              read_resource=AsyncMock(return_value=SimpleNamespace(contents=[])),
3190          )
3191          registered, _ = self._run_discover(
3192              "ink_resources_only",
3193              ["create_service"],
3194              {"url": "https://mcp.example.com"},
3195              session=session,
3196          )
3197          assert "mcp_ink_resources_only_create_service" in registered
3198          assert "mcp_ink_resources_only_list_resources" in registered
3199          assert "mcp_ink_resources_only_read_resource" in registered
3200          assert "mcp_ink_resources_only_list_prompts" not in registered
3201          assert "mcp_ink_resources_only_get_prompt" not in registered
3202  
3203      def test_existing_tool_names_reflect_registered_subset(self):
3204          from tools.mcp_tool import _existing_tool_names, _servers, _discover_and_register_server
3205          from tools.registry import ToolRegistry
3206  
3207          mock_registry = ToolRegistry()
3208          server = self._make_server(
3209              "ink_existing",
3210              ["create_service", "delete_service"],
3211              session=SimpleNamespace(),
3212          )
3213  
3214          async def fake_connect(_name, _config):
3215              return server
3216  
3217          async def run():
3218              with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
3219                   patch.dict("tools.mcp_tool._servers", {}, clear=True), \
3220                   patch("tools.registry.registry", mock_registry), \
3221                   patch("toolsets.create_custom_toolset"):
3222                  registered = await _discover_and_register_server(
3223                      "ink_existing",
3224                      {"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}},
3225                  )
3226                  return registered, _existing_tool_names()
3227  
3228          try:
3229              registered, existing = asyncio.run(run())
3230              assert registered == ["mcp_ink_existing_create_service"]
3231              assert existing == ["mcp_ink_existing_create_service"]
3232          finally:
3233              _servers.pop("ink_existing", None)
3234  
3235      def test_no_toolset_created_when_everything_is_filtered_out(self):
3236          from tools.registry import ToolRegistry
3237          from tools.mcp_tool import _discover_and_register_server, _servers
3238  
3239          mock_registry = ToolRegistry()
3240          server = self._make_server("ink_none", ["create_service"], session=SimpleNamespace())
3241          mock_create = MagicMock()
3242  
3243          async def fake_connect(_name, _config):
3244              return server
3245  
3246          async def run():
3247              with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
3248                   patch("tools.registry.registry", mock_registry), \
3249                   patch("toolsets.create_custom_toolset", mock_create):
3250                  return await _discover_and_register_server(
3251                      "ink_none",
3252                      {
3253                          "url": "https://mcp.example.com",
3254                          "tools": {
3255                              "include": ["missing_tool"],
3256                              "resources": False,
3257                              "prompts": False,
3258                          },
3259                      },
3260                  )
3261  
3262          try:
3263              registered = asyncio.run(run())
3264              assert registered == []
3265              mock_create.assert_not_called()
3266              assert mock_registry.get_all_tool_names() == []
3267          finally:
3268              _servers.pop("ink_none", None)
3269  
3270      def test_enabled_false_skips_connection_attempt(self):
3271          from tools.mcp_tool import discover_mcp_tools
3272  
3273          connect_called = []
3274  
3275          async def fake_connect(name, config):
3276              connect_called.append(name)
3277              return self._make_server(name, ["create_service"])
3278  
3279          fake_config = {
3280              "ink": {
3281                  "url": "https://mcp.example.com",
3282                  "enabled": False,
3283              }
3284          }
3285          fake_toolsets = {
3286              "hermes-cli": {"tools": [], "description": "CLI", "includes": []},
3287          }
3288  
3289          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
3290               patch("tools.mcp_tool._servers", {}), \
3291               patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \
3292               patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
3293               patch("toolsets.TOOLSETS", fake_toolsets):
3294              result = discover_mcp_tools()
3295  
3296          assert connect_called == []
3297          assert result == []
3298  
3299  
3300  # ---------------------------------------------------------------------------
3301  # Tool name collision protection
3302  # ---------------------------------------------------------------------------
3303  
3304  class TestRegistryCollisionWarning:
3305      """registry.register() warns when a tool name is overwritten by a different toolset."""
3306  
3307      def test_overwrite_different_toolset_logs_warning(self, caplog):
3308          """Overwriting a tool from a different toolset is REJECTED with an error."""
3309          from tools.registry import ToolRegistry
3310          import logging
3311  
3312          reg = ToolRegistry()
3313          schema = {"name": "my_tool", "description": "test", "parameters": {"type": "object", "properties": {}}}
3314          handler = lambda args, **kw: "{}"
3315  
3316          reg.register(name="my_tool", toolset="builtin", schema=schema, handler=handler)
3317  
3318          with caplog.at_level(logging.ERROR, logger="tools.registry"):
3319              reg.register(name="my_tool", toolset="mcp-ext", schema=schema, handler=handler)
3320  
3321          assert any("rejected" in r.message.lower() for r in caplog.records)
3322          assert any("builtin" in r.message and "mcp-ext" in r.message for r in caplog.records)
3323          # The original tool should still be from 'builtin', not overwritten
3324          assert reg.get_toolset_for_tool("my_tool") == "builtin"
3325  
3326      def test_overwrite_same_toolset_no_warning(self, caplog):
3327          """Re-registering within the same toolset is silent (e.g. reconnect)."""
3328          from tools.registry import ToolRegistry
3329          import logging
3330  
3331          reg = ToolRegistry()
3332          schema = {"name": "my_tool", "description": "test", "parameters": {"type": "object", "properties": {}}}
3333          handler = lambda args, **kw: "{}"
3334  
3335          reg.register(name="my_tool", toolset="mcp-server", schema=schema, handler=handler)
3336  
3337          with caplog.at_level(logging.WARNING, logger="tools.registry"):
3338              reg.register(name="my_tool", toolset="mcp-server", schema=schema, handler=handler)
3339  
3340          assert not any("collision" in r.message.lower() for r in caplog.records)
3341  
3342  
3343  class TestMCPBuiltinCollisionGuard:
3344      """MCP tools that collide with built-in tool names are skipped."""
3345  
3346      def test_mcp_tool_skipped_when_builtin_exists(self):
3347          """An MCP tool whose prefixed name collides with a built-in is skipped."""
3348          from tools.registry import ToolRegistry
3349          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
3350  
3351          mock_registry = ToolRegistry()
3352  
3353          # Pre-register a "built-in" tool with the name that the MCP tool would produce.
3354          # Server "abc", tool "search" → mcp_abc_search
3355          builtin_schema = {
3356              "name": "mcp_abc_search",
3357              "description": "A hypothetical built-in",
3358              "parameters": {"type": "object", "properties": {}},
3359          }
3360          mock_registry.register(
3361              name="mcp_abc_search", toolset="web",
3362              schema=builtin_schema, handler=lambda a, **k: "{}",
3363          )
3364  
3365          mock_tools = [_make_mcp_tool("search", "Search the web")]
3366          mock_session = MagicMock()
3367  
3368          async def fake_connect(name, config):
3369              server = MCPServerTask(name)
3370              server.session = mock_session
3371              server._tools = mock_tools
3372              return server
3373  
3374          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
3375               patch("tools.registry.registry", mock_registry):
3376              registered = asyncio.run(
3377                  _discover_and_register_server("abc", {"command": "test", "args": []})
3378              )
3379  
3380          # The MCP tool should have been skipped — built-in preserved.
3381          assert "mcp_abc_search" not in registered
3382          assert mock_registry.get_toolset_for_tool("mcp_abc_search") == "web"
3383  
3384          _servers.pop("abc", None)
3385  
3386      def test_mcp_tool_registered_when_no_builtin_collision(self):
3387          """MCP tools register normally when there's no collision."""
3388          from tools.registry import ToolRegistry
3389          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
3390  
3391          mock_registry = ToolRegistry()
3392          mock_tools = [_make_mcp_tool("web_search", "Search the web")]
3393          mock_session = MagicMock()
3394  
3395          async def fake_connect(name, config):
3396              server = MCPServerTask(name)
3397              server.session = mock_session
3398              server._tools = mock_tools
3399              return server
3400  
3401          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
3402               patch("tools.registry.registry", mock_registry):
3403              registered = asyncio.run(
3404                  _discover_and_register_server("minimax", {"command": "test", "args": []})
3405              )
3406  
3407          assert "mcp_minimax_web_search" in registered
3408          assert mock_registry.get_toolset_for_tool("mcp_minimax_web_search") == "mcp-minimax"
3409  
3410          _servers.pop("minimax", None)
3411  
3412      def test_mcp_tool_allowed_when_collision_is_another_mcp(self):
3413          """Collision between two MCP toolsets is allowed (last wins)."""
3414          from tools.registry import ToolRegistry
3415          from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask
3416  
3417          mock_registry = ToolRegistry()
3418  
3419          # Pre-register an MCP tool from a different server.
3420          mcp_schema = {
3421              "name": "mcp_srv_do_thing",
3422              "description": "From another MCP server",
3423              "parameters": {"type": "object", "properties": {}},
3424          }
3425          mock_registry.register(
3426              name="mcp_srv_do_thing", toolset="mcp-old",
3427              schema=mcp_schema, handler=lambda a, **k: "{}",
3428          )
3429  
3430          mock_tools = [_make_mcp_tool("do_thing", "Do a thing")]
3431          mock_session = MagicMock()
3432  
3433          async def fake_connect(name, config):
3434              server = MCPServerTask(name)
3435              server.session = mock_session
3436              server._tools = mock_tools
3437              return server
3438  
3439          with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \
3440               patch("tools.registry.registry", mock_registry):
3441              registered = asyncio.run(
3442                  _discover_and_register_server("srv", {"command": "test", "args": []})
3443              )
3444  
3445          # MCP-to-MCP collision is allowed — the new server wins.
3446          assert "mcp_srv_do_thing" in registered
3447          assert mock_registry.get_toolset_for_tool("mcp_srv_do_thing") == "mcp-srv"
3448  
3449          _servers.pop("srv", None)
3450  
3451  
3452  # ---------------------------------------------------------------------------
3453  # sanitize_mcp_name_component
3454  # ---------------------------------------------------------------------------
3455  
3456  
3457  class TestSanitizeMcpNameComponent:
3458      """Verify sanitize_mcp_name_component handles all edge cases."""
3459  
3460      def test_hyphens_replaced(self):
3461          from tools.mcp_tool import sanitize_mcp_name_component
3462          assert sanitize_mcp_name_component("my-server") == "my_server"
3463  
3464      def test_dots_replaced(self):
3465          from tools.mcp_tool import sanitize_mcp_name_component
3466          assert sanitize_mcp_name_component("ai.exa") == "ai_exa"
3467  
3468      def test_slashes_replaced(self):
3469          from tools.mcp_tool import sanitize_mcp_name_component
3470          assert sanitize_mcp_name_component("ai.exa/exa") == "ai_exa_exa"
3471  
3472      def test_mixed_special_characters(self):
3473          from tools.mcp_tool import sanitize_mcp_name_component
3474          assert sanitize_mcp_name_component("@scope/my-pkg.v2") == "_scope_my_pkg_v2"
3475  
3476      def test_alphanumeric_and_underscores_preserved(self):
3477          from tools.mcp_tool import sanitize_mcp_name_component
3478          assert sanitize_mcp_name_component("my_server_123") == "my_server_123"
3479  
3480      def test_empty_string(self):
3481          from tools.mcp_tool import sanitize_mcp_name_component
3482          assert sanitize_mcp_name_component("") == ""
3483  
3484      def test_none_returns_empty(self):
3485          from tools.mcp_tool import sanitize_mcp_name_component
3486          assert sanitize_mcp_name_component(None) == ""
3487  
3488      def test_slash_in_convert_mcp_schema(self):
3489          """Server names with slashes produce valid tool names via _convert_mcp_schema."""
3490          from tools.mcp_tool import _convert_mcp_schema
3491  
3492          mcp_tool = _make_mcp_tool(name="search")
3493          schema = _convert_mcp_schema("ai.exa/exa", mcp_tool)
3494          assert schema["name"] == "mcp_ai_exa_exa_search"
3495          # Must match Anthropic's pattern: ^[a-zA-Z0-9_-]{1,128}$
3496          import re
3497          assert re.match(r"^[a-zA-Z0-9_-]{1,128}$", schema["name"])
3498  
3499      def test_slash_in_build_utility_schemas(self):
3500          """Server names with slashes produce valid utility tool names."""
3501          from tools.mcp_tool import _build_utility_schemas
3502  
3503          schemas = _build_utility_schemas("ai.exa/exa")
3504          for s in schemas:
3505              name = s["schema"]["name"]
3506              assert "/" not in name
3507              assert "." not in name
3508  
3509      def test_slash_in_server_alias_resolution(self):
3510          """Server names with slashes resolve through their live MCP alias."""
3511          from tools.registry import ToolRegistry
3512          from toolsets import resolve_toolset, validate_toolset
3513  
3514          reg = ToolRegistry()
3515          reg.register(
3516              name="mcp_ai_exa_exa_search",
3517              toolset="mcp-ai.exa/exa",
3518              schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}},
3519              handler=lambda *_args, **_kwargs: "{}",
3520          )
3521          reg.register_toolset_alias("ai.exa/exa", "mcp-ai.exa/exa")
3522  
3523          with patch("tools.registry.registry", reg):
3524              assert validate_toolset("ai.exa/exa") is True
3525              assert "mcp_ai_exa_exa_search" in resolve_toolset("ai.exa/exa")
3526  
3527  
3528  # ---------------------------------------------------------------------------
3529  # register_mcp_servers public API
3530  # ---------------------------------------------------------------------------
3531  
3532  
3533  class TestRegisterMcpServers:
3534      """Verify the new register_mcp_servers() public API."""
3535  
3536      def test_empty_servers_returns_empty(self):
3537          from tools.mcp_tool import register_mcp_servers
3538  
3539          with patch("tools.mcp_tool._MCP_AVAILABLE", True):
3540              result = register_mcp_servers({})
3541          assert result == []
3542  
3543      def test_mcp_not_available_returns_empty(self):
3544          from tools.mcp_tool import register_mcp_servers
3545  
3546          with patch("tools.mcp_tool._MCP_AVAILABLE", False):
3547              result = register_mcp_servers({"srv": {"command": "test"}})
3548          assert result == []
3549  
3550      def test_skips_already_connected_servers(self):
3551          from tools.mcp_tool import register_mcp_servers, _servers
3552  
3553          mock_server = _make_mock_server("existing")
3554          _servers["existing"] = mock_server
3555  
3556          try:
3557              with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
3558                   patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_existing_tool"]):
3559                  result = register_mcp_servers({"existing": {"command": "test"}})
3560              assert result == ["mcp_existing_tool"]
3561          finally:
3562              _servers.pop("existing", None)
3563  
3564      def test_skips_disabled_servers(self):
3565          from tools.mcp_tool import register_mcp_servers, _servers
3566  
3567          try:
3568              with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
3569                   patch("tools.mcp_tool._existing_tool_names", return_value=[]):
3570                  result = register_mcp_servers({"srv": {"command": "test", "enabled": False}})
3571              assert result == []
3572          finally:
3573              _servers.pop("srv", None)
3574  
3575      def test_connects_new_servers(self):
3576          from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
3577  
3578          fake_config = {"my_server": {"command": "npx", "args": ["test"]}}
3579  
3580          async def fake_register(name, cfg):
3581              server = _make_mock_server(name)
3582              server._registered_tool_names = ["mcp_my_server_tool1"]
3583              _servers[name] = server
3584              return ["mcp_my_server_tool1"]
3585  
3586          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
3587               patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
3588               patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_my_server_tool1"]):
3589              _ensure_mcp_loop()
3590              result = register_mcp_servers(fake_config)
3591  
3592          assert "mcp_my_server_tool1" in result
3593          _servers.pop("my_server", None)
3594  
3595      def test_logs_summary_on_success(self):
3596          from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop
3597  
3598          fake_config = {"srv": {"command": "npx", "args": ["test"]}}
3599  
3600          async def fake_register(name, cfg):
3601              server = _make_mock_server(name)
3602              server._registered_tool_names = ["mcp_srv_t1", "mcp_srv_t2"]
3603              _servers[name] = server
3604              return ["mcp_srv_t1", "mcp_srv_t2"]
3605  
3606          with patch("tools.mcp_tool._MCP_AVAILABLE", True), \
3607               patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \
3608               patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_srv_t1", "mcp_srv_t2"]):
3609              _ensure_mcp_loop()
3610  
3611              with patch("tools.mcp_tool.logger") as mock_logger:
3612                  register_mcp_servers(fake_config)
3613  
3614                  info_calls = [str(c) for c in mock_logger.info.call_args_list]
3615                  assert any("2 tool(s)" in c and "1 server(s)" in c for c in info_calls), (
3616                      f"Summary should report 2 tools from 1 server, got: {info_calls}"
3617                  )
3618  
3619          _servers.pop("srv", None)