test_mcp_tool_401_handling.py
1 """Tests for MCP tool-handler auth-failure detection. 2 3 When a tool call raises UnauthorizedError / OAuthNonInteractiveError / 4 httpx.HTTPStatusError(401), the handler should: 5 1. Ask MCPOAuthManager.handle_401 if recovery is viable. 6 2. If yes, trigger MCPServerTask._reconnect_event and retry once. 7 3. If no, return a structured needs_reauth error so the model stops 8 hallucinating manual refresh attempts. 9 """ 10 import json 11 from unittest.mock import MagicMock 12 13 import pytest 14 15 16 pytest.importorskip("mcp.client.auth.oauth2") 17 18 19 def test_is_auth_error_detects_oauth_flow_error(): 20 from tools.mcp_tool import _is_auth_error 21 from mcp.client.auth import OAuthFlowError 22 23 assert _is_auth_error(OAuthFlowError("expired")) is True 24 25 26 def test_is_auth_error_detects_oauth_non_interactive(): 27 from tools.mcp_tool import _is_auth_error 28 from tools.mcp_oauth import OAuthNonInteractiveError 29 30 assert _is_auth_error(OAuthNonInteractiveError("no browser")) is True 31 32 33 def test_is_auth_error_detects_httpx_401(): 34 from tools.mcp_tool import _is_auth_error 35 import httpx 36 37 response = MagicMock() 38 response.status_code = 401 39 exc = httpx.HTTPStatusError("unauth", request=MagicMock(), response=response) 40 assert _is_auth_error(exc) is True 41 42 43 def test_is_auth_error_rejects_httpx_500(): 44 from tools.mcp_tool import _is_auth_error 45 import httpx 46 47 response = MagicMock() 48 response.status_code = 500 49 exc = httpx.HTTPStatusError("oops", request=MagicMock(), response=response) 50 assert _is_auth_error(exc) is False 51 52 53 def test_is_auth_error_rejects_generic_exception(): 54 from tools.mcp_tool import _is_auth_error 55 assert _is_auth_error(ValueError("not auth")) is False 56 assert _is_auth_error(RuntimeError("not auth")) is False 57 58 59 def test_call_tool_handler_returns_needs_reauth_on_unrecoverable_401(monkeypatch, tmp_path): 60 """When session.call_tool raises 401 and handle_401 returns False, 61 handler returns a structured needs_reauth error (not a generic failure).""" 62 monkeypatch.setenv("HERMES_HOME", str(tmp_path)) 63 64 from tools.mcp_tool import _make_tool_handler 65 from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests 66 from mcp.client.auth import OAuthFlowError 67 68 reset_manager_for_tests() 69 70 # Stub server 71 server = MagicMock() 72 server.name = "srv" 73 session = MagicMock() 74 75 async def _call_tool_raises(*a, **kw): 76 raise OAuthFlowError("token expired") 77 78 session.call_tool = _call_tool_raises 79 server.session = session 80 server._reconnect_event = MagicMock() 81 server._ready = MagicMock() 82 server._ready.is_set.return_value = True 83 84 from tools import mcp_tool 85 mcp_tool._servers["srv"] = server 86 mcp_tool._server_error_counts.pop("srv", None) 87 88 # Ensure the MCP loop exists (run_on_mcp_loop needs it) 89 mcp_tool._ensure_mcp_loop() 90 91 # Force handle_401 to return False (no recovery available) 92 mgr = get_manager() 93 94 async def _h401(name, token=None): 95 return False 96 97 monkeypatch.setattr(mgr, "handle_401", _h401) 98 99 try: 100 handler = _make_tool_handler("srv", "tool1", 10.0) 101 result = handler({"arg": "v"}) 102 parsed = json.loads(result) 103 assert parsed.get("needs_reauth") is True, f"expected needs_reauth, got: {parsed}" 104 assert parsed.get("server") == "srv" 105 assert "re-auth" in parsed.get("error", "").lower() or "reauth" in parsed.get("error", "").lower() 106 finally: 107 mcp_tool._servers.pop("srv", None) 108 mcp_tool._server_error_counts.pop("srv", None) 109 110 111 def test_call_tool_handler_non_auth_error_still_generic(monkeypatch, tmp_path): 112 """Non-auth exceptions still surface via the generic error path, not needs_reauth.""" 113 monkeypatch.setenv("HERMES_HOME", str(tmp_path)) 114 from tools.mcp_tool import _make_tool_handler 115 116 server = MagicMock() 117 server.name = "srv" 118 session = MagicMock() 119 120 async def _raises(*a, **kw): 121 raise RuntimeError("unrelated") 122 123 session.call_tool = _raises 124 server.session = session 125 126 from tools import mcp_tool 127 mcp_tool._servers["srv"] = server 128 mcp_tool._server_error_counts.pop("srv", None) 129 mcp_tool._ensure_mcp_loop() 130 131 try: 132 handler = _make_tool_handler("srv", "tool1", 10.0) 133 result = handler({"arg": "v"}) 134 parsed = json.loads(result) 135 assert "needs_reauth" not in parsed 136 assert "MCP call failed" in parsed.get("error", "") 137 finally: 138 mcp_tool._servers.pop("srv", None) 139 mcp_tool._server_error_counts.pop("srv", None)