test_mcp_tool.py
1 """Tests for the MCP (Model Context Protocol) client support. 2 3 All tests use mocks -- no real MCP servers or subprocesses are started. 4 """ 5 6 import asyncio 7 import json 8 import os 9 import threading 10 import time 11 from types import SimpleNamespace 12 from unittest.mock import AsyncMock, MagicMock, patch 13 14 import pytest 15 16 17 # --------------------------------------------------------------------------- 18 # Helpers 19 # --------------------------------------------------------------------------- 20 21 def _make_mcp_tool(name="read_file", description="Read a file", input_schema=None): 22 """Create a fake MCP Tool object matching the SDK interface.""" 23 tool = SimpleNamespace() 24 tool.name = name 25 tool.description = description 26 tool.inputSchema = input_schema or { 27 "type": "object", 28 "properties": { 29 "path": {"type": "string", "description": "File path"}, 30 }, 31 "required": ["path"], 32 } 33 return tool 34 35 36 def _make_call_result(text="file contents here", is_error=False): 37 """Create a fake MCP CallToolResult.""" 38 block = SimpleNamespace(text=text) 39 return SimpleNamespace(content=[block], isError=is_error) 40 41 42 def _make_mock_server(name, session=None, tools=None): 43 """Create an MCPServerTask with mock attributes for testing.""" 44 from tools.mcp_tool import MCPServerTask 45 server = MCPServerTask(name) 46 server.session = session 47 server._tools = tools or [] 48 return server 49 50 51 # --------------------------------------------------------------------------- 52 # Config loading 53 # --------------------------------------------------------------------------- 54 55 class TestLoadMCPConfig: 56 def test_no_config_returns_empty(self): 57 """No mcp_servers key in config -> empty dict.""" 58 with patch("hermes_cli.config.load_config", return_value={"model": "test"}): 59 from tools.mcp_tool import _load_mcp_config 60 result = _load_mcp_config() 61 assert result == {} 62 63 def test_valid_config_parsed(self): 64 """Valid mcp_servers config is returned as-is.""" 65 servers = { 66 "filesystem": { 67 "command": "npx", 68 "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], 69 "env": {}, 70 } 71 } 72 with patch("hermes_cli.config.load_config", return_value={"mcp_servers": servers}): 73 from tools.mcp_tool import _load_mcp_config 74 result = _load_mcp_config() 75 assert "filesystem" in result 76 assert result["filesystem"]["command"] == "npx" 77 78 def test_mcp_servers_not_dict_returns_empty(self): 79 """mcp_servers set to non-dict value -> empty dict.""" 80 with patch("hermes_cli.config.load_config", return_value={"mcp_servers": "invalid"}): 81 from tools.mcp_tool import _load_mcp_config 82 result = _load_mcp_config() 83 assert result == {} 84 85 86 # --------------------------------------------------------------------------- 87 # Schema conversion 88 # --------------------------------------------------------------------------- 89 90 class TestSchemaConversion: 91 def test_converts_mcp_tool_to_hermes_schema(self): 92 from tools.mcp_tool import _convert_mcp_schema 93 94 mcp_tool = _make_mcp_tool(name="read_file", description="Read a file") 95 schema = _convert_mcp_schema("filesystem", mcp_tool) 96 97 assert schema["name"] == "mcp_filesystem_read_file" 98 assert schema["description"] == "Read a file" 99 assert "properties" in schema["parameters"] 100 101 def test_empty_input_schema_gets_default(self): 102 from tools.mcp_tool import _convert_mcp_schema 103 104 mcp_tool = _make_mcp_tool(name="ping", description="Ping", input_schema=None) 105 mcp_tool.inputSchema = None 106 schema = _convert_mcp_schema("test", mcp_tool) 107 108 assert schema["parameters"]["type"] == "object" 109 assert schema["parameters"]["properties"] == {} 110 111 def test_object_schema_without_properties_gets_normalized(self): 112 from tools.mcp_tool import _convert_mcp_schema 113 114 mcp_tool = _make_mcp_tool( 115 name="ask", 116 description="Ask Crawl4AI", 117 input_schema={"type": "object"}, 118 ) 119 schema = _convert_mcp_schema("crawl4ai", mcp_tool) 120 121 assert schema["parameters"] == {"type": "object", "properties": {}} 122 123 def test_definitions_refs_are_rewritten_to_defs(self): 124 from tools.mcp_tool import _convert_mcp_schema 125 126 mcp_tool = _make_mcp_tool( 127 name="submit", 128 description="Submit a payload", 129 input_schema={ 130 "type": "object", 131 "properties": { 132 "input": {"$ref": "#/definitions/Payload"}, 133 }, 134 "required": ["input"], 135 "definitions": { 136 "Payload": { 137 "type": "object", 138 "properties": { 139 "query": {"type": "string"}, 140 }, 141 "required": ["query"], 142 } 143 }, 144 }, 145 ) 146 147 schema = _convert_mcp_schema("forms", mcp_tool) 148 149 assert schema["parameters"]["properties"]["input"]["$ref"] == "#/$defs/Payload" 150 assert "$defs" in schema["parameters"] 151 assert "definitions" not in schema["parameters"] 152 153 def test_nested_definition_refs_are_rewritten_recursively(self): 154 from tools.mcp_tool import _convert_mcp_schema 155 156 mcp_tool = _make_mcp_tool( 157 name="nested", 158 description="Nested schema", 159 input_schema={ 160 "type": "object", 161 "properties": { 162 "items": { 163 "type": "array", 164 "items": {"$ref": "#/definitions/Entry"}, 165 }, 166 }, 167 "definitions": { 168 "Entry": { 169 "type": "object", 170 "properties": { 171 "child": {"$ref": "#/definitions/Child"}, 172 }, 173 }, 174 "Child": { 175 "type": "object", 176 "properties": { 177 "value": {"type": "string"}, 178 }, 179 }, 180 }, 181 }, 182 ) 183 184 schema = _convert_mcp_schema("forms", mcp_tool) 185 186 assert schema["parameters"]["properties"]["items"]["items"]["$ref"] == "#/$defs/Entry" 187 assert schema["parameters"]["$defs"]["Entry"]["properties"]["child"]["$ref"] == "#/$defs/Child" 188 189 def test_missing_type_on_object_is_coerced(self): 190 """Schemas that describe an object but omit ``type`` get type='object'.""" 191 from tools.mcp_tool import _normalize_mcp_input_schema 192 193 schema = _normalize_mcp_input_schema({ 194 "properties": {"q": {"type": "string"}}, 195 "required": ["q"], 196 }) 197 198 assert schema["type"] == "object" 199 assert schema["properties"]["q"]["type"] == "string" 200 assert schema["required"] == ["q"] 201 202 def test_null_type_on_object_is_coerced(self): 203 """type: None should be treated like missing type (common MCP server bug).""" 204 from tools.mcp_tool import _normalize_mcp_input_schema 205 206 schema = _normalize_mcp_input_schema({ 207 "type": None, 208 "properties": {"x": {"type": "integer"}}, 209 }) 210 211 assert schema["type"] == "object" 212 213 def test_required_pruned_when_property_missing(self): 214 """Gemini 400s on required names that don't exist in properties.""" 215 from tools.mcp_tool import _normalize_mcp_input_schema 216 217 schema = _normalize_mcp_input_schema({ 218 "type": "object", 219 "properties": {"a": {"type": "string"}}, 220 "required": ["a", "ghost", "phantom"], 221 }) 222 223 assert schema["required"] == ["a"] 224 225 def test_required_removed_when_all_names_dangle(self): 226 from tools.mcp_tool import _normalize_mcp_input_schema 227 228 schema = _normalize_mcp_input_schema({ 229 "type": "object", 230 "properties": {}, 231 "required": ["ghost"], 232 }) 233 234 assert "required" not in schema 235 236 def test_required_pruning_applies_recursively_inside_nested_objects(self): 237 """Nested object schemas also get required pruning.""" 238 from tools.mcp_tool import _normalize_mcp_input_schema 239 240 schema = _normalize_mcp_input_schema({ 241 "type": "object", 242 "properties": { 243 "filter": { 244 "type": "object", 245 "properties": {"field": {"type": "string"}}, 246 "required": ["field", "missing"], 247 }, 248 }, 249 }) 250 251 assert schema["properties"]["filter"]["required"] == ["field"] 252 253 def test_object_in_array_items_gets_properties_filled(self): 254 """Array-item object schemas without properties get an empty dict.""" 255 from tools.mcp_tool import _normalize_mcp_input_schema 256 257 schema = _normalize_mcp_input_schema({ 258 "type": "object", 259 "properties": { 260 "items": { 261 "type": "array", 262 "items": {"type": "object"}, 263 }, 264 }, 265 }) 266 267 assert schema["properties"]["items"]["items"]["properties"] == {} 268 269 def test_optional_nullable_field_is_collapsed_to_non_null_schema(self): 270 """Anthropic rejects MCP/Pydantic anyOf-null optional parameter schemas.""" 271 from tools.mcp_tool import _normalize_mcp_input_schema 272 273 schema = _normalize_mcp_input_schema({ 274 "type": "object", 275 "properties": { 276 "command": {"type": "string"}, 277 "workdir": { 278 "anyOf": [{"type": "string"}, {"type": "null"}], 279 "default": None, 280 "description": "Optional working directory", 281 }, 282 }, 283 "required": ["command"], 284 }) 285 286 assert schema["properties"]["workdir"] == { 287 "type": "string", 288 "nullable": True, 289 "default": None, 290 "description": "Optional working directory", 291 } 292 assert schema["required"] == ["command"] 293 294 def test_nested_nullable_array_items_are_collapsed(self): 295 from tools.mcp_tool import _normalize_mcp_input_schema 296 297 schema = _normalize_mcp_input_schema({ 298 "type": "object", 299 "properties": { 300 "filters": { 301 "type": "array", 302 "items": { 303 "oneOf": [ 304 { 305 "type": "object", 306 "properties": {"field": {"type": "string"}}, 307 }, 308 {"type": "null"}, 309 ] 310 }, 311 } 312 }, 313 }) 314 315 assert schema["properties"]["filters"]["items"] == { 316 "type": "object", 317 "properties": {"field": {"type": "string"}}, 318 "nullable": True, 319 } 320 321 def test_convert_mcp_schema_survives_missing_inputschema_attribute(self): 322 """A Tool object without .inputSchema must not crash registration.""" 323 import types 324 325 from tools.mcp_tool import _convert_mcp_schema 326 327 bare_tool = types.SimpleNamespace(name="probe", description="Probe") 328 schema = _convert_mcp_schema("srv", bare_tool) 329 330 assert schema["name"] == "mcp_srv_probe" 331 assert schema["parameters"] == {"type": "object", "properties": {}} 332 333 def test_convert_mcp_schema_with_none_inputschema(self): 334 """Tool with inputSchema=None produces a valid empty object schema.""" 335 import types 336 337 from tools.mcp_tool import _convert_mcp_schema 338 339 # Note: _make_mcp_tool(input_schema=None) falls back to a default — 340 # build the namespace directly so .inputSchema really is None. 341 mcp_tool = types.SimpleNamespace(name="probe", description="Probe", inputSchema=None) 342 schema = _convert_mcp_schema("srv", mcp_tool) 343 344 assert schema["parameters"] == {"type": "object", "properties": {}} 345 346 def test_tool_name_prefix_format(self): 347 from tools.mcp_tool import _convert_mcp_schema 348 349 mcp_tool = _make_mcp_tool(name="list_dir") 350 schema = _convert_mcp_schema("my_server", mcp_tool) 351 352 assert schema["name"] == "mcp_my_server_list_dir" 353 354 def test_hyphens_sanitized_to_underscores(self): 355 """Hyphens in tool/server names are replaced with underscores for LLM compat.""" 356 from tools.mcp_tool import _convert_mcp_schema 357 358 mcp_tool = _make_mcp_tool(name="get-sum") 359 schema = _convert_mcp_schema("my-server", mcp_tool) 360 361 assert schema["name"] == "mcp_my_server_get_sum" 362 assert "-" not in schema["name"] 363 364 365 # --------------------------------------------------------------------------- 366 # Check function 367 # --------------------------------------------------------------------------- 368 369 class TestCheckFunction: 370 def test_disconnected_returns_false(self): 371 from tools.mcp_tool import _make_check_fn, _servers 372 373 _servers.pop("test_server", None) 374 check = _make_check_fn("test_server") 375 assert check() is False 376 377 def test_connected_returns_true(self): 378 from tools.mcp_tool import _make_check_fn, _servers 379 380 server = _make_mock_server("test_server", session=MagicMock()) 381 _servers["test_server"] = server 382 try: 383 check = _make_check_fn("test_server") 384 assert check() is True 385 finally: 386 _servers.pop("test_server", None) 387 388 def test_session_none_returns_false(self): 389 from tools.mcp_tool import _make_check_fn, _servers 390 391 server = _make_mock_server("test_server", session=None) 392 _servers["test_server"] = server 393 try: 394 check = _make_check_fn("test_server") 395 assert check() is False 396 finally: 397 _servers.pop("test_server", None) 398 399 400 # --------------------------------------------------------------------------- 401 # Tool handler 402 # --------------------------------------------------------------------------- 403 404 class TestToolHandler: 405 """Tool handlers are sync functions that schedule work on the MCP loop.""" 406 407 def _patch_mcp_loop(self, coro_side_effect=None): 408 """Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" 409 def fake_run(coro, timeout=30): 410 return asyncio.run(coro) 411 if coro_side_effect: 412 return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=coro_side_effect) 413 return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) 414 415 def test_successful_call(self): 416 from tools.mcp_tool import _make_tool_handler, _servers 417 418 mock_session = MagicMock() 419 mock_session.call_tool = AsyncMock( 420 return_value=_make_call_result("hello world", is_error=False) 421 ) 422 server = _make_mock_server("test_srv", session=mock_session) 423 _servers["test_srv"] = server 424 425 try: 426 handler = _make_tool_handler("test_srv", "greet", 120) 427 with self._patch_mcp_loop(): 428 result = json.loads(handler({"name": "world"})) 429 assert result["result"] == "hello world" 430 mock_session.call_tool.assert_called_once_with("greet", arguments={"name": "world"}) 431 finally: 432 _servers.pop("test_srv", None) 433 434 def test_mcp_error_result(self): 435 from tools.mcp_tool import _make_tool_handler, _servers 436 437 mock_session = MagicMock() 438 mock_session.call_tool = AsyncMock( 439 return_value=_make_call_result("something went wrong", is_error=True) 440 ) 441 server = _make_mock_server("test_srv", session=mock_session) 442 _servers["test_srv"] = server 443 444 try: 445 handler = _make_tool_handler("test_srv", "fail_tool", 120) 446 with self._patch_mcp_loop(): 447 result = json.loads(handler({})) 448 assert "error" in result 449 assert "something went wrong" in result["error"] 450 finally: 451 _servers.pop("test_srv", None) 452 453 def test_disconnected_server(self): 454 from tools.mcp_tool import _make_tool_handler, _servers 455 456 _servers.pop("ghost", None) 457 handler = _make_tool_handler("ghost", "any_tool", 120) 458 result = json.loads(handler({})) 459 assert "error" in result 460 assert "not connected" in result["error"] 461 462 def test_exception_during_call(self): 463 from tools.mcp_tool import _make_tool_handler, _servers 464 465 mock_session = MagicMock() 466 mock_session.call_tool = AsyncMock(side_effect=RuntimeError("connection lost")) 467 server = _make_mock_server("test_srv", session=mock_session) 468 _servers["test_srv"] = server 469 470 try: 471 handler = _make_tool_handler("test_srv", "broken_tool", 120) 472 with self._patch_mcp_loop(): 473 result = json.loads(handler({})) 474 assert "error" in result 475 assert "connection lost" in result["error"] 476 finally: 477 _servers.pop("test_srv", None) 478 479 def test_interrupted_call_returns_interrupted_error(self): 480 from tools.mcp_tool import _make_tool_handler, _servers 481 482 mock_session = MagicMock() 483 server = _make_mock_server("test_srv", session=mock_session) 484 _servers["test_srv"] = server 485 486 try: 487 handler = _make_tool_handler("test_srv", "greet", 120) 488 def _interrupting_run(coro, timeout=30): 489 coro.close() 490 raise InterruptedError("User sent a new message") 491 with patch( 492 "tools.mcp_tool._run_on_mcp_loop", 493 side_effect=_interrupting_run, 494 ): 495 result = json.loads(handler({})) 496 assert result == {"error": "MCP call interrupted: user sent a new message"} 497 finally: 498 _servers.pop("test_srv", None) 499 500 501 class TestRunOnMCPLoopInterrupts: 502 def test_interrupt_cancels_waiting_mcp_call(self): 503 import tools.mcp_tool as mcp_mod 504 from tools.interrupt import set_interrupt 505 506 loop = asyncio.new_event_loop() 507 thread = threading.Thread(target=loop.run_forever, daemon=True) 508 thread.start() 509 510 cancelled = threading.Event() 511 512 async def _slow_call(): 513 try: 514 await asyncio.sleep(5) 515 return "done" 516 except asyncio.CancelledError: 517 cancelled.set() 518 raise 519 520 old_loop = mcp_mod._mcp_loop 521 old_thread = mcp_mod._mcp_thread 522 mcp_mod._mcp_loop = loop 523 mcp_mod._mcp_thread = thread 524 525 waiter_tid = threading.current_thread().ident 526 527 def _interrupt_soon(): 528 time.sleep(0.2) 529 set_interrupt(True, waiter_tid) 530 531 interrupter = threading.Thread(target=_interrupt_soon, daemon=True) 532 interrupter.start() 533 534 try: 535 with pytest.raises(InterruptedError, match="User sent a new message"): 536 mcp_mod._run_on_mcp_loop(_slow_call(), timeout=2) 537 538 deadline = time.time() + 2 539 while time.time() < deadline and not cancelled.is_set(): 540 time.sleep(0.05) 541 assert cancelled.is_set() 542 finally: 543 set_interrupt(False, waiter_tid) 544 loop.call_soon_threadsafe(loop.stop) 545 thread.join(timeout=2) 546 loop.close() 547 mcp_mod._mcp_loop = old_loop 548 mcp_mod._mcp_thread = old_thread 549 550 551 # --------------------------------------------------------------------------- 552 # Tool registration (discovery + register) 553 # --------------------------------------------------------------------------- 554 555 class TestDiscoverAndRegister: 556 def test_tools_registered_in_registry(self): 557 """_discover_and_register_server registers tools with correct names.""" 558 from tools.registry import ToolRegistry 559 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 560 561 mock_registry = ToolRegistry() 562 mock_tools = [ 563 _make_mcp_tool("read_file", "Read a file"), 564 _make_mcp_tool("write_file", "Write a file"), 565 ] 566 mock_session = MagicMock() 567 568 async def fake_connect(name, config): 569 server = MCPServerTask(name) 570 server.session = mock_session 571 server._tools = mock_tools 572 return server 573 574 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 575 patch("tools.registry.registry", mock_registry): 576 registered = asyncio.run( 577 _discover_and_register_server("fs", {"command": "npx", "args": []}) 578 ) 579 580 assert "mcp_fs_read_file" in registered 581 assert "mcp_fs_write_file" in registered 582 assert "mcp_fs_read_file" in mock_registry.get_all_tool_names() 583 assert "mcp_fs_write_file" in mock_registry.get_all_tool_names() 584 585 _servers.pop("fs", None) 586 587 def test_toolset_resolves_live_from_registry(self): 588 """MCP toolsets resolve through the live registry without TOOLSETS mutation.""" 589 from tools.registry import ToolRegistry 590 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 591 from toolsets import resolve_toolset, validate_toolset 592 593 mock_registry = ToolRegistry() 594 mock_tools = [_make_mcp_tool("ping", "Ping")] 595 mock_session = MagicMock() 596 597 async def fake_connect(name, config): 598 server = MCPServerTask(name) 599 server.session = mock_session 600 server._tools = mock_tools 601 return server 602 603 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 604 patch("tools.registry.registry", mock_registry): 605 asyncio.run( 606 _discover_and_register_server("myserver", {"command": "test"}) 607 ) 608 609 assert validate_toolset("myserver") is True 610 assert validate_toolset("mcp-myserver") is True 611 assert "mcp_myserver_ping" in resolve_toolset("myserver") 612 assert "mcp_myserver_ping" in resolve_toolset("mcp-myserver") 613 614 _servers.pop("myserver", None) 615 616 def test_schema_format_correct(self): 617 """Registered schemas have the correct format.""" 618 from tools.registry import ToolRegistry 619 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 620 621 mock_registry = ToolRegistry() 622 mock_tools = [_make_mcp_tool("do_thing", "Do something")] 623 mock_session = MagicMock() 624 625 async def fake_connect(name, config): 626 server = MCPServerTask(name) 627 server.session = mock_session 628 server._tools = mock_tools 629 return server 630 631 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 632 patch("tools.registry.registry", mock_registry): 633 asyncio.run( 634 _discover_and_register_server("srv", {"command": "test"}) 635 ) 636 637 entry = mock_registry._tools.get("mcp_srv_do_thing") 638 assert entry is not None 639 assert entry.schema["name"] == "mcp_srv_do_thing" 640 assert "parameters" in entry.schema 641 assert entry.is_async is False 642 assert entry.toolset == "mcp-srv" 643 644 _servers.pop("srv", None) 645 646 647 # --------------------------------------------------------------------------- 648 # MCPServerTask (run / start / shutdown) 649 # --------------------------------------------------------------------------- 650 651 class TestMCPServerTask: 652 """Test the MCPServerTask lifecycle with mocked MCP SDK.""" 653 654 def _mock_stdio_and_session(self, session): 655 """Return patches for stdio_client and ClientSession as async CMs.""" 656 mock_read, mock_write = MagicMock(), MagicMock() 657 658 mock_stdio_cm = MagicMock() 659 mock_stdio_cm.__aenter__ = AsyncMock(return_value=(mock_read, mock_write)) 660 mock_stdio_cm.__aexit__ = AsyncMock(return_value=False) 661 662 mock_cs_cm = MagicMock() 663 mock_cs_cm.__aenter__ = AsyncMock(return_value=session) 664 mock_cs_cm.__aexit__ = AsyncMock(return_value=False) 665 666 return ( 667 patch("tools.mcp_tool.stdio_client", return_value=mock_stdio_cm), 668 patch("tools.mcp_tool.ClientSession", return_value=mock_cs_cm), 669 mock_read, mock_write, 670 ) 671 672 def test_start_connects_and_discovers_tools(self): 673 """start() creates a Task that connects, discovers tools, and waits.""" 674 from tools.mcp_tool import MCPServerTask 675 676 mock_tools = [_make_mcp_tool("echo")] 677 mock_session = MagicMock() 678 mock_session.initialize = AsyncMock() 679 mock_session.list_tools = AsyncMock( 680 return_value=SimpleNamespace(tools=mock_tools) 681 ) 682 683 p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) 684 685 async def _test(): 686 with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs: 687 server = MCPServerTask("test_srv") 688 await server.start({"command": "npx", "args": ["-y", "test"]}) 689 690 assert server.session is mock_session 691 assert len(server._tools) == 1 692 assert server._tools[0].name == "echo" 693 mock_session.initialize.assert_called_once() 694 695 await server.shutdown() 696 assert server.session is None 697 698 asyncio.run(_test()) 699 700 def test_no_command_raises(self): 701 """Missing 'command' in config raises ValueError.""" 702 from tools.mcp_tool import MCPServerTask 703 704 async def _test(): 705 server = MCPServerTask("bad") 706 with pytest.raises(ValueError, match="no 'command'"): 707 await server.start({"args": []}) 708 709 asyncio.run(_test()) 710 711 def test_refresh_tools_deregisters_removed_tools(self): 712 """Dynamic refresh removes stale registry entries for deleted tools.""" 713 from tools.registry import ToolRegistry 714 from tools.mcp_tool import MCPServerTask 715 716 mock_registry = ToolRegistry() 717 server = MCPServerTask("srv") 718 server._config = {"command": "test"} 719 server._tools = [_make_mcp_tool("old"), _make_mcp_tool("keep")] 720 server._registered_tool_names = ["mcp_srv_old", "mcp_srv_keep"] 721 server.session = MagicMock() 722 server.session.list_tools = AsyncMock( 723 return_value=SimpleNamespace(tools=[_make_mcp_tool("keep"), _make_mcp_tool("new")]) 724 ) 725 726 with patch("tools.registry.registry", mock_registry): 727 mock_registry.register( 728 name="mcp_srv_old", 729 toolset="mcp-srv", 730 schema={"name": "mcp_srv_old", "description": "Old"}, 731 handler=lambda *_args, **_kwargs: "{}", 732 ) 733 mock_registry.register( 734 name="mcp_srv_keep", 735 toolset="mcp-srv", 736 schema={"name": "mcp_srv_keep", "description": "Keep"}, 737 handler=lambda *_args, **_kwargs: "{}", 738 ) 739 740 asyncio.run(server._refresh_tools()) 741 742 names = mock_registry.get_all_tool_names() 743 assert "mcp_srv_old" not in names 744 assert "mcp_srv_keep" in names 745 assert "mcp_srv_new" in names 746 assert set(server._registered_tool_names) == { 747 "mcp_srv_keep", 748 "mcp_srv_new", 749 "mcp_srv_list_resources", 750 "mcp_srv_read_resource", 751 "mcp_srv_list_prompts", 752 "mcp_srv_get_prompt", 753 } 754 755 def test_schedule_tools_refresh_keeps_task_until_done(self): 756 """Background refresh tasks are strongly referenced and then discarded.""" 757 from tools.mcp_tool import MCPServerTask 758 759 async def _test(): 760 started = asyncio.Event() 761 finish = asyncio.Event() 762 server = MCPServerTask("srv") 763 764 async def fake_refresh(_server): 765 started.set() 766 await finish.wait() 767 768 with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh): 769 server._schedule_tools_refresh() 770 771 await started.wait() 772 assert len(server._pending_refresh_tasks) == 1 773 task = next(iter(server._pending_refresh_tasks)) 774 assert not task.done() 775 776 finish.set() 777 await task 778 await asyncio.sleep(0) 779 assert server._pending_refresh_tasks == set() 780 781 asyncio.run(_test()) 782 783 def test_shutdown_cancels_pending_refresh_tasks(self): 784 """shutdown() cancels in-flight background refresh tasks.""" 785 from tools.mcp_tool import MCPServerTask 786 787 async def _test(): 788 started = asyncio.Event() 789 cancelled = asyncio.Event() 790 server = MCPServerTask("srv") 791 792 async def fake_refresh(_server): 793 started.set() 794 try: 795 await asyncio.sleep(3600) 796 except asyncio.CancelledError: 797 cancelled.set() 798 raise 799 800 with patch.object(MCPServerTask, "_refresh_tools", new=fake_refresh): 801 server._schedule_tools_refresh() 802 await started.wait() 803 804 await server.shutdown() 805 806 assert cancelled.is_set() 807 assert server._pending_refresh_tasks == set() 808 809 asyncio.run(_test()) 810 811 def test_empty_env_gets_safe_defaults(self): 812 """Empty env dict gets safe default env vars (PATH, HOME, etc.).""" 813 from tools.mcp_tool import MCPServerTask 814 815 mock_session = MagicMock() 816 mock_session.initialize = AsyncMock() 817 mock_session.list_tools = AsyncMock( 818 return_value=SimpleNamespace(tools=[]) 819 ) 820 821 p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) 822 823 async def _test(): 824 with patch("tools.mcp_tool.StdioServerParameters") as mock_params, \ 825 p_stdio, p_cs, \ 826 patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/home/test"}, clear=False): 827 server = MCPServerTask("srv") 828 await server.start({"command": "node", "env": {}}) 829 830 # Empty dict -> safe env vars (not None) 831 call_kwargs = mock_params.call_args 832 env_arg = call_kwargs.kwargs.get("env") 833 assert env_arg is not None 834 assert isinstance(env_arg, dict) 835 assert "PATH" in env_arg 836 assert "HOME" in env_arg 837 838 await server.shutdown() 839 840 asyncio.run(_test()) 841 842 def test_shutdown_signals_task_exit(self): 843 """shutdown() signals the event and waits for task completion.""" 844 from tools.mcp_tool import MCPServerTask 845 846 mock_session = MagicMock() 847 mock_session.initialize = AsyncMock() 848 mock_session.list_tools = AsyncMock( 849 return_value=SimpleNamespace(tools=[]) 850 ) 851 852 p_stdio, p_cs, _, _ = self._mock_stdio_and_session(mock_session) 853 854 async def _test(): 855 with patch("tools.mcp_tool.StdioServerParameters"), p_stdio, p_cs: 856 server = MCPServerTask("srv") 857 await server.start({"command": "npx"}) 858 859 assert server.session is not None 860 assert not server._task.done() 861 862 await server.shutdown() 863 864 assert server.session is None 865 assert server._task.done() 866 867 asyncio.run(_test()) 868 869 870 # --------------------------------------------------------------------------- 871 # discover_mcp_tools toolset injection 872 # --------------------------------------------------------------------------- 873 874 class TestToolsetInjection: 875 def test_mcp_tools_resolve_through_server_aliases(self): 876 """Discovered MCP tools resolve through raw server-name aliases.""" 877 from tools.mcp_tool import MCPServerTask 878 from tools.registry import ToolRegistry 879 from toolsets import resolve_toolset, validate_toolset 880 881 mock_tools = [_make_mcp_tool("list_files", "List files")] 882 mock_session = MagicMock() 883 mock_registry = ToolRegistry() 884 885 fresh_servers = {} 886 887 async def fake_connect(name, config): 888 server = MCPServerTask(name) 889 server.session = mock_session 890 server._tools = mock_tools 891 return server 892 893 fake_config = {"fs": {"command": "npx", "args": []}} 894 895 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 896 patch("tools.mcp_tool._servers", fresh_servers), \ 897 patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 898 patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 899 patch("tools.registry.registry", mock_registry): 900 from tools.mcp_tool import discover_mcp_tools 901 result = discover_mcp_tools() 902 903 assert "mcp_fs_list_files" in result 904 assert validate_toolset("fs") is True 905 assert validate_toolset("mcp-fs") is True 906 assert "mcp_fs_list_files" in resolve_toolset("fs") 907 assert "mcp_fs_list_files" in resolve_toolset("mcp-fs") 908 909 def test_server_toolset_skips_builtin_collision(self): 910 """MCP raw aliases never overwrite a built-in toolset name.""" 911 from tools.mcp_tool import MCPServerTask 912 from tools.registry import ToolRegistry 913 from toolsets import resolve_toolset, validate_toolset 914 915 mock_tools = [_make_mcp_tool("run", "Run command")] 916 mock_session = MagicMock() 917 fresh_servers = {} 918 mock_registry = ToolRegistry() 919 920 async def fake_connect(name, config): 921 server = MCPServerTask(name) 922 server.session = mock_session 923 server._tools = mock_tools 924 return server 925 926 fake_toolsets = { 927 "hermes-cli": {"tools": ["terminal"], "description": "CLI", "includes": []}, 928 # Built-in toolset named "terminal" — must not be overwritten 929 "terminal": {"tools": ["terminal"], "description": "Terminal tools", "includes": []}, 930 } 931 fake_config = {"terminal": {"command": "npx", "args": []}} 932 933 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 934 patch("tools.mcp_tool._servers", fresh_servers), \ 935 patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 936 patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 937 patch("tools.registry.registry", mock_registry), \ 938 patch("toolsets.TOOLSETS", fake_toolsets): 939 from tools.mcp_tool import discover_mcp_tools 940 discover_mcp_tools() 941 942 assert fake_toolsets["terminal"]["description"] == "Terminal tools" 943 assert "mcp_terminal_run" not in resolve_toolset("terminal") 944 assert validate_toolset("mcp-terminal") is True 945 assert "mcp_terminal_run" in resolve_toolset("mcp-terminal") 946 947 def test_server_connection_failure_skipped(self): 948 """If one server fails to connect, others still proceed.""" 949 from tools.mcp_tool import MCPServerTask 950 951 mock_tools = [_make_mcp_tool("ping", "Ping")] 952 mock_session = MagicMock() 953 954 fresh_servers = {} 955 call_count = 0 956 957 async def flaky_connect(name, config): 958 nonlocal call_count 959 call_count += 1 960 if name == "broken": 961 raise ConnectionError("cannot reach server") 962 server = MCPServerTask(name) 963 server.session = mock_session 964 server._tools = mock_tools 965 return server 966 967 fake_config = { 968 "broken": {"command": "bad"}, 969 "good": {"command": "npx", "args": []}, 970 } 971 fake_toolsets = { 972 "hermes-cli": {"tools": [], "description": "CLI", "includes": []}, 973 } 974 975 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 976 patch("tools.mcp_tool._servers", fresh_servers), \ 977 patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 978 patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \ 979 patch("toolsets.TOOLSETS", fake_toolsets): 980 from tools.mcp_tool import discover_mcp_tools 981 result = discover_mcp_tools() 982 983 assert "mcp_good_ping" in result 984 assert "mcp_broken_ping" not in result 985 assert call_count == 2 986 987 def test_partial_failure_retry_on_second_call(self): 988 """Failed servers are retried on subsequent discover_mcp_tools() calls.""" 989 from tools.mcp_tool import MCPServerTask 990 991 mock_tools = [_make_mcp_tool("ping", "Ping")] 992 mock_session = MagicMock() 993 994 # Use a real dict so idempotency logic works correctly 995 fresh_servers = {} 996 call_count = 0 997 broken_fixed = False 998 999 async def flaky_connect(name, config): 1000 nonlocal call_count 1001 call_count += 1 1002 if name == "broken" and not broken_fixed: 1003 raise ConnectionError("cannot reach server") 1004 server = MCPServerTask(name) 1005 server.session = mock_session 1006 server._tools = mock_tools 1007 return server 1008 1009 fake_config = { 1010 "broken": {"command": "bad"}, 1011 "good": {"command": "npx", "args": []}, 1012 } 1013 fake_toolsets = { 1014 "hermes-cli": {"tools": [], "description": "CLI", "includes": []}, 1015 } 1016 1017 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 1018 patch("tools.mcp_tool._servers", fresh_servers), \ 1019 patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 1020 patch("tools.mcp_tool._connect_server", side_effect=flaky_connect), \ 1021 patch("toolsets.TOOLSETS", fake_toolsets): 1022 from tools.mcp_tool import discover_mcp_tools 1023 1024 # First call: good connects, broken fails 1025 result1 = discover_mcp_tools() 1026 assert "mcp_good_ping" in result1 1027 assert "mcp_broken_ping" not in result1 1028 first_attempts = call_count 1029 1030 # "Fix" the broken server 1031 broken_fixed = True 1032 call_count = 0 1033 1034 # Second call: should retry broken, skip good 1035 result2 = discover_mcp_tools() 1036 assert "mcp_good_ping" in result2 1037 assert "mcp_broken_ping" in result2 1038 assert call_count == 1 # Only broken retried 1039 1040 1041 # --------------------------------------------------------------------------- 1042 # Graceful fallback 1043 # --------------------------------------------------------------------------- 1044 1045 class TestGracefulFallback: 1046 def test_mcp_unavailable_returns_empty(self): 1047 """When _MCP_AVAILABLE is False, discover_mcp_tools is a no-op.""" 1048 with patch("tools.mcp_tool._MCP_AVAILABLE", False): 1049 from tools.mcp_tool import discover_mcp_tools 1050 result = discover_mcp_tools() 1051 assert result == [] 1052 1053 def test_no_servers_returns_empty(self): 1054 """No MCP servers configured -> empty list.""" 1055 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 1056 patch("tools.mcp_tool._servers", {}), \ 1057 patch("tools.mcp_tool._load_mcp_config", return_value={}): 1058 from tools.mcp_tool import discover_mcp_tools 1059 result = discover_mcp_tools() 1060 assert result == [] 1061 1062 1063 # --------------------------------------------------------------------------- 1064 # Shutdown (public API) 1065 # --------------------------------------------------------------------------- 1066 1067 class TestShutdown: 1068 def test_no_servers_safe(self): 1069 """shutdown_mcp_servers with no servers does nothing.""" 1070 from tools.mcp_tool import shutdown_mcp_servers, _servers 1071 1072 _servers.clear() 1073 shutdown_mcp_servers() # Should not raise 1074 1075 def test_shutdown_clears_servers(self): 1076 """shutdown_mcp_servers calls shutdown() on each server and clears dict.""" 1077 import tools.mcp_tool as mcp_mod 1078 from tools.mcp_tool import shutdown_mcp_servers, _servers 1079 1080 _servers.clear() 1081 mock_server = MagicMock() 1082 mock_server.name = "test" 1083 mock_server.shutdown = AsyncMock() 1084 _servers["test"] = mock_server 1085 1086 mcp_mod._ensure_mcp_loop() 1087 try: 1088 shutdown_mcp_servers() 1089 finally: 1090 mcp_mod._mcp_loop = None 1091 mcp_mod._mcp_thread = None 1092 1093 assert len(_servers) == 0 1094 mock_server.shutdown.assert_called_once() 1095 1096 def test_shutdown_deregisters_registered_tools(self): 1097 """shutdown_mcp_servers removes MCP tools and their raw alias.""" 1098 import tools.mcp_tool as mcp_mod 1099 from tools.mcp_tool import MCPServerTask, shutdown_mcp_servers, _servers 1100 from tools.registry import registry 1101 from toolsets import resolve_toolset, validate_toolset 1102 1103 _servers.clear() 1104 registry.register( 1105 name="mcp_test_ping", 1106 toolset="mcp-test", 1107 schema={ 1108 "name": "mcp_test_ping", 1109 "description": "Ping", 1110 "parameters": {"type": "object", "properties": {}}, 1111 }, 1112 handler=lambda *_args, **_kwargs: "{}", 1113 ) 1114 registry.register_toolset_alias("test", "mcp-test") 1115 1116 server = MCPServerTask("test") 1117 server._registered_tool_names = ["mcp_test_ping"] 1118 _servers["test"] = server 1119 1120 mcp_mod._ensure_mcp_loop() 1121 try: 1122 assert validate_toolset("test") is True 1123 assert "mcp_test_ping" in resolve_toolset("test") 1124 shutdown_mcp_servers() 1125 finally: 1126 mcp_mod._mcp_loop = None 1127 mcp_mod._mcp_thread = None 1128 1129 assert "mcp_test_ping" not in registry.get_all_tool_names() 1130 assert validate_toolset("test") is False 1131 1132 def test_shutdown_handles_errors(self): 1133 """shutdown_mcp_servers handles errors during close gracefully.""" 1134 import tools.mcp_tool as mcp_mod 1135 from tools.mcp_tool import shutdown_mcp_servers, _servers 1136 1137 _servers.clear() 1138 mock_server = MagicMock() 1139 mock_server.name = "broken" 1140 mock_server.shutdown = AsyncMock(side_effect=RuntimeError("close failed")) 1141 _servers["broken"] = mock_server 1142 1143 mcp_mod._ensure_mcp_loop() 1144 try: 1145 shutdown_mcp_servers() # Should not raise 1146 finally: 1147 mcp_mod._mcp_loop = None 1148 mcp_mod._mcp_thread = None 1149 1150 assert len(_servers) == 0 1151 1152 def test_shutdown_is_parallel(self): 1153 """Multiple servers are shut down in parallel via asyncio.gather.""" 1154 import tools.mcp_tool as mcp_mod 1155 from tools.mcp_tool import shutdown_mcp_servers, _servers 1156 import time 1157 1158 _servers.clear() 1159 1160 # 3 servers each taking 1s to shut down 1161 for i in range(3): 1162 mock_server = MagicMock() 1163 mock_server.name = f"srv_{i}" 1164 async def slow_shutdown(): 1165 await asyncio.sleep(1) 1166 mock_server.shutdown = slow_shutdown 1167 _servers[f"srv_{i}"] = mock_server 1168 1169 mcp_mod._ensure_mcp_loop() 1170 try: 1171 start = time.monotonic() 1172 shutdown_mcp_servers() 1173 elapsed = time.monotonic() - start 1174 finally: 1175 mcp_mod._mcp_loop = None 1176 mcp_mod._mcp_thread = None 1177 1178 assert len(_servers) == 0 1179 # Parallel: ~1s, not ~3s. Allow some margin. 1180 assert elapsed < 2.5, f"Shutdown took {elapsed:.1f}s, expected ~1s (parallel)" 1181 1182 1183 # --------------------------------------------------------------------------- 1184 # _build_safe_env 1185 # --------------------------------------------------------------------------- 1186 1187 class TestBuildSafeEnv: 1188 """Tests for _build_safe_env() environment filtering.""" 1189 1190 def test_only_safe_vars_passed(self): 1191 """Only safe baseline vars and XDG_* from os.environ are included.""" 1192 from tools.mcp_tool import _build_safe_env 1193 1194 fake_env = { 1195 "PATH": "/usr/bin", 1196 "HOME": "/home/test", 1197 "USER": "test", 1198 "LANG": "en_US.UTF-8", 1199 "LC_ALL": "C", 1200 "TERM": "xterm", 1201 "SHELL": "/bin/bash", 1202 "TMPDIR": "/tmp", 1203 "XDG_DATA_HOME": "/home/test/.local/share", 1204 "SECRET_KEY": "should_not_appear", 1205 "AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE", 1206 } 1207 with patch.dict("os.environ", fake_env, clear=True): 1208 result = _build_safe_env(None) 1209 1210 # Safe vars present 1211 assert result["PATH"] == "/usr/bin" 1212 assert result["HOME"] == "/home/test" 1213 assert result["USER"] == "test" 1214 assert result["LANG"] == "en_US.UTF-8" 1215 assert result["XDG_DATA_HOME"] == "/home/test/.local/share" 1216 # Unsafe vars excluded 1217 assert "SECRET_KEY" not in result 1218 assert "AWS_ACCESS_KEY_ID" not in result 1219 1220 def test_user_env_merged(self): 1221 """User-specified env vars are merged into the safe env.""" 1222 from tools.mcp_tool import _build_safe_env 1223 1224 with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True): 1225 result = _build_safe_env({"MY_CUSTOM_VAR": "hello"}) 1226 1227 assert result["PATH"] == "/usr/bin" 1228 assert result["MY_CUSTOM_VAR"] == "hello" 1229 1230 def test_user_env_overrides_safe(self): 1231 """User env can override safe defaults.""" 1232 from tools.mcp_tool import _build_safe_env 1233 1234 with patch.dict("os.environ", {"PATH": "/usr/bin"}, clear=True): 1235 result = _build_safe_env({"PATH": "/custom/bin"}) 1236 1237 assert result["PATH"] == "/custom/bin" 1238 1239 def test_none_user_env(self): 1240 """None user_env still returns safe vars from os.environ.""" 1241 from tools.mcp_tool import _build_safe_env 1242 1243 with patch.dict("os.environ", {"PATH": "/usr/bin", "HOME": "/root"}, clear=True): 1244 result = _build_safe_env(None) 1245 1246 assert isinstance(result, dict) 1247 assert result["PATH"] == "/usr/bin" 1248 assert result["HOME"] == "/root" 1249 1250 def test_secret_vars_excluded(self): 1251 """Sensitive env vars from os.environ are NOT passed through.""" 1252 from tools.mcp_tool import _build_safe_env 1253 1254 fake_env = { 1255 "PATH": "/usr/bin", 1256 "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", 1257 "GITHUB_TOKEN": "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", 1258 "OPENAI_API_KEY": "sk-proj-abc123", 1259 "DATABASE_URL": "postgres://user:pass@localhost/db", 1260 "API_SECRET": "supersecret", 1261 } 1262 with patch.dict("os.environ", fake_env, clear=True): 1263 result = _build_safe_env(None) 1264 1265 assert "PATH" in result 1266 assert "AWS_SECRET_ACCESS_KEY" not in result 1267 assert "GITHUB_TOKEN" not in result 1268 assert "OPENAI_API_KEY" not in result 1269 assert "DATABASE_URL" not in result 1270 assert "API_SECRET" not in result 1271 1272 1273 # --------------------------------------------------------------------------- 1274 # _sanitize_error 1275 # --------------------------------------------------------------------------- 1276 1277 class TestSanitizeError: 1278 """Tests for _sanitize_error() credential stripping.""" 1279 1280 def test_strips_github_pat(self): 1281 from tools.mcp_tool import _sanitize_error 1282 result = _sanitize_error("Error with ghp_abc123def456") 1283 assert result == "Error with [REDACTED]" 1284 1285 def test_strips_openai_key(self): 1286 from tools.mcp_tool import _sanitize_error 1287 result = _sanitize_error("key sk-projABC123xyz") 1288 assert result == "key [REDACTED]" 1289 1290 def test_strips_bearer_token(self): 1291 from tools.mcp_tool import _sanitize_error 1292 result = _sanitize_error("Authorization: Bearer eyJabc123def") 1293 assert result == "Authorization: [REDACTED]" 1294 1295 def test_strips_token_param(self): 1296 from tools.mcp_tool import _sanitize_error 1297 result = _sanitize_error("url?token=secret123") 1298 assert result == "url?[REDACTED]" 1299 1300 def test_no_credentials_unchanged(self): 1301 from tools.mcp_tool import _sanitize_error 1302 result = _sanitize_error("normal error message") 1303 assert result == "normal error message" 1304 1305 def test_multiple_credentials(self): 1306 from tools.mcp_tool import _sanitize_error 1307 result = _sanitize_error("ghp_abc123 and sk-projXyz789 and token=foo") 1308 assert "ghp_" not in result 1309 assert "sk-" not in result 1310 assert "token=" not in result 1311 assert result.count("[REDACTED]") == 3 1312 1313 1314 # --------------------------------------------------------------------------- 1315 # HTTP config 1316 # --------------------------------------------------------------------------- 1317 1318 class TestHTTPConfig: 1319 """Tests for HTTP transport detection and handling.""" 1320 1321 def test_is_http_with_url(self): 1322 from tools.mcp_tool import MCPServerTask 1323 server = MCPServerTask("remote") 1324 server._config = {"url": "https://example.com/mcp"} 1325 assert server._is_http() is True 1326 1327 def test_is_stdio_with_command(self): 1328 from tools.mcp_tool import MCPServerTask 1329 server = MCPServerTask("local") 1330 server._config = {"command": "npx", "args": []} 1331 assert server._is_http() is False 1332 1333 def test_conflicting_url_and_command_warns(self): 1334 """Config with both url and command logs a warning and uses HTTP.""" 1335 from tools.mcp_tool import MCPServerTask 1336 server = MCPServerTask("conflict") 1337 config = {"url": "https://example.com/mcp", "command": "npx", "args": []} 1338 # url takes precedence 1339 server._config = config 1340 assert server._is_http() is True 1341 1342 def test_http_unavailable_raises(self): 1343 from tools.mcp_tool import MCPServerTask 1344 1345 server = MCPServerTask("remote") 1346 config = {"url": "https://example.com/mcp"} 1347 1348 async def _test(): 1349 with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", False): 1350 with pytest.raises(ImportError, match="HTTP transport"): 1351 await server._run_http(config) 1352 1353 asyncio.run(_test()) 1354 1355 def test_http_seeds_initial_protocol_header(self): 1356 from tools.mcp_tool import LATEST_PROTOCOL_VERSION, MCPServerTask 1357 1358 server = MCPServerTask("remote") 1359 captured = {} 1360 1361 class DummyAsyncClient: 1362 def __init__(self, **kwargs): 1363 captured.update(kwargs) 1364 1365 async def __aenter__(self): 1366 return self 1367 1368 async def __aexit__(self, exc_type, exc, tb): 1369 return False 1370 1371 class DummyTransportCtx: 1372 async def __aenter__(self): 1373 return MagicMock(), MagicMock(), (lambda: None) 1374 1375 async def __aexit__(self, exc_type, exc, tb): 1376 return False 1377 1378 class DummySession: 1379 def __init__(self, *args, **kwargs): 1380 pass 1381 1382 async def __aenter__(self): 1383 return self 1384 1385 async def __aexit__(self, exc_type, exc, tb): 1386 return False 1387 1388 async def initialize(self): 1389 return None 1390 1391 class DummyLegacyTransportCtx: 1392 def __init__(self, **kwargs): 1393 captured["legacy_headers"] = kwargs.get("headers") 1394 1395 async def __aenter__(self): 1396 return MagicMock(), MagicMock(), (lambda: None) 1397 1398 async def __aexit__(self, exc_type, exc, tb): 1399 return False 1400 1401 async def _discover_tools(self): 1402 self._shutdown_event.set() 1403 1404 async def _run(config, *, new_http): 1405 captured.clear() 1406 with patch("tools.mcp_tool._MCP_HTTP_AVAILABLE", True), \ 1407 patch("tools.mcp_tool._MCP_NEW_HTTP", new_http), \ 1408 patch("httpx.AsyncClient", DummyAsyncClient), \ 1409 patch("tools.mcp_tool.streamable_http_client", return_value=DummyTransportCtx()), \ 1410 patch("tools.mcp_tool.streamablehttp_client", side_effect=lambda url, **kwargs: DummyLegacyTransportCtx(**kwargs)), \ 1411 patch("tools.mcp_tool.ClientSession", DummySession), \ 1412 patch.object(MCPServerTask, "_discover_tools", _discover_tools): 1413 await server._run_http(config) 1414 1415 asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=True)) 1416 assert captured["headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION 1417 1418 asyncio.run(_run({ 1419 "url": "https://example.com/mcp", 1420 "headers": {"mcp-protocol-version": "custom-version"}, 1421 }, new_http=True)) 1422 assert captured["headers"]["mcp-protocol-version"] == "custom-version" 1423 1424 asyncio.run(_run({ 1425 "url": "https://example.com/mcp", 1426 "headers": {"MCP-Protocol-Version": "custom-version"}, 1427 }, new_http=True)) 1428 assert captured["headers"]["MCP-Protocol-Version"] == "custom-version" 1429 assert "mcp-protocol-version" not in captured["headers"] 1430 1431 asyncio.run(_run({"url": "https://example.com/mcp"}, new_http=False)) 1432 assert captured["legacy_headers"]["mcp-protocol-version"] == LATEST_PROTOCOL_VERSION 1433 1434 asyncio.run(_run({ 1435 "url": "https://example.com/mcp", 1436 "headers": {"MCP-Protocol-Version": "custom-version"}, 1437 }, new_http=False)) 1438 assert captured["legacy_headers"]["MCP-Protocol-Version"] == "custom-version" 1439 assert "mcp-protocol-version" not in captured["legacy_headers"] 1440 1441 1442 # --------------------------------------------------------------------------- 1443 # Reconnection logic 1444 # --------------------------------------------------------------------------- 1445 1446 class TestReconnection: 1447 """Tests for automatic reconnection behavior in MCPServerTask.run().""" 1448 1449 def test_reconnect_on_disconnect(self): 1450 """After initial success, a connection drop triggers reconnection.""" 1451 from tools.mcp_tool import MCPServerTask 1452 1453 run_count = 0 1454 target_server = None 1455 1456 original_run_stdio = MCPServerTask._run_stdio 1457 1458 async def patched_run_stdio(self_srv, config): 1459 nonlocal run_count, target_server 1460 run_count += 1 1461 if target_server is not self_srv: 1462 return await original_run_stdio(self_srv, config) 1463 if run_count == 1: 1464 # First connection succeeds, then simulate disconnect 1465 self_srv.session = MagicMock() 1466 self_srv._tools = [] 1467 self_srv._ready.set() 1468 raise ConnectionError("connection dropped") 1469 else: 1470 # Reconnection succeeds; signal shutdown so run() exits 1471 self_srv.session = MagicMock() 1472 self_srv._shutdown_event.set() 1473 await self_srv._shutdown_event.wait() 1474 1475 async def _test(): 1476 nonlocal target_server 1477 server = MCPServerTask("test_srv") 1478 target_server = server 1479 1480 with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \ 1481 patch("asyncio.sleep", new_callable=AsyncMock): 1482 await server.run({"command": "test"}) 1483 1484 assert run_count >= 2 # At least one reconnection attempt 1485 1486 asyncio.run(_test()) 1487 1488 def test_no_reconnect_on_shutdown(self): 1489 """If shutdown is requested, don't attempt reconnection.""" 1490 from tools.mcp_tool import MCPServerTask 1491 1492 run_count = 0 1493 target_server = None 1494 1495 original_run_stdio = MCPServerTask._run_stdio 1496 1497 async def patched_run_stdio(self_srv, config): 1498 nonlocal run_count, target_server 1499 run_count += 1 1500 if target_server is not self_srv: 1501 return await original_run_stdio(self_srv, config) 1502 self_srv.session = MagicMock() 1503 self_srv._tools = [] 1504 self_srv._ready.set() 1505 raise ConnectionError("connection dropped") 1506 1507 async def _test(): 1508 nonlocal target_server 1509 server = MCPServerTask("test_srv") 1510 target_server = server 1511 server._shutdown_event.set() # Shutdown already requested 1512 1513 with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \ 1514 patch("asyncio.sleep", new_callable=AsyncMock): 1515 await server.run({"command": "test"}) 1516 1517 # Should not retry because shutdown was set 1518 assert run_count == 1 1519 1520 asyncio.run(_test()) 1521 1522 def test_no_reconnect_on_initial_failure(self): 1523 """First connection failure retries up to _MAX_INITIAL_CONNECT_RETRIES times. 1524 1525 Before the MCP resilience fix, initial failures gave up immediately. 1526 Now they retry with backoff to handle transient DNS/network blips. 1527 """ 1528 from tools.mcp_tool import MCPServerTask, _MAX_INITIAL_CONNECT_RETRIES 1529 1530 run_count = 0 1531 target_server = None 1532 1533 original_run_stdio = MCPServerTask._run_stdio 1534 1535 async def patched_run_stdio(self_srv, config): 1536 nonlocal run_count, target_server 1537 run_count += 1 1538 if target_server is not self_srv: 1539 return await original_run_stdio(self_srv, config) 1540 raise ConnectionError("cannot connect") 1541 1542 async def _test(): 1543 nonlocal target_server 1544 server = MCPServerTask("test_srv") 1545 target_server = server 1546 1547 with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio), \ 1548 patch("asyncio.sleep", new_callable=AsyncMock): 1549 await server.run({"command": "test"}) 1550 1551 # Now retries up to _MAX_INITIAL_CONNECT_RETRIES before giving up 1552 assert run_count == _MAX_INITIAL_CONNECT_RETRIES + 1 1553 assert server._error is not None 1554 assert "cannot connect" in str(server._error) 1555 1556 asyncio.run(_test()) 1557 1558 1559 # --------------------------------------------------------------------------- 1560 # Configurable timeouts 1561 # --------------------------------------------------------------------------- 1562 1563 class TestConfigurableTimeouts: 1564 """Tests for configurable per-server timeouts.""" 1565 1566 def test_default_timeout(self): 1567 """Server with no timeout config gets _DEFAULT_TOOL_TIMEOUT.""" 1568 from tools.mcp_tool import MCPServerTask, _DEFAULT_TOOL_TIMEOUT 1569 1570 server = MCPServerTask("test_srv") 1571 assert server.tool_timeout == _DEFAULT_TOOL_TIMEOUT 1572 assert server.tool_timeout == 120 1573 1574 def test_custom_timeout(self): 1575 """Server with timeout=180 in config gets 180.""" 1576 from tools.mcp_tool import MCPServerTask 1577 1578 target_server = None 1579 1580 original_run_stdio = MCPServerTask._run_stdio 1581 1582 async def patched_run_stdio(self_srv, config): 1583 if target_server is not self_srv: 1584 return await original_run_stdio(self_srv, config) 1585 self_srv.session = MagicMock() 1586 self_srv._tools = [] 1587 self_srv._ready.set() 1588 await self_srv._shutdown_event.wait() 1589 1590 async def _test(): 1591 nonlocal target_server 1592 server = MCPServerTask("test_srv") 1593 target_server = server 1594 1595 with patch.object(MCPServerTask, "_run_stdio", patched_run_stdio): 1596 task = asyncio.ensure_future( 1597 server.run({"command": "test", "timeout": 180}) 1598 ) 1599 await server._ready.wait() 1600 assert server.tool_timeout == 180 1601 server._shutdown_event.set() 1602 await task 1603 1604 asyncio.run(_test()) 1605 1606 def test_timeout_passed_to_handler(self): 1607 """The tool handler uses the server's configured timeout.""" 1608 from tools.mcp_tool import _make_tool_handler, _servers, MCPServerTask 1609 1610 mock_session = MagicMock() 1611 mock_session.call_tool = AsyncMock( 1612 return_value=_make_call_result("ok", is_error=False) 1613 ) 1614 server = _make_mock_server("test_srv", session=mock_session) 1615 server.tool_timeout = 180 1616 _servers["test_srv"] = server 1617 1618 try: 1619 handler = _make_tool_handler("test_srv", "my_tool", 180) 1620 with patch("tools.mcp_tool._run_on_mcp_loop") as mock_run: 1621 def fake_run(coro, timeout=30): 1622 coro.close() 1623 return json.dumps({"result": "ok"}) 1624 1625 mock_run.side_effect = fake_run 1626 handler({}) 1627 # Verify timeout=180 was passed 1628 call_kwargs = mock_run.call_args 1629 assert call_kwargs.kwargs.get("timeout") == 180 or \ 1630 (len(call_kwargs.args) > 1 and call_kwargs.args[1] == 180) or \ 1631 call_kwargs[1].get("timeout") == 180 1632 finally: 1633 _servers.pop("test_srv", None) 1634 1635 1636 # --------------------------------------------------------------------------- 1637 # Utility tool schemas (Resources & Prompts) 1638 # --------------------------------------------------------------------------- 1639 1640 class TestUtilitySchemas: 1641 """Tests for _build_utility_schemas() and the schema format of utility tools.""" 1642 1643 def test_builds_four_utility_schemas(self): 1644 from tools.mcp_tool import _build_utility_schemas 1645 1646 schemas = _build_utility_schemas("myserver") 1647 assert len(schemas) == 4 1648 names = [s["schema"]["name"] for s in schemas] 1649 assert "mcp_myserver_list_resources" in names 1650 assert "mcp_myserver_read_resource" in names 1651 assert "mcp_myserver_list_prompts" in names 1652 assert "mcp_myserver_get_prompt" in names 1653 1654 def test_hyphens_sanitized_in_utility_names(self): 1655 from tools.mcp_tool import _build_utility_schemas 1656 1657 schemas = _build_utility_schemas("my-server") 1658 names = [s["schema"]["name"] for s in schemas] 1659 for name in names: 1660 assert "-" not in name 1661 assert "mcp_my_server_list_resources" in names 1662 1663 def test_list_resources_schema_no_required_params(self): 1664 from tools.mcp_tool import _build_utility_schemas 1665 1666 schemas = _build_utility_schemas("srv") 1667 lr = next(s for s in schemas if s["handler_key"] == "list_resources") 1668 params = lr["schema"]["parameters"] 1669 assert params["type"] == "object" 1670 assert params["properties"] == {} 1671 assert "required" not in params 1672 1673 def test_read_resource_schema_requires_uri(self): 1674 from tools.mcp_tool import _build_utility_schemas 1675 1676 schemas = _build_utility_schemas("srv") 1677 rr = next(s for s in schemas if s["handler_key"] == "read_resource") 1678 params = rr["schema"]["parameters"] 1679 assert "uri" in params["properties"] 1680 assert params["properties"]["uri"]["type"] == "string" 1681 assert params["required"] == ["uri"] 1682 1683 def test_list_prompts_schema_no_required_params(self): 1684 from tools.mcp_tool import _build_utility_schemas 1685 1686 schemas = _build_utility_schemas("srv") 1687 lp = next(s for s in schemas if s["handler_key"] == "list_prompts") 1688 params = lp["schema"]["parameters"] 1689 assert params["type"] == "object" 1690 assert params["properties"] == {} 1691 assert "required" not in params 1692 1693 def test_get_prompt_schema_requires_name(self): 1694 from tools.mcp_tool import _build_utility_schemas 1695 1696 schemas = _build_utility_schemas("srv") 1697 gp = next(s for s in schemas if s["handler_key"] == "get_prompt") 1698 params = gp["schema"]["parameters"] 1699 assert "name" in params["properties"] 1700 assert params["properties"]["name"]["type"] == "string" 1701 assert "arguments" in params["properties"] 1702 assert params["properties"]["arguments"]["type"] == "object" 1703 assert params["required"] == ["name"] 1704 1705 def test_schemas_have_descriptions(self): 1706 from tools.mcp_tool import _build_utility_schemas 1707 1708 schemas = _build_utility_schemas("test_srv") 1709 for entry in schemas: 1710 desc = entry["schema"]["description"] 1711 assert desc and len(desc) > 0 1712 assert "test_srv" in desc 1713 1714 1715 # --------------------------------------------------------------------------- 1716 # Utility tool handlers (Resources & Prompts) 1717 # --------------------------------------------------------------------------- 1718 1719 class TestUtilityHandlers: 1720 """Tests for the MCP Resources & Prompts handler functions.""" 1721 1722 def _patch_mcp_loop(self): 1723 """Return a patch for _run_on_mcp_loop that runs the coroutine directly.""" 1724 def fake_run(coro, timeout=30): 1725 return asyncio.run(coro) 1726 return patch("tools.mcp_tool._run_on_mcp_loop", side_effect=fake_run) 1727 1728 # -- list_resources -- 1729 1730 def test_list_resources_success(self): 1731 from tools.mcp_tool import _make_list_resources_handler, _servers 1732 1733 mock_resource = SimpleNamespace( 1734 uri="file:///tmp/test.txt", name="test.txt", 1735 description="A test file", mimeType="text/plain", 1736 ) 1737 mock_session = MagicMock() 1738 mock_session.list_resources = AsyncMock( 1739 return_value=SimpleNamespace(resources=[mock_resource]) 1740 ) 1741 server = _make_mock_server("srv", session=mock_session) 1742 _servers["srv"] = server 1743 1744 try: 1745 handler = _make_list_resources_handler("srv", 120) 1746 with self._patch_mcp_loop(): 1747 result = json.loads(handler({})) 1748 assert "resources" in result 1749 assert len(result["resources"]) == 1 1750 assert result["resources"][0]["uri"] == "file:///tmp/test.txt" 1751 assert result["resources"][0]["name"] == "test.txt" 1752 finally: 1753 _servers.pop("srv", None) 1754 1755 def test_list_resources_empty(self): 1756 from tools.mcp_tool import _make_list_resources_handler, _servers 1757 1758 mock_session = MagicMock() 1759 mock_session.list_resources = AsyncMock( 1760 return_value=SimpleNamespace(resources=[]) 1761 ) 1762 server = _make_mock_server("srv", session=mock_session) 1763 _servers["srv"] = server 1764 1765 try: 1766 handler = _make_list_resources_handler("srv", 120) 1767 with self._patch_mcp_loop(): 1768 result = json.loads(handler({})) 1769 assert result["resources"] == [] 1770 finally: 1771 _servers.pop("srv", None) 1772 1773 def test_list_resources_disconnected(self): 1774 from tools.mcp_tool import _make_list_resources_handler, _servers 1775 _servers.pop("ghost", None) 1776 handler = _make_list_resources_handler("ghost", 120) 1777 result = json.loads(handler({})) 1778 assert "error" in result 1779 assert "not connected" in result["error"] 1780 1781 # -- read_resource -- 1782 1783 def test_read_resource_success(self): 1784 from tools.mcp_tool import _make_read_resource_handler, _servers 1785 1786 content_block = SimpleNamespace(text="Hello from resource") 1787 mock_session = MagicMock() 1788 mock_session.read_resource = AsyncMock( 1789 return_value=SimpleNamespace(contents=[content_block]) 1790 ) 1791 server = _make_mock_server("srv", session=mock_session) 1792 _servers["srv"] = server 1793 1794 try: 1795 handler = _make_read_resource_handler("srv", 120) 1796 with self._patch_mcp_loop(): 1797 result = json.loads(handler({"uri": "file:///tmp/test.txt"})) 1798 assert result["result"] == "Hello from resource" 1799 mock_session.read_resource.assert_called_once_with("file:///tmp/test.txt") 1800 finally: 1801 _servers.pop("srv", None) 1802 1803 def test_read_resource_missing_uri(self): 1804 from tools.mcp_tool import _make_read_resource_handler, _servers 1805 1806 server = _make_mock_server("srv", session=MagicMock()) 1807 _servers["srv"] = server 1808 1809 try: 1810 handler = _make_read_resource_handler("srv", 120) 1811 result = json.loads(handler({})) 1812 assert "error" in result 1813 assert "uri" in result["error"].lower() 1814 finally: 1815 _servers.pop("srv", None) 1816 1817 def test_read_resource_disconnected(self): 1818 from tools.mcp_tool import _make_read_resource_handler, _servers 1819 _servers.pop("ghost", None) 1820 handler = _make_read_resource_handler("ghost", 120) 1821 result = json.loads(handler({"uri": "test://x"})) 1822 assert "error" in result 1823 assert "not connected" in result["error"] 1824 1825 # -- list_prompts -- 1826 1827 def test_list_prompts_success(self): 1828 from tools.mcp_tool import _make_list_prompts_handler, _servers 1829 1830 mock_prompt = SimpleNamespace( 1831 name="summarize", description="Summarize text", 1832 arguments=[ 1833 SimpleNamespace(name="text", description="Text to summarize", required=True), 1834 ], 1835 ) 1836 mock_session = MagicMock() 1837 mock_session.list_prompts = AsyncMock( 1838 return_value=SimpleNamespace(prompts=[mock_prompt]) 1839 ) 1840 server = _make_mock_server("srv", session=mock_session) 1841 _servers["srv"] = server 1842 1843 try: 1844 handler = _make_list_prompts_handler("srv", 120) 1845 with self._patch_mcp_loop(): 1846 result = json.loads(handler({})) 1847 assert "prompts" in result 1848 assert len(result["prompts"]) == 1 1849 assert result["prompts"][0]["name"] == "summarize" 1850 assert result["prompts"][0]["arguments"][0]["name"] == "text" 1851 finally: 1852 _servers.pop("srv", None) 1853 1854 def test_list_prompts_empty(self): 1855 from tools.mcp_tool import _make_list_prompts_handler, _servers 1856 1857 mock_session = MagicMock() 1858 mock_session.list_prompts = AsyncMock( 1859 return_value=SimpleNamespace(prompts=[]) 1860 ) 1861 server = _make_mock_server("srv", session=mock_session) 1862 _servers["srv"] = server 1863 1864 try: 1865 handler = _make_list_prompts_handler("srv", 120) 1866 with self._patch_mcp_loop(): 1867 result = json.loads(handler({})) 1868 assert result["prompts"] == [] 1869 finally: 1870 _servers.pop("srv", None) 1871 1872 def test_list_prompts_disconnected(self): 1873 from tools.mcp_tool import _make_list_prompts_handler, _servers 1874 _servers.pop("ghost", None) 1875 handler = _make_list_prompts_handler("ghost", 120) 1876 result = json.loads(handler({})) 1877 assert "error" in result 1878 assert "not connected" in result["error"] 1879 1880 # -- get_prompt -- 1881 1882 def test_get_prompt_success(self): 1883 from tools.mcp_tool import _make_get_prompt_handler, _servers 1884 1885 mock_msg = SimpleNamespace( 1886 role="assistant", 1887 content=SimpleNamespace(text="Here is a summary of your text."), 1888 ) 1889 mock_session = MagicMock() 1890 mock_session.get_prompt = AsyncMock( 1891 return_value=SimpleNamespace(messages=[mock_msg], description=None) 1892 ) 1893 server = _make_mock_server("srv", session=mock_session) 1894 _servers["srv"] = server 1895 1896 try: 1897 handler = _make_get_prompt_handler("srv", 120) 1898 with self._patch_mcp_loop(): 1899 result = json.loads(handler({"name": "summarize", "arguments": {"text": "hello"}})) 1900 assert "messages" in result 1901 assert len(result["messages"]) == 1 1902 assert result["messages"][0]["role"] == "assistant" 1903 assert "summary" in result["messages"][0]["content"].lower() 1904 mock_session.get_prompt.assert_called_once_with( 1905 "summarize", arguments={"text": "hello"} 1906 ) 1907 finally: 1908 _servers.pop("srv", None) 1909 1910 def test_get_prompt_missing_name(self): 1911 from tools.mcp_tool import _make_get_prompt_handler, _servers 1912 1913 server = _make_mock_server("srv", session=MagicMock()) 1914 _servers["srv"] = server 1915 1916 try: 1917 handler = _make_get_prompt_handler("srv", 120) 1918 result = json.loads(handler({})) 1919 assert "error" in result 1920 assert "name" in result["error"].lower() 1921 finally: 1922 _servers.pop("srv", None) 1923 1924 def test_get_prompt_disconnected(self): 1925 from tools.mcp_tool import _make_get_prompt_handler, _servers 1926 _servers.pop("ghost", None) 1927 handler = _make_get_prompt_handler("ghost", 120) 1928 result = json.loads(handler({"name": "test"})) 1929 assert "error" in result 1930 assert "not connected" in result["error"] 1931 1932 def test_get_prompt_default_arguments(self): 1933 from tools.mcp_tool import _make_get_prompt_handler, _servers 1934 1935 mock_session = MagicMock() 1936 mock_session.get_prompt = AsyncMock( 1937 return_value=SimpleNamespace(messages=[], description=None) 1938 ) 1939 server = _make_mock_server("srv", session=mock_session) 1940 _servers["srv"] = server 1941 1942 try: 1943 handler = _make_get_prompt_handler("srv", 120) 1944 with self._patch_mcp_loop(): 1945 handler({"name": "test_prompt"}) 1946 # arguments defaults to {} when not provided 1947 mock_session.get_prompt.assert_called_once_with( 1948 "test_prompt", arguments={} 1949 ) 1950 finally: 1951 _servers.pop("srv", None) 1952 1953 1954 # --------------------------------------------------------------------------- 1955 # Utility tools registration in _discover_and_register_server 1956 # --------------------------------------------------------------------------- 1957 1958 class TestUtilityToolRegistration: 1959 """Verify utility tools are registered alongside regular MCP tools.""" 1960 1961 def test_utility_tools_registered(self): 1962 """_discover_and_register_server registers all 4 utility tools.""" 1963 from tools.registry import ToolRegistry 1964 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 1965 1966 mock_registry = ToolRegistry() 1967 mock_tools = [_make_mcp_tool("read_file", "Read a file")] 1968 mock_session = MagicMock() 1969 1970 async def fake_connect(name, config): 1971 server = MCPServerTask(name) 1972 server.session = mock_session 1973 server._tools = mock_tools 1974 return server 1975 1976 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 1977 patch("tools.registry.registry", mock_registry): 1978 registered = asyncio.run( 1979 _discover_and_register_server("fs", {"command": "npx", "args": []}) 1980 ) 1981 1982 # Regular tool + 4 utility tools 1983 assert "mcp_fs_read_file" in registered 1984 assert "mcp_fs_list_resources" in registered 1985 assert "mcp_fs_read_resource" in registered 1986 assert "mcp_fs_list_prompts" in registered 1987 assert "mcp_fs_get_prompt" in registered 1988 assert len(registered) == 5 1989 1990 # All in the registry 1991 all_names = mock_registry.get_all_tool_names() 1992 for name in registered: 1993 assert name in all_names 1994 1995 _servers.pop("fs", None) 1996 1997 def test_utility_tools_in_same_toolset(self): 1998 """Utility tools belong to the same mcp-{server} toolset.""" 1999 from tools.registry import ToolRegistry 2000 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 2001 2002 mock_registry = ToolRegistry() 2003 mock_session = MagicMock() 2004 2005 async def fake_connect(name, config): 2006 server = MCPServerTask(name) 2007 server.session = mock_session 2008 server._tools = [] 2009 return server 2010 2011 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 2012 patch("tools.registry.registry", mock_registry): 2013 asyncio.run( 2014 _discover_and_register_server("myserv", {"command": "test"}) 2015 ) 2016 2017 # Check that utility tools are in the right toolset 2018 for tool_name in ["mcp_myserv_list_resources", "mcp_myserv_read_resource", 2019 "mcp_myserv_list_prompts", "mcp_myserv_get_prompt"]: 2020 entry = mock_registry._tools.get(tool_name) 2021 assert entry is not None, f"{tool_name} not found in registry" 2022 assert entry.toolset == "mcp-myserv" 2023 2024 _servers.pop("myserv", None) 2025 2026 def test_utility_tools_have_check_fn(self): 2027 """Utility tools have a working check_fn.""" 2028 from tools.registry import ToolRegistry 2029 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 2030 2031 mock_registry = ToolRegistry() 2032 mock_session = MagicMock() 2033 2034 async def fake_connect(name, config): 2035 server = MCPServerTask(name) 2036 server.session = mock_session 2037 server._tools = [] 2038 return server 2039 2040 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 2041 patch("tools.registry.registry", mock_registry): 2042 asyncio.run( 2043 _discover_and_register_server("chk", {"command": "test"}) 2044 ) 2045 2046 entry = mock_registry._tools.get("mcp_chk_list_resources") 2047 assert entry is not None 2048 # Server is connected, check_fn should return True 2049 assert entry.check_fn() is True 2050 2051 # Disconnect the server 2052 _servers["chk"].session = None 2053 assert entry.check_fn() is False 2054 2055 _servers.pop("chk", None) 2056 2057 2058 # =========================================================================== 2059 # SamplingHandler tests 2060 # =========================================================================== 2061 2062 import math 2063 import time 2064 2065 class _CompatType: 2066 def __init__(self, **kwargs): 2067 self.__dict__.update(kwargs) 2068 2069 2070 try: 2071 from mcp.types import ( 2072 CreateMessageResult, 2073 ErrorData, 2074 SamplingCapability, 2075 TextContent, 2076 ) 2077 except ImportError: 2078 CreateMessageResult = _CompatType 2079 ErrorData = _CompatType 2080 SamplingCapability = _CompatType 2081 TextContent = _CompatType 2082 2083 try: 2084 from mcp.types import CreateMessageResultWithTools 2085 except ImportError: 2086 CreateMessageResultWithTools = _CompatType 2087 2088 try: 2089 from mcp.types import SamplingToolsCapability 2090 except ImportError: 2091 SamplingToolsCapability = _CompatType 2092 2093 try: 2094 from mcp.types import ToolUseContent 2095 except ImportError: 2096 ToolUseContent = _CompatType 2097 2098 from tools.mcp_tool import ( 2099 CreateMessageResultWithTools, 2100 SamplingHandler, 2101 SamplingToolsCapability, 2102 ToolUseContent, 2103 _safe_numeric, 2104 ) 2105 2106 2107 # --------------------------------------------------------------------------- 2108 # Helpers for sampling tests 2109 # --------------------------------------------------------------------------- 2110 2111 def _make_sampling_params( 2112 messages=None, 2113 max_tokens=100, 2114 system_prompt=None, 2115 model_preferences=None, 2116 temperature=None, 2117 stop_sequences=None, 2118 tools=None, 2119 tool_choice=None, 2120 ): 2121 """Create a fake CreateMessageRequestParams using SimpleNamespace. 2122 2123 Each message must have a ``content_as_list`` attribute that mirrors 2124 the SDK helper so that ``_convert_messages`` works correctly. 2125 """ 2126 if messages is None: 2127 content = SimpleNamespace(text="Hello") 2128 msg = SimpleNamespace(role="user", content=content, content_as_list=[content]) 2129 messages = [msg] 2130 2131 params = SimpleNamespace( 2132 messages=messages, 2133 maxTokens=max_tokens, 2134 modelPreferences=model_preferences, 2135 temperature=temperature, 2136 stopSequences=stop_sequences, 2137 tools=tools, 2138 toolChoice=tool_choice, 2139 ) 2140 if system_prompt is not None: 2141 params.systemPrompt = system_prompt 2142 return params 2143 2144 2145 def _make_llm_response( 2146 content="LLM response", 2147 model="test-model", 2148 finish_reason="stop", 2149 tool_calls=None, 2150 ): 2151 """Create a fake OpenAI chat completion response (text).""" 2152 message = SimpleNamespace(content=content, tool_calls=tool_calls) 2153 choice = SimpleNamespace( 2154 finish_reason=finish_reason, 2155 message=message, 2156 ) 2157 usage = SimpleNamespace(total_tokens=42) 2158 return SimpleNamespace(choices=[choice], model=model, usage=usage) 2159 2160 2161 def _make_llm_tool_response(tool_calls_data=None, model="test-model"): 2162 """Create a fake response with tool_calls. 2163 2164 ``tool_calls_data``: list of (id, name, arguments_json) tuples. 2165 """ 2166 if tool_calls_data is None: 2167 tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')] 2168 2169 tc_list = [ 2170 SimpleNamespace( 2171 id=tc_id, 2172 function=SimpleNamespace(name=name, arguments=args), 2173 ) 2174 for tc_id, name, args in tool_calls_data 2175 ] 2176 return _make_llm_response( 2177 content=None, 2178 model=model, 2179 finish_reason="tool_calls", 2180 tool_calls=tc_list, 2181 ) 2182 2183 2184 # --------------------------------------------------------------------------- 2185 # 1. _safe_numeric helper 2186 # --------------------------------------------------------------------------- 2187 2188 class TestSafeNumeric: 2189 def test_int_passthrough(self): 2190 assert _safe_numeric(10, 5, int) == 10 2191 2192 def test_string_coercion(self): 2193 assert _safe_numeric("20", 5, int) == 20 2194 2195 def test_none_returns_default(self): 2196 assert _safe_numeric(None, 7, int) == 7 2197 2198 def test_inf_returns_default(self): 2199 assert _safe_numeric(float("inf"), 3.0, float) == 3.0 2200 2201 def test_nan_returns_default(self): 2202 assert _safe_numeric(float("nan"), 4.0, float) == 4.0 2203 2204 def test_below_minimum_clamps(self): 2205 assert _safe_numeric(-5, 10, int, minimum=1) == 1 2206 2207 def test_minimum_zero_allowed(self): 2208 assert _safe_numeric(0, 10, int, minimum=0) == 0 2209 2210 def test_non_numeric_string_returns_default(self): 2211 assert _safe_numeric("abc", 42, int) == 42 2212 2213 def test_float_coercion(self): 2214 assert _safe_numeric("3.5", 1.0, float) == 3.5 2215 2216 2217 # --------------------------------------------------------------------------- 2218 # 2. SamplingHandler initialization and config parsing 2219 # --------------------------------------------------------------------------- 2220 2221 class TestSamplingHandlerInit: 2222 def test_defaults(self): 2223 h = SamplingHandler("srv", {}) 2224 assert h.server_name == "srv" 2225 assert h.max_rpm == 10 2226 assert h.timeout == 30 2227 assert h.max_tokens_cap == 4096 2228 assert h.max_tool_rounds == 5 2229 assert h.model_override is None 2230 assert h.allowed_models == [] 2231 assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0} 2232 2233 def test_custom_config(self): 2234 cfg = { 2235 "max_rpm": 20, 2236 "timeout": 60, 2237 "max_tokens_cap": 2048, 2238 "max_tool_rounds": 3, 2239 "model": "gpt-4o", 2240 "allowed_models": ["gpt-4o", "gpt-3.5-turbo"], 2241 "log_level": "debug", 2242 } 2243 h = SamplingHandler("custom", cfg) 2244 assert h.max_rpm == 20 2245 assert h.timeout == 60.0 2246 assert h.max_tokens_cap == 2048 2247 assert h.max_tool_rounds == 3 2248 assert h.model_override == "gpt-4o" 2249 assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"] 2250 2251 def test_string_numeric_config_values(self): 2252 """YAML sometimes delivers numeric values as strings.""" 2253 cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"} 2254 h = SamplingHandler("s", cfg) 2255 assert h.max_rpm == 15 2256 assert h.timeout == 45.5 2257 assert h.max_tokens_cap == 1024 2258 2259 2260 # --------------------------------------------------------------------------- 2261 # 3. Rate limiting 2262 # --------------------------------------------------------------------------- 2263 2264 class TestRateLimit: 2265 def setup_method(self): 2266 self.handler = SamplingHandler("rl", {"max_rpm": 3}) 2267 2268 def test_allows_under_limit(self): 2269 assert self.handler._check_rate_limit() is True 2270 assert self.handler._check_rate_limit() is True 2271 assert self.handler._check_rate_limit() is True 2272 2273 def test_rejects_over_limit(self): 2274 for _ in range(3): 2275 self.handler._check_rate_limit() 2276 assert self.handler._check_rate_limit() is False 2277 2278 def test_window_expiry(self): 2279 """Old timestamps should be purged from the sliding window.""" 2280 for _ in range(3): 2281 self.handler._check_rate_limit() 2282 # Simulate timestamps from 61 seconds ago 2283 self.handler._rate_timestamps[:] = [time.time() - 61] * 3 2284 assert self.handler._check_rate_limit() is True 2285 2286 2287 # --------------------------------------------------------------------------- 2288 # 4. Model resolution 2289 # --------------------------------------------------------------------------- 2290 2291 class TestResolveModel: 2292 def setup_method(self): 2293 self.handler = SamplingHandler("mr", {}) 2294 2295 def test_no_preference_no_override(self): 2296 assert self.handler._resolve_model(None) is None 2297 2298 def test_config_override_wins(self): 2299 self.handler.model_override = "override-model" 2300 prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")]) 2301 assert self.handler._resolve_model(prefs) == "override-model" 2302 2303 def test_hint_used_when_no_override(self): 2304 prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")]) 2305 assert self.handler._resolve_model(prefs) == "hint-model" 2306 2307 def test_empty_hints(self): 2308 prefs = SimpleNamespace(hints=[]) 2309 assert self.handler._resolve_model(prefs) is None 2310 2311 def test_hint_without_name(self): 2312 prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)]) 2313 assert self.handler._resolve_model(prefs) is None 2314 2315 2316 # --------------------------------------------------------------------------- 2317 # 5. Message conversion 2318 # --------------------------------------------------------------------------- 2319 2320 class TestConvertMessages: 2321 def setup_method(self): 2322 self.handler = SamplingHandler("mc", {}) 2323 2324 def test_single_text_message(self): 2325 content = SimpleNamespace(text="Hello world") 2326 msg = SimpleNamespace(role="user", content=content, content_as_list=[content]) 2327 params = _make_sampling_params(messages=[msg]) 2328 result = self.handler._convert_messages(params) 2329 assert len(result) == 1 2330 assert result[0] == {"role": "user", "content": "Hello world"} 2331 2332 def test_image_message(self): 2333 text_block = SimpleNamespace(text="Look at this") 2334 img_block = SimpleNamespace(data="abc123", mimeType="image/png") 2335 msg = SimpleNamespace( 2336 role="user", 2337 content=[text_block, img_block], 2338 content_as_list=[text_block, img_block], 2339 ) 2340 params = _make_sampling_params(messages=[msg]) 2341 result = self.handler._convert_messages(params) 2342 assert len(result) == 1 2343 parts = result[0]["content"] 2344 assert len(parts) == 2 2345 assert parts[0] == {"type": "text", "text": "Look at this"} 2346 assert parts[1]["type"] == "image_url" 2347 assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"] 2348 2349 def test_tool_result_message(self): 2350 inner = SimpleNamespace(text="42 degrees") 2351 tr_block = SimpleNamespace(toolUseId="call_1", content=[inner]) 2352 msg = SimpleNamespace( 2353 role="user", 2354 content=[tr_block], 2355 content_as_list=[tr_block], 2356 ) 2357 params = _make_sampling_params(messages=[msg]) 2358 result = self.handler._convert_messages(params) 2359 assert len(result) == 1 2360 assert result[0]["role"] == "tool" 2361 assert result[0]["tool_call_id"] == "call_1" 2362 assert result[0]["content"] == "42 degrees" 2363 2364 def test_tool_use_message(self): 2365 tu_block = SimpleNamespace( 2366 id="call_2", name="get_weather", input={"city": "London"} 2367 ) 2368 msg = SimpleNamespace( 2369 role="assistant", 2370 content=[tu_block], 2371 content_as_list=[tu_block], 2372 ) 2373 params = _make_sampling_params(messages=[msg]) 2374 result = self.handler._convert_messages(params) 2375 assert len(result) == 1 2376 assert result[0]["role"] == "assistant" 2377 assert len(result[0]["tool_calls"]) == 1 2378 assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather" 2379 assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"} 2380 2381 def test_mixed_text_and_tool_use(self): 2382 """Assistant message with both text and tool_calls.""" 2383 text_block = SimpleNamespace(text="Let me check the weather") 2384 tu_block = SimpleNamespace( 2385 id="call_3", name="get_weather", input={"city": "Paris"} 2386 ) 2387 msg = SimpleNamespace( 2388 role="assistant", 2389 content=[text_block, tu_block], 2390 content_as_list=[text_block, tu_block], 2391 ) 2392 params = _make_sampling_params(messages=[msg]) 2393 result = self.handler._convert_messages(params) 2394 assert len(result) == 1 2395 assert result[0]["content"] == "Let me check the weather" 2396 assert len(result[0]["tool_calls"]) == 1 2397 2398 def test_fallback_without_content_as_list(self): 2399 """When content_as_list is absent, falls back to content.""" 2400 content = SimpleNamespace(text="Fallback text") 2401 msg = SimpleNamespace(role="user", content=content) 2402 params = _make_sampling_params(messages=[msg]) 2403 result = self.handler._convert_messages(params) 2404 assert len(result) == 1 2405 assert result[0]["content"] == "Fallback text" 2406 2407 2408 # --------------------------------------------------------------------------- 2409 # 6. Text-only sampling callback (full flow) 2410 # --------------------------------------------------------------------------- 2411 2412 class TestSamplingCallbackText: 2413 def setup_method(self): 2414 self.handler = SamplingHandler("txt", {}) 2415 2416 def test_text_response(self): 2417 """Full flow: text response returns CreateMessageResult.""" 2418 fake_client = MagicMock() 2419 fake_client.chat.completions.create.return_value = _make_llm_response( 2420 content="Hello from LLM" 2421 ) 2422 2423 with patch( 2424 "agent.auxiliary_client.call_llm", 2425 return_value=fake_client.chat.completions.create.return_value, 2426 ): 2427 params = _make_sampling_params() 2428 result = asyncio.run(self.handler(None, params)) 2429 2430 assert isinstance(result, CreateMessageResult) 2431 assert isinstance(result.content, TextContent) 2432 assert result.content.text == "Hello from LLM" 2433 assert result.model == "test-model" 2434 assert result.role == "assistant" 2435 assert result.stopReason == "endTurn" 2436 2437 def test_system_prompt_prepended(self): 2438 """System prompt is inserted as the first message.""" 2439 fake_client = MagicMock() 2440 fake_client.chat.completions.create.return_value = _make_llm_response() 2441 2442 with patch( 2443 "agent.auxiliary_client.call_llm", 2444 return_value=fake_client.chat.completions.create.return_value, 2445 ) as mock_call: 2446 params = _make_sampling_params(system_prompt="Be helpful") 2447 asyncio.run(self.handler(None, params)) 2448 2449 call_args = mock_call.call_args 2450 messages = call_args.kwargs["messages"] 2451 assert messages[0] == {"role": "system", "content": "Be helpful"} 2452 2453 def test_server_tools_with_object_schema_are_normalized(self): 2454 """Server-provided tools should gain empty properties for object schemas.""" 2455 fake_client = MagicMock() 2456 fake_client.chat.completions.create.return_value = _make_llm_response() 2457 server_tool = SimpleNamespace( 2458 name="ask", 2459 description="Ask Crawl4AI", 2460 inputSchema={"type": "object"}, 2461 ) 2462 2463 with patch( 2464 "agent.auxiliary_client.call_llm", 2465 return_value=fake_client.chat.completions.create.return_value, 2466 ) as mock_call: 2467 params = _make_sampling_params(tools=[server_tool]) 2468 asyncio.run(self.handler(None, params)) 2469 2470 tools = mock_call.call_args.kwargs["tools"] 2471 assert tools == [{ 2472 "type": "function", 2473 "function": { 2474 "name": "ask", 2475 "description": "Ask Crawl4AI", 2476 "parameters": {"type": "object", "properties": {}}, 2477 }, 2478 }] 2479 2480 def test_length_stop_reason(self): 2481 """finish_reason='length' maps to stopReason='maxTokens'.""" 2482 fake_client = MagicMock() 2483 fake_client.chat.completions.create.return_value = _make_llm_response( 2484 finish_reason="length" 2485 ) 2486 2487 with patch( 2488 "agent.auxiliary_client.call_llm", 2489 return_value=fake_client.chat.completions.create.return_value, 2490 ): 2491 params = _make_sampling_params() 2492 result = asyncio.run(self.handler(None, params)) 2493 2494 assert isinstance(result, CreateMessageResult) 2495 assert result.stopReason == "maxTokens" 2496 2497 2498 # --------------------------------------------------------------------------- 2499 # 7. Tool use sampling callback 2500 # --------------------------------------------------------------------------- 2501 2502 class TestSamplingCallbackToolUse: 2503 def setup_method(self): 2504 self.handler = SamplingHandler("tu", {}) 2505 2506 def test_tool_use_response(self): 2507 """LLM tool_calls response returns CreateMessageResultWithTools.""" 2508 fake_client = MagicMock() 2509 fake_client.chat.completions.create.return_value = _make_llm_tool_response() 2510 2511 with patch( 2512 "agent.auxiliary_client.call_llm", 2513 return_value=fake_client.chat.completions.create.return_value, 2514 ): 2515 params = _make_sampling_params() 2516 result = asyncio.run(self.handler(None, params)) 2517 2518 assert isinstance(result, CreateMessageResultWithTools) 2519 assert result.stopReason == "toolUse" 2520 assert result.model == "test-model" 2521 assert len(result.content) == 1 2522 tc = result.content[0] 2523 assert isinstance(tc, ToolUseContent) 2524 assert tc.name == "get_weather" 2525 assert tc.id == "call_1" 2526 assert tc.input == {"city": "London"} 2527 2528 def test_multiple_tool_calls(self): 2529 """Multiple tool_calls in a single response.""" 2530 fake_client = MagicMock() 2531 fake_client.chat.completions.create.return_value = _make_llm_tool_response( 2532 tool_calls_data=[ 2533 ("call_a", "func_a", '{"x": 1}'), 2534 ("call_b", "func_b", '{"y": 2}'), 2535 ] 2536 ) 2537 2538 with patch( 2539 "agent.auxiliary_client.call_llm", 2540 return_value=fake_client.chat.completions.create.return_value, 2541 ): 2542 result = asyncio.run(self.handler(None, _make_sampling_params())) 2543 2544 assert isinstance(result, CreateMessageResultWithTools) 2545 assert len(result.content) == 2 2546 assert result.content[0].name == "func_a" 2547 assert result.content[1].name == "func_b" 2548 2549 2550 # --------------------------------------------------------------------------- 2551 # 8. Tool loop governance 2552 # --------------------------------------------------------------------------- 2553 2554 class TestToolLoopGovernance: 2555 def test_max_tool_rounds_enforcement(self): 2556 """After max_tool_rounds consecutive tool responses, an error is returned.""" 2557 handler = SamplingHandler("tl", {"max_tool_rounds": 2}) 2558 fake_client = MagicMock() 2559 fake_client.chat.completions.create.return_value = _make_llm_tool_response() 2560 2561 with patch( 2562 "agent.auxiliary_client.call_llm", 2563 return_value=fake_client.chat.completions.create.return_value, 2564 ): 2565 params = _make_sampling_params() 2566 # Round 1, 2: allowed 2567 r1 = asyncio.run(handler(None, params)) 2568 assert isinstance(r1, CreateMessageResultWithTools) 2569 r2 = asyncio.run(handler(None, params)) 2570 assert isinstance(r2, CreateMessageResultWithTools) 2571 # Round 3: exceeds limit 2572 r3 = asyncio.run(handler(None, params)) 2573 assert isinstance(r3, ErrorData) 2574 assert "Tool loop limit exceeded" in r3.message 2575 2576 def test_text_response_resets_counter(self): 2577 """A text response resets the tool loop counter.""" 2578 handler = SamplingHandler("tl2", {"max_tool_rounds": 1}) 2579 2580 # Use a list to hold the current response, so the side_effect can 2581 # pick up changes between calls. 2582 responses = [_make_llm_tool_response()] 2583 2584 with patch( 2585 "agent.auxiliary_client.call_llm", 2586 side_effect=lambda **kw: responses[0], 2587 ): 2588 # Tool response (round 1 of 1 allowed) 2589 r1 = asyncio.run(handler(None, _make_sampling_params())) 2590 assert isinstance(r1, CreateMessageResultWithTools) 2591 2592 # Text response resets counter 2593 responses[0] = _make_llm_response() 2594 r2 = asyncio.run(handler(None, _make_sampling_params())) 2595 assert isinstance(r2, CreateMessageResult) 2596 2597 # Tool response again (should succeed since counter was reset) 2598 responses[0] = _make_llm_tool_response() 2599 r3 = asyncio.run(handler(None, _make_sampling_params())) 2600 assert isinstance(r3, CreateMessageResultWithTools) 2601 2602 def test_max_tool_rounds_zero_disables(self): 2603 """max_tool_rounds=0 means tool loops are disabled entirely.""" 2604 handler = SamplingHandler("tl3", {"max_tool_rounds": 0}) 2605 fake_client = MagicMock() 2606 fake_client.chat.completions.create.return_value = _make_llm_tool_response() 2607 2608 with patch( 2609 "agent.auxiliary_client.call_llm", 2610 return_value=fake_client.chat.completions.create.return_value, 2611 ): 2612 result = asyncio.run(handler(None, _make_sampling_params())) 2613 assert isinstance(result, ErrorData) 2614 assert "Tool loops disabled" in result.message 2615 2616 2617 # --------------------------------------------------------------------------- 2618 # 9. Error paths: rate limit, timeout, no provider 2619 # --------------------------------------------------------------------------- 2620 2621 class TestSamplingErrors: 2622 def test_rate_limit_error(self): 2623 handler = SamplingHandler("rle", {"max_rpm": 1}) 2624 fake_client = MagicMock() 2625 fake_client.chat.completions.create.return_value = _make_llm_response() 2626 2627 with patch( 2628 "agent.auxiliary_client.call_llm", 2629 return_value=fake_client.chat.completions.create.return_value, 2630 ): 2631 # First call succeeds 2632 r1 = asyncio.run(handler(None, _make_sampling_params())) 2633 assert isinstance(r1, CreateMessageResult) 2634 # Second call is rate limited 2635 r2 = asyncio.run(handler(None, _make_sampling_params())) 2636 assert isinstance(r2, ErrorData) 2637 assert "rate limit" in r2.message.lower() 2638 assert handler.metrics["errors"] == 1 2639 2640 def test_timeout_error(self): 2641 handler = SamplingHandler("to", {"timeout": 0.05}) 2642 2643 def slow_call(**kwargs): 2644 import threading 2645 evt = threading.Event() 2646 evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout) 2647 return _make_llm_response() 2648 2649 with patch( 2650 "agent.auxiliary_client.call_llm", 2651 side_effect=slow_call, 2652 ): 2653 result = asyncio.run(handler(None, _make_sampling_params())) 2654 assert isinstance(result, ErrorData) 2655 assert "timed out" in result.message.lower() 2656 assert handler.metrics["errors"] == 1 2657 2658 def test_no_provider_error(self): 2659 handler = SamplingHandler("np", {}) 2660 2661 with patch( 2662 "agent.auxiliary_client.call_llm", 2663 side_effect=RuntimeError("No LLM provider configured"), 2664 ): 2665 result = asyncio.run(handler(None, _make_sampling_params())) 2666 assert isinstance(result, ErrorData) 2667 assert handler.metrics["errors"] == 1 2668 2669 def test_empty_choices_returns_error(self): 2670 """LLM returning choices=[] is handled gracefully, not IndexError.""" 2671 handler = SamplingHandler("ec", {}) 2672 fake_client = MagicMock() 2673 fake_client.chat.completions.create.return_value = SimpleNamespace( 2674 choices=[], 2675 model="test-model", 2676 usage=SimpleNamespace(total_tokens=0), 2677 ) 2678 2679 with patch( 2680 "agent.auxiliary_client.call_llm", 2681 return_value=fake_client.chat.completions.create.return_value, 2682 ): 2683 result = asyncio.run(handler(None, _make_sampling_params())) 2684 2685 assert isinstance(result, ErrorData) 2686 assert "empty response" in result.message.lower() 2687 assert handler.metrics["errors"] == 1 2688 2689 def test_none_choices_returns_error(self): 2690 """LLM returning choices=None is handled gracefully, not TypeError.""" 2691 handler = SamplingHandler("nc", {}) 2692 fake_client = MagicMock() 2693 fake_client.chat.completions.create.return_value = SimpleNamespace( 2694 choices=None, 2695 model="test-model", 2696 usage=SimpleNamespace(total_tokens=0), 2697 ) 2698 2699 with patch( 2700 "agent.auxiliary_client.call_llm", 2701 return_value=fake_client.chat.completions.create.return_value, 2702 ): 2703 result = asyncio.run(handler(None, _make_sampling_params())) 2704 2705 assert isinstance(result, ErrorData) 2706 assert "empty response" in result.message.lower() 2707 assert handler.metrics["errors"] == 1 2708 2709 def test_missing_choices_attr_returns_error(self): 2710 """LLM response without choices attribute is handled gracefully.""" 2711 handler = SamplingHandler("mc", {}) 2712 fake_client = MagicMock() 2713 fake_client.chat.completions.create.return_value = SimpleNamespace( 2714 model="test-model", 2715 usage=SimpleNamespace(total_tokens=0), 2716 ) 2717 2718 with patch( 2719 "agent.auxiliary_client.call_llm", 2720 return_value=fake_client.chat.completions.create.return_value, 2721 ): 2722 result = asyncio.run(handler(None, _make_sampling_params())) 2723 2724 assert isinstance(result, ErrorData) 2725 assert "empty response" in result.message.lower() 2726 assert handler.metrics["errors"] == 1 2727 2728 2729 # --------------------------------------------------------------------------- 2730 # 10. Model whitelist 2731 # --------------------------------------------------------------------------- 2732 2733 class TestModelWhitelist: 2734 def test_allowed_model_passes(self): 2735 handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]}) 2736 fake_client = MagicMock() 2737 fake_client.chat.completions.create.return_value = _make_llm_response() 2738 2739 with patch( 2740 "agent.auxiliary_client.call_llm", 2741 return_value=fake_client.chat.completions.create.return_value, 2742 ): 2743 result = asyncio.run(handler(None, _make_sampling_params())) 2744 assert isinstance(result, CreateMessageResult) 2745 2746 def test_disallowed_model_rejected(self): 2747 handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"], "model": "test-model"}) 2748 fake_client = MagicMock() 2749 2750 with patch( 2751 "agent.auxiliary_client.call_llm", 2752 return_value=fake_client.chat.completions.create.return_value, 2753 ): 2754 result = asyncio.run(handler(None, _make_sampling_params())) 2755 assert isinstance(result, ErrorData) 2756 assert "not allowed" in result.message 2757 assert handler.metrics["errors"] == 1 2758 2759 def test_empty_whitelist_allows_all(self): 2760 handler = SamplingHandler("wl3", {"allowed_models": []}) 2761 fake_client = MagicMock() 2762 fake_client.chat.completions.create.return_value = _make_llm_response() 2763 2764 with patch( 2765 "agent.auxiliary_client.call_llm", 2766 return_value=fake_client.chat.completions.create.return_value, 2767 ): 2768 result = asyncio.run(handler(None, _make_sampling_params())) 2769 assert isinstance(result, CreateMessageResult) 2770 2771 2772 # --------------------------------------------------------------------------- 2773 # 11. Malformed tool_call arguments 2774 # --------------------------------------------------------------------------- 2775 2776 class TestMalformedToolCallArgs: 2777 def test_invalid_json_wrapped_as_raw(self): 2778 """Malformed JSON arguments get wrapped in {"_raw": ...}.""" 2779 handler = SamplingHandler("mf", {}) 2780 fake_client = MagicMock() 2781 fake_client.chat.completions.create.return_value = _make_llm_tool_response( 2782 tool_calls_data=[("call_x", "some_tool", "not valid json {{{")] 2783 ) 2784 2785 with patch( 2786 "agent.auxiliary_client.call_llm", 2787 return_value=fake_client.chat.completions.create.return_value, 2788 ): 2789 result = asyncio.run(handler(None, _make_sampling_params())) 2790 2791 assert isinstance(result, CreateMessageResultWithTools) 2792 tc = result.content[0] 2793 assert isinstance(tc, ToolUseContent) 2794 assert tc.input == {"_raw": "not valid json {{{"} 2795 2796 def test_dict_args_pass_through(self): 2797 """When arguments are already a dict, they pass through directly.""" 2798 handler = SamplingHandler("mf2", {}) 2799 2800 # Build a tool call where arguments is already a dict 2801 tc_obj = SimpleNamespace( 2802 id="call_d", 2803 function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}), 2804 ) 2805 message = SimpleNamespace(content=None, tool_calls=[tc_obj]) 2806 choice = SimpleNamespace(finish_reason="tool_calls", message=message) 2807 usage = SimpleNamespace(total_tokens=10) 2808 response = SimpleNamespace(choices=[choice], model="m", usage=usage) 2809 2810 fake_client = MagicMock() 2811 fake_client.chat.completions.create.return_value = response 2812 2813 with patch( 2814 "agent.auxiliary_client.call_llm", 2815 return_value=fake_client.chat.completions.create.return_value, 2816 ): 2817 result = asyncio.run(handler(None, _make_sampling_params())) 2818 2819 assert isinstance(result, CreateMessageResultWithTools) 2820 assert result.content[0].input == {"key": "val"} 2821 2822 2823 # --------------------------------------------------------------------------- 2824 # 12. Metrics tracking 2825 # --------------------------------------------------------------------------- 2826 2827 class TestMetricsTracking: 2828 def test_request_and_token_metrics(self): 2829 handler = SamplingHandler("met", {}) 2830 fake_client = MagicMock() 2831 fake_client.chat.completions.create.return_value = _make_llm_response() 2832 2833 with patch( 2834 "agent.auxiliary_client.call_llm", 2835 return_value=fake_client.chat.completions.create.return_value, 2836 ): 2837 asyncio.run(handler(None, _make_sampling_params())) 2838 2839 assert handler.metrics["requests"] == 1 2840 assert handler.metrics["tokens_used"] == 42 2841 assert handler.metrics["errors"] == 0 2842 2843 def test_tool_use_count_metric(self): 2844 handler = SamplingHandler("met2", {}) 2845 fake_client = MagicMock() 2846 fake_client.chat.completions.create.return_value = _make_llm_tool_response() 2847 2848 with patch( 2849 "agent.auxiliary_client.call_llm", 2850 return_value=fake_client.chat.completions.create.return_value, 2851 ): 2852 asyncio.run(handler(None, _make_sampling_params())) 2853 2854 assert handler.metrics["tool_use_count"] == 1 2855 assert handler.metrics["requests"] == 1 2856 2857 def test_error_metric_incremented(self): 2858 handler = SamplingHandler("met3", {}) 2859 2860 with patch( 2861 "agent.auxiliary_client.call_llm", 2862 side_effect=RuntimeError("No LLM provider configured"), 2863 ): 2864 asyncio.run(handler(None, _make_sampling_params())) 2865 2866 assert handler.metrics["errors"] == 1 2867 assert handler.metrics["requests"] == 0 2868 2869 2870 # --------------------------------------------------------------------------- 2871 # 13. session_kwargs() 2872 # --------------------------------------------------------------------------- 2873 2874 class TestSessionKwargs: 2875 def test_returns_correct_keys(self): 2876 handler = SamplingHandler("sk", {}) 2877 kwargs = handler.session_kwargs() 2878 assert "sampling_callback" in kwargs 2879 assert "sampling_capabilities" in kwargs 2880 assert kwargs["sampling_callback"] is handler 2881 2882 def test_sampling_capabilities_type(self): 2883 handler = SamplingHandler("sk2", {}) 2884 kwargs = handler.session_kwargs() 2885 cap = kwargs["sampling_capabilities"] 2886 assert isinstance(cap, SamplingCapability) 2887 assert isinstance(cap.tools, SamplingToolsCapability) 2888 2889 2890 # --------------------------------------------------------------------------- 2891 # 14. MCPServerTask integration 2892 # --------------------------------------------------------------------------- 2893 2894 class TestMCPServerTaskSamplingIntegration: 2895 def test_sampling_handler_created_when_enabled(self): 2896 """MCPServerTask.run() creates a SamplingHandler when sampling is enabled.""" 2897 from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES 2898 2899 server = MCPServerTask("int_test") 2900 config = { 2901 "command": "fake", 2902 "sampling": {"enabled": True, "max_rpm": 5}, 2903 } 2904 # We only need to test the setup logic, not the actual connection. 2905 # Calling run() would attempt a real connection, so we test the 2906 # sampling setup portion directly. 2907 server._config = config 2908 sampling_config = config.get("sampling", {}) 2909 if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES: 2910 server._sampling = SamplingHandler(server.name, sampling_config) 2911 else: 2912 server._sampling = None 2913 2914 assert server._sampling is not None 2915 assert isinstance(server._sampling, SamplingHandler) 2916 assert server._sampling.server_name == "int_test" 2917 assert server._sampling.max_rpm == 5 2918 2919 def test_sampling_handler_none_when_disabled(self): 2920 """MCPServerTask._sampling is None when sampling is disabled.""" 2921 from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES 2922 2923 server = MCPServerTask("int_test2") 2924 config = { 2925 "command": "fake", 2926 "sampling": {"enabled": False}, 2927 } 2928 server._config = config 2929 sampling_config = config.get("sampling", {}) 2930 if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES: 2931 server._sampling = SamplingHandler(server.name, sampling_config) 2932 else: 2933 server._sampling = None 2934 2935 assert server._sampling is None 2936 2937 def test_session_kwargs_used_in_stdio(self): 2938 """When sampling is set, session_kwargs() are passed to ClientSession.""" 2939 from tools.mcp_tool import MCPServerTask 2940 2941 server = MCPServerTask("sk_test") 2942 server._sampling = SamplingHandler("sk_test", {"max_rpm": 7}) 2943 kwargs = server._sampling.session_kwargs() 2944 assert "sampling_callback" in kwargs 2945 assert "sampling_capabilities" in kwargs 2946 2947 2948 # --------------------------------------------------------------------------- 2949 # Discovery failed_count tracking 2950 # --------------------------------------------------------------------------- 2951 2952 class TestDiscoveryFailedCount: 2953 """Verify discover_mcp_tools() correctly tracks failed server connections.""" 2954 2955 def test_failed_server_increments_failed_count(self): 2956 """When _discover_and_register_server raises, failed_count increments.""" 2957 from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop 2958 2959 fake_config = { 2960 "good_server": {"command": "npx", "args": ["good"]}, 2961 "bad_server": {"command": "npx", "args": ["bad"]}, 2962 } 2963 2964 async def fake_register(name, cfg): 2965 if name == "bad_server": 2966 raise ConnectionError("Connection refused") 2967 # Simulate successful registration 2968 from tools.mcp_tool import MCPServerTask 2969 server = MCPServerTask(name) 2970 server.session = MagicMock() 2971 server._tools = [_make_mcp_tool("tool_a")] 2972 _servers[name] = server 2973 return [f"mcp_{name}_tool_a"] 2974 2975 with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 2976 patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \ 2977 patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 2978 patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_good_server_tool_a"]): 2979 _ensure_mcp_loop() 2980 2981 # Capture the logger to verify failed_count in summary 2982 with patch("tools.mcp_tool.logger") as mock_logger: 2983 discover_mcp_tools() 2984 2985 # Find the summary info call 2986 info_calls = [ 2987 str(call) 2988 for call in mock_logger.info.call_args_list 2989 if "failed" in str(call).lower() or "MCP:" in str(call) 2990 ] 2991 # The summary should mention the failure 2992 assert any("1 failed" in str(c) for c in info_calls), ( 2993 f"Summary should report 1 failed server, got: {info_calls}" 2994 ) 2995 2996 _servers.pop("good_server", None) 2997 _servers.pop("bad_server", None) 2998 2999 def test_all_servers_fail_still_prints_summary(self): 3000 """When all servers fail, a summary with failure count is still printed.""" 3001 from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop 3002 3003 fake_config = { 3004 "srv1": {"command": "npx", "args": ["a"]}, 3005 "srv2": {"command": "npx", "args": ["b"]}, 3006 } 3007 3008 async def always_fail(name, cfg): 3009 raise ConnectionError(f"Server {name} refused") 3010 3011 with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 3012 patch("tools.mcp_tool._discover_and_register_server", side_effect=always_fail), \ 3013 patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 3014 patch("tools.mcp_tool._existing_tool_names", return_value=[]): 3015 _ensure_mcp_loop() 3016 3017 with patch("tools.mcp_tool.logger") as mock_logger: 3018 discover_mcp_tools() 3019 3020 # Summary must be printed even when all servers fail 3021 info_calls = [str(call) for call in mock_logger.info.call_args_list] 3022 assert any("2 failed" in str(c) for c in info_calls), ( 3023 f"Summary should report 2 failed servers, got: {info_calls}" 3024 ) 3025 3026 _servers.pop("srv1", None) 3027 _servers.pop("srv2", None) 3028 3029 def test_ok_servers_excludes_failures(self): 3030 """ok_servers count correctly excludes failed servers.""" 3031 from tools.mcp_tool import discover_mcp_tools, _servers, _ensure_mcp_loop 3032 3033 fake_config = { 3034 "ok1": {"command": "npx", "args": ["ok1"]}, 3035 "ok2": {"command": "npx", "args": ["ok2"]}, 3036 "fail1": {"command": "npx", "args": ["fail"]}, 3037 } 3038 3039 async def selective_register(name, cfg): 3040 if name == "fail1": 3041 raise ConnectionError("Refused") 3042 from tools.mcp_tool import MCPServerTask 3043 server = MCPServerTask(name) 3044 server.session = MagicMock() 3045 server._tools = [_make_mcp_tool("t")] 3046 _servers[name] = server 3047 return [f"mcp_{name}_t"] 3048 3049 with patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 3050 patch("tools.mcp_tool._discover_and_register_server", side_effect=selective_register), \ 3051 patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 3052 patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_ok1_t", "mcp_ok2_t"]): 3053 _ensure_mcp_loop() 3054 3055 with patch("tools.mcp_tool.logger") as mock_logger: 3056 discover_mcp_tools() 3057 3058 info_calls = [str(call) for call in mock_logger.info.call_args_list] 3059 # Should say "2 server(s)" not "3 server(s)" 3060 assert any("2 server" in str(c) for c in info_calls), ( 3061 f"Summary should report 2 ok servers, got: {info_calls}" 3062 ) 3063 assert any("1 failed" in str(c) for c in info_calls), ( 3064 f"Summary should report 1 failed, got: {info_calls}" 3065 ) 3066 3067 _servers.pop("ok1", None) 3068 _servers.pop("ok2", None) 3069 _servers.pop("fail1", None) 3070 3071 3072 class TestMCPSelectiveToolLoading: 3073 """Tests for per-server MCP filtering and utility tool policies.""" 3074 3075 def _make_server(self, name, tool_names, session=None): 3076 server = _make_mock_server( 3077 name, 3078 session=session or SimpleNamespace(), 3079 tools=[_make_mcp_tool(n, n) for n in tool_names], 3080 ) 3081 return server 3082 3083 def _run_discover(self, name, tool_names, config, session=None): 3084 from tools.registry import ToolRegistry 3085 from tools.mcp_tool import _discover_and_register_server, _servers 3086 3087 mock_registry = ToolRegistry() 3088 server = self._make_server(name, tool_names, session=session) 3089 3090 async def fake_connect(_name, _config): 3091 return server 3092 3093 async def run(): 3094 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 3095 patch("tools.registry.registry", mock_registry), \ 3096 patch("toolsets.create_custom_toolset"): 3097 return await _discover_and_register_server(name, config) 3098 3099 try: 3100 registered = asyncio.run(run()) 3101 finally: 3102 _servers.pop(name, None) 3103 return registered, mock_registry 3104 3105 def test_include_takes_precedence_over_exclude(self): 3106 config = { 3107 "url": "https://mcp.example.com", 3108 "tools": { 3109 "include": ["create_service"], 3110 "exclude": ["create_service", "delete_service"], 3111 }, 3112 } 3113 registered, _ = self._run_discover( 3114 "ink", 3115 ["create_service", "delete_service", "list_services"], 3116 config, 3117 session=SimpleNamespace(), 3118 ) 3119 assert registered == ["mcp_ink_create_service"] 3120 3121 def test_exclude_filter_registers_all_except_listed_tools(self): 3122 config = { 3123 "url": "https://mcp.example.com", 3124 "tools": {"exclude": ["delete_service"]}, 3125 } 3126 registered, _ = self._run_discover( 3127 "ink_exclude", 3128 ["create_service", "delete_service", "list_services"], 3129 config, 3130 session=SimpleNamespace(), 3131 ) 3132 assert registered == [ 3133 "mcp_ink_exclude_create_service", 3134 "mcp_ink_exclude_list_services", 3135 ] 3136 3137 def test_include_filter_skips_utility_tools_without_capabilities(self): 3138 config = { 3139 "url": "https://mcp.example.com", 3140 "tools": {"include": ["create_service"]}, 3141 } 3142 registered, mock_registry = self._run_discover( 3143 "ink_no_caps", 3144 ["create_service", "delete_service"], 3145 config, 3146 session=SimpleNamespace(), 3147 ) 3148 assert registered == ["mcp_ink_no_caps_create_service"] 3149 assert set(mock_registry.get_all_tool_names()) == {"mcp_ink_no_caps_create_service"} 3150 3151 def test_no_filter_registers_all_server_tools_when_no_utilities_supported(self): 3152 registered, _ = self._run_discover( 3153 "ink_no_filter", 3154 ["create_service", "delete_service", "list_services"], 3155 {"url": "https://mcp.example.com"}, 3156 session=SimpleNamespace(), 3157 ) 3158 assert registered == [ 3159 "mcp_ink_no_filter_create_service", 3160 "mcp_ink_no_filter_delete_service", 3161 "mcp_ink_no_filter_list_services", 3162 ] 3163 3164 def test_resources_and_prompts_can_be_disabled_explicitly(self): 3165 session = SimpleNamespace( 3166 list_resources=AsyncMock(), 3167 read_resource=AsyncMock(), 3168 list_prompts=AsyncMock(), 3169 get_prompt=AsyncMock(), 3170 ) 3171 config = { 3172 "url": "https://mcp.example.com", 3173 "tools": { 3174 "resources": False, 3175 "prompts": False, 3176 }, 3177 } 3178 registered, _ = self._run_discover( 3179 "ink_disabled_utils", 3180 ["create_service"], 3181 config, 3182 session=session, 3183 ) 3184 assert registered == ["mcp_ink_disabled_utils_create_service"] 3185 3186 def test_registers_only_utility_tools_supported_by_server_capabilities(self): 3187 session = SimpleNamespace( 3188 list_resources=AsyncMock(return_value=SimpleNamespace(resources=[])), 3189 read_resource=AsyncMock(return_value=SimpleNamespace(contents=[])), 3190 ) 3191 registered, _ = self._run_discover( 3192 "ink_resources_only", 3193 ["create_service"], 3194 {"url": "https://mcp.example.com"}, 3195 session=session, 3196 ) 3197 assert "mcp_ink_resources_only_create_service" in registered 3198 assert "mcp_ink_resources_only_list_resources" in registered 3199 assert "mcp_ink_resources_only_read_resource" in registered 3200 assert "mcp_ink_resources_only_list_prompts" not in registered 3201 assert "mcp_ink_resources_only_get_prompt" not in registered 3202 3203 def test_existing_tool_names_reflect_registered_subset(self): 3204 from tools.mcp_tool import _existing_tool_names, _servers, _discover_and_register_server 3205 from tools.registry import ToolRegistry 3206 3207 mock_registry = ToolRegistry() 3208 server = self._make_server( 3209 "ink_existing", 3210 ["create_service", "delete_service"], 3211 session=SimpleNamespace(), 3212 ) 3213 3214 async def fake_connect(_name, _config): 3215 return server 3216 3217 async def run(): 3218 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 3219 patch.dict("tools.mcp_tool._servers", {}, clear=True), \ 3220 patch("tools.registry.registry", mock_registry), \ 3221 patch("toolsets.create_custom_toolset"): 3222 registered = await _discover_and_register_server( 3223 "ink_existing", 3224 {"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}}, 3225 ) 3226 return registered, _existing_tool_names() 3227 3228 try: 3229 registered, existing = asyncio.run(run()) 3230 assert registered == ["mcp_ink_existing_create_service"] 3231 assert existing == ["mcp_ink_existing_create_service"] 3232 finally: 3233 _servers.pop("ink_existing", None) 3234 3235 def test_no_toolset_created_when_everything_is_filtered_out(self): 3236 from tools.registry import ToolRegistry 3237 from tools.mcp_tool import _discover_and_register_server, _servers 3238 3239 mock_registry = ToolRegistry() 3240 server = self._make_server("ink_none", ["create_service"], session=SimpleNamespace()) 3241 mock_create = MagicMock() 3242 3243 async def fake_connect(_name, _config): 3244 return server 3245 3246 async def run(): 3247 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 3248 patch("tools.registry.registry", mock_registry), \ 3249 patch("toolsets.create_custom_toolset", mock_create): 3250 return await _discover_and_register_server( 3251 "ink_none", 3252 { 3253 "url": "https://mcp.example.com", 3254 "tools": { 3255 "include": ["missing_tool"], 3256 "resources": False, 3257 "prompts": False, 3258 }, 3259 }, 3260 ) 3261 3262 try: 3263 registered = asyncio.run(run()) 3264 assert registered == [] 3265 mock_create.assert_not_called() 3266 assert mock_registry.get_all_tool_names() == [] 3267 finally: 3268 _servers.pop("ink_none", None) 3269 3270 def test_enabled_false_skips_connection_attempt(self): 3271 from tools.mcp_tool import discover_mcp_tools 3272 3273 connect_called = [] 3274 3275 async def fake_connect(name, config): 3276 connect_called.append(name) 3277 return self._make_server(name, ["create_service"]) 3278 3279 fake_config = { 3280 "ink": { 3281 "url": "https://mcp.example.com", 3282 "enabled": False, 3283 } 3284 } 3285 fake_toolsets = { 3286 "hermes-cli": {"tools": [], "description": "CLI", "includes": []}, 3287 } 3288 3289 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 3290 patch("tools.mcp_tool._servers", {}), \ 3291 patch("tools.mcp_tool._load_mcp_config", return_value=fake_config), \ 3292 patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 3293 patch("toolsets.TOOLSETS", fake_toolsets): 3294 result = discover_mcp_tools() 3295 3296 assert connect_called == [] 3297 assert result == [] 3298 3299 3300 # --------------------------------------------------------------------------- 3301 # Tool name collision protection 3302 # --------------------------------------------------------------------------- 3303 3304 class TestRegistryCollisionWarning: 3305 """registry.register() warns when a tool name is overwritten by a different toolset.""" 3306 3307 def test_overwrite_different_toolset_logs_warning(self, caplog): 3308 """Overwriting a tool from a different toolset is REJECTED with an error.""" 3309 from tools.registry import ToolRegistry 3310 import logging 3311 3312 reg = ToolRegistry() 3313 schema = {"name": "my_tool", "description": "test", "parameters": {"type": "object", "properties": {}}} 3314 handler = lambda args, **kw: "{}" 3315 3316 reg.register(name="my_tool", toolset="builtin", schema=schema, handler=handler) 3317 3318 with caplog.at_level(logging.ERROR, logger="tools.registry"): 3319 reg.register(name="my_tool", toolset="mcp-ext", schema=schema, handler=handler) 3320 3321 assert any("rejected" in r.message.lower() for r in caplog.records) 3322 assert any("builtin" in r.message and "mcp-ext" in r.message for r in caplog.records) 3323 # The original tool should still be from 'builtin', not overwritten 3324 assert reg.get_toolset_for_tool("my_tool") == "builtin" 3325 3326 def test_overwrite_same_toolset_no_warning(self, caplog): 3327 """Re-registering within the same toolset is silent (e.g. reconnect).""" 3328 from tools.registry import ToolRegistry 3329 import logging 3330 3331 reg = ToolRegistry() 3332 schema = {"name": "my_tool", "description": "test", "parameters": {"type": "object", "properties": {}}} 3333 handler = lambda args, **kw: "{}" 3334 3335 reg.register(name="my_tool", toolset="mcp-server", schema=schema, handler=handler) 3336 3337 with caplog.at_level(logging.WARNING, logger="tools.registry"): 3338 reg.register(name="my_tool", toolset="mcp-server", schema=schema, handler=handler) 3339 3340 assert not any("collision" in r.message.lower() for r in caplog.records) 3341 3342 3343 class TestMCPBuiltinCollisionGuard: 3344 """MCP tools that collide with built-in tool names are skipped.""" 3345 3346 def test_mcp_tool_skipped_when_builtin_exists(self): 3347 """An MCP tool whose prefixed name collides with a built-in is skipped.""" 3348 from tools.registry import ToolRegistry 3349 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 3350 3351 mock_registry = ToolRegistry() 3352 3353 # Pre-register a "built-in" tool with the name that the MCP tool would produce. 3354 # Server "abc", tool "search" → mcp_abc_search 3355 builtin_schema = { 3356 "name": "mcp_abc_search", 3357 "description": "A hypothetical built-in", 3358 "parameters": {"type": "object", "properties": {}}, 3359 } 3360 mock_registry.register( 3361 name="mcp_abc_search", toolset="web", 3362 schema=builtin_schema, handler=lambda a, **k: "{}", 3363 ) 3364 3365 mock_tools = [_make_mcp_tool("search", "Search the web")] 3366 mock_session = MagicMock() 3367 3368 async def fake_connect(name, config): 3369 server = MCPServerTask(name) 3370 server.session = mock_session 3371 server._tools = mock_tools 3372 return server 3373 3374 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 3375 patch("tools.registry.registry", mock_registry): 3376 registered = asyncio.run( 3377 _discover_and_register_server("abc", {"command": "test", "args": []}) 3378 ) 3379 3380 # The MCP tool should have been skipped — built-in preserved. 3381 assert "mcp_abc_search" not in registered 3382 assert mock_registry.get_toolset_for_tool("mcp_abc_search") == "web" 3383 3384 _servers.pop("abc", None) 3385 3386 def test_mcp_tool_registered_when_no_builtin_collision(self): 3387 """MCP tools register normally when there's no collision.""" 3388 from tools.registry import ToolRegistry 3389 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 3390 3391 mock_registry = ToolRegistry() 3392 mock_tools = [_make_mcp_tool("web_search", "Search the web")] 3393 mock_session = MagicMock() 3394 3395 async def fake_connect(name, config): 3396 server = MCPServerTask(name) 3397 server.session = mock_session 3398 server._tools = mock_tools 3399 return server 3400 3401 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 3402 patch("tools.registry.registry", mock_registry): 3403 registered = asyncio.run( 3404 _discover_and_register_server("minimax", {"command": "test", "args": []}) 3405 ) 3406 3407 assert "mcp_minimax_web_search" in registered 3408 assert mock_registry.get_toolset_for_tool("mcp_minimax_web_search") == "mcp-minimax" 3409 3410 _servers.pop("minimax", None) 3411 3412 def test_mcp_tool_allowed_when_collision_is_another_mcp(self): 3413 """Collision between two MCP toolsets is allowed (last wins).""" 3414 from tools.registry import ToolRegistry 3415 from tools.mcp_tool import _discover_and_register_server, _servers, MCPServerTask 3416 3417 mock_registry = ToolRegistry() 3418 3419 # Pre-register an MCP tool from a different server. 3420 mcp_schema = { 3421 "name": "mcp_srv_do_thing", 3422 "description": "From another MCP server", 3423 "parameters": {"type": "object", "properties": {}}, 3424 } 3425 mock_registry.register( 3426 name="mcp_srv_do_thing", toolset="mcp-old", 3427 schema=mcp_schema, handler=lambda a, **k: "{}", 3428 ) 3429 3430 mock_tools = [_make_mcp_tool("do_thing", "Do a thing")] 3431 mock_session = MagicMock() 3432 3433 async def fake_connect(name, config): 3434 server = MCPServerTask(name) 3435 server.session = mock_session 3436 server._tools = mock_tools 3437 return server 3438 3439 with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ 3440 patch("tools.registry.registry", mock_registry): 3441 registered = asyncio.run( 3442 _discover_and_register_server("srv", {"command": "test", "args": []}) 3443 ) 3444 3445 # MCP-to-MCP collision is allowed — the new server wins. 3446 assert "mcp_srv_do_thing" in registered 3447 assert mock_registry.get_toolset_for_tool("mcp_srv_do_thing") == "mcp-srv" 3448 3449 _servers.pop("srv", None) 3450 3451 3452 # --------------------------------------------------------------------------- 3453 # sanitize_mcp_name_component 3454 # --------------------------------------------------------------------------- 3455 3456 3457 class TestSanitizeMcpNameComponent: 3458 """Verify sanitize_mcp_name_component handles all edge cases.""" 3459 3460 def test_hyphens_replaced(self): 3461 from tools.mcp_tool import sanitize_mcp_name_component 3462 assert sanitize_mcp_name_component("my-server") == "my_server" 3463 3464 def test_dots_replaced(self): 3465 from tools.mcp_tool import sanitize_mcp_name_component 3466 assert sanitize_mcp_name_component("ai.exa") == "ai_exa" 3467 3468 def test_slashes_replaced(self): 3469 from tools.mcp_tool import sanitize_mcp_name_component 3470 assert sanitize_mcp_name_component("ai.exa/exa") == "ai_exa_exa" 3471 3472 def test_mixed_special_characters(self): 3473 from tools.mcp_tool import sanitize_mcp_name_component 3474 assert sanitize_mcp_name_component("@scope/my-pkg.v2") == "_scope_my_pkg_v2" 3475 3476 def test_alphanumeric_and_underscores_preserved(self): 3477 from tools.mcp_tool import sanitize_mcp_name_component 3478 assert sanitize_mcp_name_component("my_server_123") == "my_server_123" 3479 3480 def test_empty_string(self): 3481 from tools.mcp_tool import sanitize_mcp_name_component 3482 assert sanitize_mcp_name_component("") == "" 3483 3484 def test_none_returns_empty(self): 3485 from tools.mcp_tool import sanitize_mcp_name_component 3486 assert sanitize_mcp_name_component(None) == "" 3487 3488 def test_slash_in_convert_mcp_schema(self): 3489 """Server names with slashes produce valid tool names via _convert_mcp_schema.""" 3490 from tools.mcp_tool import _convert_mcp_schema 3491 3492 mcp_tool = _make_mcp_tool(name="search") 3493 schema = _convert_mcp_schema("ai.exa/exa", mcp_tool) 3494 assert schema["name"] == "mcp_ai_exa_exa_search" 3495 # Must match Anthropic's pattern: ^[a-zA-Z0-9_-]{1,128}$ 3496 import re 3497 assert re.match(r"^[a-zA-Z0-9_-]{1,128}$", schema["name"]) 3498 3499 def test_slash_in_build_utility_schemas(self): 3500 """Server names with slashes produce valid utility tool names.""" 3501 from tools.mcp_tool import _build_utility_schemas 3502 3503 schemas = _build_utility_schemas("ai.exa/exa") 3504 for s in schemas: 3505 name = s["schema"]["name"] 3506 assert "/" not in name 3507 assert "." not in name 3508 3509 def test_slash_in_server_alias_resolution(self): 3510 """Server names with slashes resolve through their live MCP alias.""" 3511 from tools.registry import ToolRegistry 3512 from toolsets import resolve_toolset, validate_toolset 3513 3514 reg = ToolRegistry() 3515 reg.register( 3516 name="mcp_ai_exa_exa_search", 3517 toolset="mcp-ai.exa/exa", 3518 schema={"name": "mcp_ai_exa_exa_search", "description": "Search", "parameters": {"type": "object", "properties": {}}}, 3519 handler=lambda *_args, **_kwargs: "{}", 3520 ) 3521 reg.register_toolset_alias("ai.exa/exa", "mcp-ai.exa/exa") 3522 3523 with patch("tools.registry.registry", reg): 3524 assert validate_toolset("ai.exa/exa") is True 3525 assert "mcp_ai_exa_exa_search" in resolve_toolset("ai.exa/exa") 3526 3527 3528 # --------------------------------------------------------------------------- 3529 # register_mcp_servers public API 3530 # --------------------------------------------------------------------------- 3531 3532 3533 class TestRegisterMcpServers: 3534 """Verify the new register_mcp_servers() public API.""" 3535 3536 def test_empty_servers_returns_empty(self): 3537 from tools.mcp_tool import register_mcp_servers 3538 3539 with patch("tools.mcp_tool._MCP_AVAILABLE", True): 3540 result = register_mcp_servers({}) 3541 assert result == [] 3542 3543 def test_mcp_not_available_returns_empty(self): 3544 from tools.mcp_tool import register_mcp_servers 3545 3546 with patch("tools.mcp_tool._MCP_AVAILABLE", False): 3547 result = register_mcp_servers({"srv": {"command": "test"}}) 3548 assert result == [] 3549 3550 def test_skips_already_connected_servers(self): 3551 from tools.mcp_tool import register_mcp_servers, _servers 3552 3553 mock_server = _make_mock_server("existing") 3554 _servers["existing"] = mock_server 3555 3556 try: 3557 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 3558 patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_existing_tool"]): 3559 result = register_mcp_servers({"existing": {"command": "test"}}) 3560 assert result == ["mcp_existing_tool"] 3561 finally: 3562 _servers.pop("existing", None) 3563 3564 def test_skips_disabled_servers(self): 3565 from tools.mcp_tool import register_mcp_servers, _servers 3566 3567 try: 3568 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 3569 patch("tools.mcp_tool._existing_tool_names", return_value=[]): 3570 result = register_mcp_servers({"srv": {"command": "test", "enabled": False}}) 3571 assert result == [] 3572 finally: 3573 _servers.pop("srv", None) 3574 3575 def test_connects_new_servers(self): 3576 from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop 3577 3578 fake_config = {"my_server": {"command": "npx", "args": ["test"]}} 3579 3580 async def fake_register(name, cfg): 3581 server = _make_mock_server(name) 3582 server._registered_tool_names = ["mcp_my_server_tool1"] 3583 _servers[name] = server 3584 return ["mcp_my_server_tool1"] 3585 3586 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 3587 patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \ 3588 patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_my_server_tool1"]): 3589 _ensure_mcp_loop() 3590 result = register_mcp_servers(fake_config) 3591 3592 assert "mcp_my_server_tool1" in result 3593 _servers.pop("my_server", None) 3594 3595 def test_logs_summary_on_success(self): 3596 from tools.mcp_tool import register_mcp_servers, _servers, _ensure_mcp_loop 3597 3598 fake_config = {"srv": {"command": "npx", "args": ["test"]}} 3599 3600 async def fake_register(name, cfg): 3601 server = _make_mock_server(name) 3602 server._registered_tool_names = ["mcp_srv_t1", "mcp_srv_t2"] 3603 _servers[name] = server 3604 return ["mcp_srv_t1", "mcp_srv_t2"] 3605 3606 with patch("tools.mcp_tool._MCP_AVAILABLE", True), \ 3607 patch("tools.mcp_tool._discover_and_register_server", side_effect=fake_register), \ 3608 patch("tools.mcp_tool._existing_tool_names", return_value=["mcp_srv_t1", "mcp_srv_t2"]): 3609 _ensure_mcp_loop() 3610 3611 with patch("tools.mcp_tool.logger") as mock_logger: 3612 register_mcp_servers(fake_config) 3613 3614 info_calls = [str(c) for c in mock_logger.info.call_args_list] 3615 assert any("2 tool(s)" in c and "1 server(s)" in c for c in info_calls), ( 3616 f"Summary should report 2 tools from 1 server, got: {info_calls}" 3617 ) 3618 3619 _servers.pop("srv", None)