/ tests / tools / test_mcp_tool_401_handling.py
test_mcp_tool_401_handling.py
  1  """Tests for MCP tool-handler auth-failure detection.
  2  
  3  When a tool call raises UnauthorizedError / OAuthNonInteractiveError /
  4  httpx.HTTPStatusError(401), the handler should:
  5    1. Ask MCPOAuthManager.handle_401 if recovery is viable.
  6    2. If yes, trigger MCPServerTask._reconnect_event and retry once.
  7    3. If no, return a structured needs_reauth error so the model stops
  8       hallucinating manual refresh attempts.
  9  """
 10  import json
 11  from unittest.mock import MagicMock
 12  
 13  import pytest
 14  
 15  
 16  pytest.importorskip("mcp.client.auth.oauth2")
 17  
 18  
 19  def test_is_auth_error_detects_oauth_flow_error():
 20      from tools.mcp_tool import _is_auth_error
 21      from mcp.client.auth import OAuthFlowError
 22  
 23      assert _is_auth_error(OAuthFlowError("expired")) is True
 24  
 25  
 26  def test_is_auth_error_detects_oauth_non_interactive():
 27      from tools.mcp_tool import _is_auth_error
 28      from tools.mcp_oauth import OAuthNonInteractiveError
 29  
 30      assert _is_auth_error(OAuthNonInteractiveError("no browser")) is True
 31  
 32  
 33  def test_is_auth_error_detects_httpx_401():
 34      from tools.mcp_tool import _is_auth_error
 35      import httpx
 36  
 37      response = MagicMock()
 38      response.status_code = 401
 39      exc = httpx.HTTPStatusError("unauth", request=MagicMock(), response=response)
 40      assert _is_auth_error(exc) is True
 41  
 42  
 43  def test_is_auth_error_rejects_httpx_500():
 44      from tools.mcp_tool import _is_auth_error
 45      import httpx
 46  
 47      response = MagicMock()
 48      response.status_code = 500
 49      exc = httpx.HTTPStatusError("oops", request=MagicMock(), response=response)
 50      assert _is_auth_error(exc) is False
 51  
 52  
 53  def test_is_auth_error_rejects_generic_exception():
 54      from tools.mcp_tool import _is_auth_error
 55      assert _is_auth_error(ValueError("not auth")) is False
 56      assert _is_auth_error(RuntimeError("not auth")) is False
 57  
 58  
 59  def test_call_tool_handler_returns_needs_reauth_on_unrecoverable_401(monkeypatch, tmp_path):
 60      """When session.call_tool raises 401 and handle_401 returns False,
 61      handler returns a structured needs_reauth error (not a generic failure)."""
 62      monkeypatch.setenv("HERMES_HOME", str(tmp_path))
 63  
 64      from tools.mcp_tool import _make_tool_handler
 65      from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
 66      from mcp.client.auth import OAuthFlowError
 67  
 68      reset_manager_for_tests()
 69  
 70      # Stub server
 71      server = MagicMock()
 72      server.name = "srv"
 73      session = MagicMock()
 74  
 75      async def _call_tool_raises(*a, **kw):
 76          raise OAuthFlowError("token expired")
 77  
 78      session.call_tool = _call_tool_raises
 79      server.session = session
 80      server._reconnect_event = MagicMock()
 81      server._ready = MagicMock()
 82      server._ready.is_set.return_value = True
 83  
 84      from tools import mcp_tool
 85      mcp_tool._servers["srv"] = server
 86      mcp_tool._server_error_counts.pop("srv", None)
 87  
 88      # Ensure the MCP loop exists (run_on_mcp_loop needs it)
 89      mcp_tool._ensure_mcp_loop()
 90  
 91      # Force handle_401 to return False (no recovery available)
 92      mgr = get_manager()
 93  
 94      async def _h401(name, token=None):
 95          return False
 96  
 97      monkeypatch.setattr(mgr, "handle_401", _h401)
 98  
 99      try:
100          handler = _make_tool_handler("srv", "tool1", 10.0)
101          result = handler({"arg": "v"})
102          parsed = json.loads(result)
103          assert parsed.get("needs_reauth") is True, f"expected needs_reauth, got: {parsed}"
104          assert parsed.get("server") == "srv"
105          assert "re-auth" in parsed.get("error", "").lower() or "reauth" in parsed.get("error", "").lower()
106      finally:
107          mcp_tool._servers.pop("srv", None)
108          mcp_tool._server_error_counts.pop("srv", None)
109  
110  
111  def test_call_tool_handler_non_auth_error_still_generic(monkeypatch, tmp_path):
112      """Non-auth exceptions still surface via the generic error path, not needs_reauth."""
113      monkeypatch.setenv("HERMES_HOME", str(tmp_path))
114      from tools.mcp_tool import _make_tool_handler
115  
116      server = MagicMock()
117      server.name = "srv"
118      session = MagicMock()
119  
120      async def _raises(*a, **kw):
121          raise RuntimeError("unrelated")
122  
123      session.call_tool = _raises
124      server.session = session
125  
126      from tools import mcp_tool
127      mcp_tool._servers["srv"] = server
128      mcp_tool._server_error_counts.pop("srv", None)
129      mcp_tool._ensure_mcp_loop()
130  
131      try:
132          handler = _make_tool_handler("srv", "tool1", 10.0)
133          result = handler({"arg": "v"})
134          parsed = json.loads(result)
135          assert "needs_reauth" not in parsed
136          assert "MCP call failed" in parsed.get("error", "")
137      finally:
138          mcp_tool._servers.pop("srv", None)
139          mcp_tool._server_error_counts.pop("srv", None)