mcp_oauth_manager.py
1 #!/usr/bin/env python3 2 """Central manager for per-server MCP OAuth state. 3 4 One instance shared across the process. Holds per-server OAuth provider 5 instances and coordinates: 6 7 - **Cross-process token reload** via mtime-based disk watch. When an external 8 process (e.g. a user cron job) refreshes tokens on disk, the next auth flow 9 picks them up without requiring a process restart. 10 - **401 deduplication** via in-flight futures. When N concurrent tool calls 11 all hit 401 with the same access_token, only one recovery attempt fires; 12 the rest await the same result. 13 - **Reconnect signalling** for long-lived MCP sessions. The manager itself 14 does not drive reconnection — the `MCPServerTask` in `mcp_tool.py` does — 15 but the manager is the single source of truth that decides when reconnect 16 is warranted. 17 18 Replaces what used to be scattered across eight call sites in `mcp_oauth.py`, 19 `mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place 20 that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths 21 go through `get_manager()`. 22 23 Design reference: 24 25 - Claude Code's ``invalidateOAuthCacheIfDiskChanged`` 26 (``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical 27 external-refresh staleness bug class. 28 - Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed`` 29 (``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's 30 lazy refresh rather than calling refresh before every op, because one 31 ``stat()`` per tool call is cheaper than an ``await`` + potential refresh 32 round-trip, and the SDK's in-memory expiry path is already correct. 33 """ 34 35 from __future__ import annotations 36 37 import asyncio 38 import logging 39 import threading 40 from dataclasses import dataclass, field 41 from typing import Any, Optional 42 43 logger = logging.getLogger(__name__) 44 45 46 # --------------------------------------------------------------------------- 47 # Per-server entry 48 # --------------------------------------------------------------------------- 49 50 51 @dataclass 52 class _ProviderEntry: 53 """Per-server OAuth state tracked by the manager. 54 55 Fields: 56 server_url: The MCP server URL used to build the provider. Tracked 57 so we can discard a cached provider if the URL changes. 58 oauth_config: Optional dict from ``mcp_servers.<name>.oauth``. 59 provider: The ``httpx.Auth``-compatible provider wrapping the MCP 60 SDK. None until first use. 61 last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file. 62 Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed` 63 to detect external refreshes. 64 lock: Serialises concurrent access to this entry's state. Bound to 65 whichever asyncio loop first awaits it (the MCP event loop). 66 pending_401: In-flight 401-handler futures keyed by the failed 67 access_token, for deduplicating thundering-herd 401s. Mirrors 68 Claude Code's ``pending401Handlers`` map. 69 """ 70 71 server_url: str 72 oauth_config: Optional[dict] 73 provider: Optional[Any] = None 74 last_mtime_ns: int = 0 75 lock: asyncio.Lock = field(default_factory=asyncio.Lock) 76 pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict) 77 78 79 # --------------------------------------------------------------------------- 80 # HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch 81 # --------------------------------------------------------------------------- 82 83 84 def _make_hermes_provider_class() -> Optional[type]: 85 """Lazy-import the SDK base class and return our subclass. 86 87 Wrapped in a function so this module imports cleanly even when the 88 MCP SDK's OAuth module is unavailable (e.g. older mcp versions). 89 """ 90 try: 91 from mcp.client.auth.oauth2 import OAuthClientProvider 92 except ImportError: # pragma: no cover — SDK required in CI 93 return None 94 95 class HermesMCPOAuthProvider(OAuthClientProvider): 96 """OAuthClientProvider with pre-flow disk-mtime reload. 97 98 Before every ``async_auth_flow`` invocation, asks the manager to 99 check whether the tokens file on disk has been modified externally. 100 If so, the manager resets ``_initialized`` so the next flow 101 re-reads from storage. 102 103 This makes external-process refreshes (cron, another CLI instance) 104 visible to the running MCP session without requiring a restart. 105 106 Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged`` 107 (``src/utils/auth.ts:1320``, CC-1096 / GH#24317). 108 """ 109 110 def __init__(self, *args: Any, server_name: str = "", **kwargs: Any): 111 super().__init__(*args, **kwargs) 112 self._hermes_server_name = server_name 113 114 async def _initialize(self) -> None: 115 """Load stored tokens + client info AND seed token_expiry_time. 116 117 Also eagerly fetches OAuth authorization-server metadata (PRM + 118 ASM) when we have stored tokens but no cached metadata, so the 119 SDK's ``_refresh_token`` can build the correct token_endpoint 120 URL on the preemptive-refresh path. Without this, the SDK 121 falls back to ``{mcp_server_url}/token`` (wrong for providers 122 whose AS is a different origin — BetterStack's MCP lives at 123 ``https://mcp.betterstack.com`` but its token endpoint is at 124 ``https://betterstack.com/oauth/token``), the refresh 404s, and 125 we drop through to full browser reauth. 126 127 The SDK's base ``_initialize`` populates ``current_tokens`` but 128 does NOT call ``update_token_expiry``, so ``token_expiry_time`` 129 stays ``None`` and ``is_token_valid()`` returns True for any 130 loaded token regardless of actual age. After a process restart 131 this ships stale Bearer tokens to the server; some providers 132 return HTTP 401 (caught by the 401 handler), others return 200 133 with an app-level auth error (invisible to the transport layer, 134 e.g. BetterStack returning "No teams found. Please check your 135 authentication."). 136 137 Seeding ``token_expiry_time`` from the reloaded token fixes that: 138 ``is_token_valid()`` correctly reports False for expired tokens, 139 ``async_auth_flow`` takes the ``can_refresh_token()`` branch, 140 and the SDK quietly refreshes before the first real request. 141 142 Paired with :class:`HermesTokenStorage` persisting an absolute 143 ``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the 144 remaining TTL we compute here reflects real wall-clock age. 145 """ 146 await super()._initialize() 147 tokens = self.context.current_tokens 148 if tokens is not None and tokens.expires_in is not None: 149 self.context.update_token_expiry(tokens) 150 151 # Pre-flight OAuth AS discovery so ``_refresh_token`` has a 152 # correct ``token_endpoint`` before the first refresh attempt. 153 # Only runs when we have tokens on cold-load but no cached 154 # metadata — i.e. the exact scenario where the SDK's built-in 155 # 401-branch discovery hasn't had a chance to run yet. 156 if ( 157 tokens is not None 158 and self.context.oauth_metadata is None 159 ): 160 try: 161 await self._prefetch_oauth_metadata() 162 except Exception as exc: # pragma: no cover — defensive 163 # Non-fatal: if discovery fails, the SDK's normal 401- 164 # branch discovery will run on the next request. 165 logger.debug( 166 "MCP OAuth '%s': pre-flight metadata discovery " 167 "failed (non-fatal): %s", 168 self._hermes_server_name, exc, 169 ) 170 171 async def _prefetch_oauth_metadata(self) -> None: 172 """Fetch PRM + ASM from the well-known endpoints, cache on context. 173 174 Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551) 175 but runs synchronously before the first request instead of 176 inside the httpx auth_flow generator. Uses the SDK's own URL 177 builders and response handlers so we track whatever the SDK 178 version we're pinned to expects. 179 """ 180 import httpx # local import: httpx is an MCP SDK dependency 181 from mcp.client.auth.utils import ( 182 build_oauth_authorization_server_metadata_discovery_urls, 183 build_protected_resource_metadata_discovery_urls, 184 create_oauth_metadata_request, 185 handle_auth_metadata_response, 186 handle_protected_resource_response, 187 ) 188 189 server_url = self.context.server_url 190 async with httpx.AsyncClient(timeout=10.0) as client: 191 # Step 1: PRM discovery to learn the authorization_server URL. 192 for url in build_protected_resource_metadata_discovery_urls( 193 None, server_url 194 ): 195 req = create_oauth_metadata_request(url) 196 try: 197 resp = await client.send(req) 198 except httpx.HTTPError as exc: 199 logger.debug( 200 "MCP OAuth '%s': PRM discovery to %s failed: %s", 201 self._hermes_server_name, url, exc, 202 ) 203 continue 204 prm = await handle_protected_resource_response(resp) 205 if prm: 206 self.context.protected_resource_metadata = prm 207 if prm.authorization_servers: 208 self.context.auth_server_url = str( 209 prm.authorization_servers[0] 210 ) 211 break 212 213 # Step 2: ASM discovery against the auth_server_url (or 214 # server_url fallback for legacy providers). 215 for url in build_oauth_authorization_server_metadata_discovery_urls( 216 self.context.auth_server_url, server_url 217 ): 218 req = create_oauth_metadata_request(url) 219 try: 220 resp = await client.send(req) 221 except httpx.HTTPError as exc: 222 logger.debug( 223 "MCP OAuth '%s': ASM discovery to %s failed: %s", 224 self._hermes_server_name, url, exc, 225 ) 226 continue 227 ok, asm = await handle_auth_metadata_response(resp) 228 if not ok: 229 break 230 if asm: 231 self.context.oauth_metadata = asm 232 logger.debug( 233 "MCP OAuth '%s': pre-flight ASM discovered " 234 "token_endpoint=%s", 235 self._hermes_server_name, asm.token_endpoint, 236 ) 237 break 238 239 async def async_auth_flow(self, request): # type: ignore[override] 240 # Pre-flow hook: ask the manager to refresh from disk if needed. 241 # Any failure here is non-fatal — we just log and proceed with 242 # whatever state the SDK already has. 243 try: 244 await get_manager().invalidate_if_disk_changed( 245 self._hermes_server_name 246 ) 247 except Exception as exc: # pragma: no cover — defensive 248 logger.debug( 249 "MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s", 250 self._hermes_server_name, exc, 251 ) 252 253 # Manually bridge the bidirectional generator protocol. httpx's 254 # auth_flow driver (httpx._client._send_handling_auth) calls 255 # ``auth_flow.asend(response)`` to feed HTTP responses back into 256 # the generator. A naive wrapper using ``async for item in inner: 257 # yield item`` DISCARDS those .asend(response) values and resumes 258 # the inner generator with None, so the SDK's 259 # ``response = yield request`` branch in 260 # mcp/client/auth/oauth2.py sees response=None and crashes at 261 # ``if response.status_code == 401`` with AttributeError. 262 # 263 # The bridge below forwards each .asend() value into the inner 264 # generator via inner.asend(incoming), preserving the bidirectional 265 # contract. Regression from PR #11383 caught by 266 # tests/tools/test_mcp_oauth_bidirectional.py. 267 inner = super().async_auth_flow(request) 268 try: 269 outgoing = await inner.__anext__() 270 while True: 271 incoming = yield outgoing 272 outgoing = await inner.asend(incoming) 273 except StopAsyncIteration: 274 return 275 276 return HermesMCPOAuthProvider 277 278 279 # Cached at import time. Tested and used by :class:`MCPOAuthManager`. 280 _HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class() 281 282 283 # --------------------------------------------------------------------------- 284 # Manager 285 # --------------------------------------------------------------------------- 286 287 288 class MCPOAuthManager: 289 """Single source of truth for per-server MCP OAuth state. 290 291 Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for 292 get-or-create semantics. Per-entry state is guarded by the entry's own 293 ``asyncio.Lock`` (used from the MCP event loop thread). 294 """ 295 296 def __init__(self) -> None: 297 self._entries: dict[str, _ProviderEntry] = {} 298 self._entries_lock = threading.Lock() 299 300 # -- Provider construction / caching ------------------------------------- 301 302 def get_or_build_provider( 303 self, 304 server_name: str, 305 server_url: str, 306 oauth_config: Optional[dict], 307 ) -> Optional[Any]: 308 """Return a cached OAuth provider for ``server_name`` or build one. 309 310 Idempotent: repeat calls with the same name return the same instance. 311 If ``server_url`` changes for a given name, the cached entry is 312 discarded and a fresh provider is built. 313 314 Returns None if the MCP SDK's OAuth support is unavailable. 315 """ 316 with self._entries_lock: 317 entry = self._entries.get(server_name) 318 if entry is not None and entry.server_url != server_url: 319 logger.info( 320 "MCP OAuth '%s': URL changed from %s to %s, discarding cache", 321 server_name, entry.server_url, server_url, 322 ) 323 entry = None 324 325 if entry is None: 326 entry = _ProviderEntry( 327 server_url=server_url, 328 oauth_config=oauth_config, 329 ) 330 self._entries[server_name] = entry 331 332 if entry.provider is None: 333 entry.provider = self._build_provider(server_name, entry) 334 335 return entry.provider 336 337 def _build_provider( 338 self, 339 server_name: str, 340 entry: _ProviderEntry, 341 ) -> Optional[Any]: 342 """Build the underlying OAuth provider. 343 344 Constructs :class:`HermesMCPOAuthProvider` directly using the helpers 345 extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow 346 disk-watch hook so external token refreshes (cron, other CLI 347 instances) are visible to running MCP sessions. 348 349 Returns None if the MCP SDK's OAuth support is unavailable. 350 """ 351 if _HERMES_PROVIDER_CLS is None: 352 logger.warning( 353 "MCP OAuth '%s': SDK auth module unavailable", server_name, 354 ) 355 return None 356 357 # Local imports avoid circular deps at module import time. 358 from tools.mcp_oauth import ( 359 HermesTokenStorage, 360 _OAUTH_AVAILABLE, 361 _build_client_metadata, 362 _configure_callback_port, 363 _is_interactive, 364 _maybe_preregister_client, 365 _redirect_handler, 366 _wait_for_callback, 367 ) 368 369 if not _OAUTH_AVAILABLE: 370 return None 371 372 cfg = dict(entry.oauth_config or {}) 373 storage = HermesTokenStorage(server_name) 374 375 if not _is_interactive() and not storage.has_cached_tokens(): 376 logger.warning( 377 "MCP OAuth for '%s': non-interactive environment and no " 378 "cached tokens found. Run interactively first to complete " 379 "initial authorization.", 380 server_name, 381 ) 382 383 _configure_callback_port(cfg) 384 client_metadata = _build_client_metadata(cfg) 385 _maybe_preregister_client(storage, cfg, client_metadata) 386 387 return _HERMES_PROVIDER_CLS( 388 server_name=server_name, 389 server_url=entry.server_url, 390 client_metadata=client_metadata, 391 storage=storage, 392 redirect_handler=_redirect_handler, 393 callback_handler=_wait_for_callback, 394 timeout=float(cfg.get("timeout", 300)), 395 ) 396 397 def remove(self, server_name: str) -> None: 398 """Evict the provider from cache AND delete tokens from disk. 399 400 Called by ``hermes mcp remove <name>`` and (indirectly) by 401 ``hermes mcp login <name>`` during forced re-auth. 402 """ 403 with self._entries_lock: 404 self._entries.pop(server_name, None) 405 406 from tools.mcp_oauth import remove_oauth_tokens 407 remove_oauth_tokens(server_name) 408 logger.info( 409 "MCP OAuth '%s': evicted from cache and removed from disk", 410 server_name, 411 ) 412 413 # -- Disk watch ---------------------------------------------------------- 414 415 async def invalidate_if_disk_changed(self, server_name: str) -> bool: 416 """If the tokens file on disk has a newer mtime than last-seen, force 417 the MCP SDK provider to reload its in-memory state. 418 419 Returns True if the cache was invalidated (mtime differed). This is 420 the core fix for the external-refresh workflow: a cron job writes 421 fresh tokens to disk, and on the next tool call the running MCP 422 session picks them up without a restart. 423 """ 424 from tools.mcp_oauth import _get_token_dir, _safe_filename 425 426 entry = self._entries.get(server_name) 427 if entry is None or entry.provider is None: 428 return False 429 430 async with entry.lock: 431 tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json" 432 try: 433 mtime_ns = tokens_path.stat().st_mtime_ns 434 except (FileNotFoundError, OSError): 435 return False 436 437 if mtime_ns != entry.last_mtime_ns: 438 old = entry.last_mtime_ns 439 entry.last_mtime_ns = mtime_ns 440 # Force the SDK's OAuthClientProvider to reload from storage 441 # on its next auth flow. `_initialized` is private API but 442 # stable across the MCP SDK versions we pin (>=1.26.0). 443 if hasattr(entry.provider, "_initialized"): 444 entry.provider._initialized = False # noqa: SLF001 445 logger.info( 446 "MCP OAuth '%s': tokens file changed (mtime %d -> %d), " 447 "forcing reload", 448 server_name, old, mtime_ns, 449 ) 450 return True 451 return False 452 453 # -- 401 handler (dedup'd) ----------------------------------------------- 454 455 async def handle_401( 456 self, 457 server_name: str, 458 failed_access_token: Optional[str] = None, 459 ) -> bool: 460 """Handle a 401 from a tool call, deduplicated across concurrent callers. 461 462 Returns: 463 True if a (possibly new) access token is now available — caller 464 should trigger a reconnect and retry the operation. 465 False if no recovery path exists — caller should surface a 466 ``needs_reauth`` error to the model so it stops hallucinating 467 manual refresh attempts. 468 469 Thundering-herd protection: if N concurrent tool calls hit 401 with 470 the same ``failed_access_token``, only one recovery attempt fires. 471 Others await the same future. 472 """ 473 entry = self._entries.get(server_name) 474 if entry is None or entry.provider is None: 475 return False 476 477 key = failed_access_token or "<unknown>" 478 loop = asyncio.get_running_loop() 479 480 async with entry.lock: 481 pending = entry.pending_401.get(key) 482 if pending is None: 483 pending = loop.create_future() 484 entry.pending_401[key] = pending 485 486 async def _do_handle() -> None: 487 try: 488 # Step 1: Did disk change? Picks up external refresh. 489 disk_changed = await self.invalidate_if_disk_changed( 490 server_name 491 ) 492 if disk_changed: 493 if not pending.done(): 494 pending.set_result(True) 495 return 496 497 # Step 2: No disk change — if the SDK can refresh 498 # in-place, let the caller retry. The SDK's httpx.Auth 499 # flow will issue the refresh on the next request. 500 provider = entry.provider 501 ctx = getattr(provider, "context", None) 502 can_refresh = False 503 if ctx is not None: 504 can_refresh_fn = getattr(ctx, "can_refresh_token", None) 505 if callable(can_refresh_fn): 506 try: 507 can_refresh = bool(can_refresh_fn()) 508 except Exception: 509 can_refresh = False 510 if not pending.done(): 511 pending.set_result(can_refresh) 512 except Exception as exc: # pragma: no cover — defensive 513 logger.warning( 514 "MCP OAuth '%s': 401 handler failed: %s", 515 server_name, exc, 516 ) 517 if not pending.done(): 518 pending.set_result(False) 519 finally: 520 entry.pending_401.pop(key, None) 521 522 asyncio.create_task(_do_handle()) 523 524 try: 525 return await pending 526 except Exception as exc: # pragma: no cover — defensive 527 logger.warning( 528 "MCP OAuth '%s': awaiting 401 handler failed: %s", 529 server_name, exc, 530 ) 531 return False 532 533 534 # --------------------------------------------------------------------------- 535 # Module-level singleton 536 # --------------------------------------------------------------------------- 537 538 539 _MANAGER: Optional[MCPOAuthManager] = None 540 _MANAGER_LOCK = threading.Lock() 541 542 543 def get_manager() -> MCPOAuthManager: 544 """Return the process-wide :class:`MCPOAuthManager` singleton.""" 545 global _MANAGER 546 with _MANAGER_LOCK: 547 if _MANAGER is None: 548 _MANAGER = MCPOAuthManager() 549 return _MANAGER 550 551 552 def reset_manager_for_tests() -> None: 553 """Test-only helper: drop the singleton so fixtures start clean.""" 554 global _MANAGER 555 with _MANAGER_LOCK: 556 _MANAGER = None