/ tools / mcp_oauth_manager.py
mcp_oauth_manager.py
  1  #!/usr/bin/env python3
  2  """Central manager for per-server MCP OAuth state.
  3  
  4  One instance shared across the process. Holds per-server OAuth provider
  5  instances and coordinates:
  6  
  7  - **Cross-process token reload** via mtime-based disk watch. When an external
  8    process (e.g. a user cron job) refreshes tokens on disk, the next auth flow
  9    picks them up without requiring a process restart.
 10  - **401 deduplication** via in-flight futures. When N concurrent tool calls
 11    all hit 401 with the same access_token, only one recovery attempt fires;
 12    the rest await the same result.
 13  - **Reconnect signalling** for long-lived MCP sessions. The manager itself
 14    does not drive reconnection — the `MCPServerTask` in `mcp_tool.py` does —
 15    but the manager is the single source of truth that decides when reconnect
 16    is warranted.
 17  
 18  Replaces what used to be scattered across eight call sites in `mcp_oauth.py`,
 19  `mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place
 20  that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths
 21  go through `get_manager()`.
 22  
 23  Design reference:
 24  
 25  - Claude Code's ``invalidateOAuthCacheIfDiskChanged``
 26    (``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical
 27    external-refresh staleness bug class.
 28  - Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed``
 29    (``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's
 30    lazy refresh rather than calling refresh before every op, because one
 31    ``stat()`` per tool call is cheaper than an ``await`` + potential refresh
 32    round-trip, and the SDK's in-memory expiry path is already correct.
 33  """
 34  
 35  from __future__ import annotations
 36  
 37  import asyncio
 38  import logging
 39  import threading
 40  from dataclasses import dataclass, field
 41  from typing import Any, Optional
 42  
 43  logger = logging.getLogger(__name__)
 44  
 45  
 46  # ---------------------------------------------------------------------------
 47  # Per-server entry
 48  # ---------------------------------------------------------------------------
 49  
 50  
 51  @dataclass
 52  class _ProviderEntry:
 53      """Per-server OAuth state tracked by the manager.
 54  
 55      Fields:
 56          server_url: The MCP server URL used to build the provider. Tracked
 57              so we can discard a cached provider if the URL changes.
 58          oauth_config: Optional dict from ``mcp_servers.<name>.oauth``.
 59          provider: The ``httpx.Auth``-compatible provider wrapping the MCP
 60              SDK. None until first use.
 61          last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file.
 62              Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed`
 63              to detect external refreshes.
 64          lock: Serialises concurrent access to this entry's state. Bound to
 65              whichever asyncio loop first awaits it (the MCP event loop).
 66          pending_401: In-flight 401-handler futures keyed by the failed
 67              access_token, for deduplicating thundering-herd 401s. Mirrors
 68              Claude Code's ``pending401Handlers`` map.
 69      """
 70  
 71      server_url: str
 72      oauth_config: Optional[dict]
 73      provider: Optional[Any] = None
 74      last_mtime_ns: int = 0
 75      lock: asyncio.Lock = field(default_factory=asyncio.Lock)
 76      pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict)
 77  
 78  
 79  # ---------------------------------------------------------------------------
 80  # HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch
 81  # ---------------------------------------------------------------------------
 82  
 83  
 84  def _make_hermes_provider_class() -> Optional[type]:
 85      """Lazy-import the SDK base class and return our subclass.
 86  
 87      Wrapped in a function so this module imports cleanly even when the
 88      MCP SDK's OAuth module is unavailable (e.g. older mcp versions).
 89      """
 90      try:
 91          from mcp.client.auth.oauth2 import OAuthClientProvider
 92      except ImportError:  # pragma: no cover — SDK required in CI
 93          return None
 94  
 95      class HermesMCPOAuthProvider(OAuthClientProvider):
 96          """OAuthClientProvider with pre-flow disk-mtime reload.
 97  
 98          Before every ``async_auth_flow`` invocation, asks the manager to
 99          check whether the tokens file on disk has been modified externally.
100          If so, the manager resets ``_initialized`` so the next flow
101          re-reads from storage.
102  
103          This makes external-process refreshes (cron, another CLI instance)
104          visible to the running MCP session without requiring a restart.
105  
106          Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged``
107          (``src/utils/auth.ts:1320``, CC-1096 / GH#24317).
108          """
109  
110          def __init__(self, *args: Any, server_name: str = "", **kwargs: Any):
111              super().__init__(*args, **kwargs)
112              self._hermes_server_name = server_name
113  
114          async def _initialize(self) -> None:
115              """Load stored tokens + client info AND seed token_expiry_time.
116  
117              Also eagerly fetches OAuth authorization-server metadata (PRM +
118              ASM) when we have stored tokens but no cached metadata, so the
119              SDK's ``_refresh_token`` can build the correct token_endpoint
120              URL on the preemptive-refresh path. Without this, the SDK
121              falls back to ``{mcp_server_url}/token`` (wrong for providers
122              whose AS is a different origin — BetterStack's MCP lives at
123              ``https://mcp.betterstack.com`` but its token endpoint is at
124              ``https://betterstack.com/oauth/token``), the refresh 404s, and
125              we drop through to full browser reauth.
126  
127              The SDK's base ``_initialize`` populates ``current_tokens`` but
128              does NOT call ``update_token_expiry``, so ``token_expiry_time``
129              stays ``None`` and ``is_token_valid()`` returns True for any
130              loaded token regardless of actual age. After a process restart
131              this ships stale Bearer tokens to the server; some providers
132              return HTTP 401 (caught by the 401 handler), others return 200
133              with an app-level auth error (invisible to the transport layer,
134              e.g. BetterStack returning "No teams found. Please check your
135              authentication.").
136  
137              Seeding ``token_expiry_time`` from the reloaded token fixes that:
138              ``is_token_valid()`` correctly reports False for expired tokens,
139              ``async_auth_flow`` takes the ``can_refresh_token()`` branch,
140              and the SDK quietly refreshes before the first real request.
141  
142              Paired with :class:`HermesTokenStorage` persisting an absolute
143              ``expires_at`` timestamp (``mcp_oauth.py:set_tokens``) so the
144              remaining TTL we compute here reflects real wall-clock age.
145              """
146              await super()._initialize()
147              tokens = self.context.current_tokens
148              if tokens is not None and tokens.expires_in is not None:
149                  self.context.update_token_expiry(tokens)
150  
151              # Pre-flight OAuth AS discovery so ``_refresh_token`` has a
152              # correct ``token_endpoint`` before the first refresh attempt.
153              # Only runs when we have tokens on cold-load but no cached
154              # metadata — i.e. the exact scenario where the SDK's built-in
155              # 401-branch discovery hasn't had a chance to run yet.
156              if (
157                  tokens is not None
158                  and self.context.oauth_metadata is None
159              ):
160                  try:
161                      await self._prefetch_oauth_metadata()
162                  except Exception as exc:  # pragma: no cover — defensive
163                      # Non-fatal: if discovery fails, the SDK's normal 401-
164                      # branch discovery will run on the next request.
165                      logger.debug(
166                          "MCP OAuth '%s': pre-flight metadata discovery "
167                          "failed (non-fatal): %s",
168                          self._hermes_server_name, exc,
169                      )
170  
171          async def _prefetch_oauth_metadata(self) -> None:
172              """Fetch PRM + ASM from the well-known endpoints, cache on context.
173  
174              Mirrors the SDK's 401-branch discovery (oauth2.py ~line 511-551)
175              but runs synchronously before the first request instead of
176              inside the httpx auth_flow generator. Uses the SDK's own URL
177              builders and response handlers so we track whatever the SDK
178              version we're pinned to expects.
179              """
180              import httpx  # local import: httpx is an MCP SDK dependency
181              from mcp.client.auth.utils import (
182                  build_oauth_authorization_server_metadata_discovery_urls,
183                  build_protected_resource_metadata_discovery_urls,
184                  create_oauth_metadata_request,
185                  handle_auth_metadata_response,
186                  handle_protected_resource_response,
187              )
188  
189              server_url = self.context.server_url
190              async with httpx.AsyncClient(timeout=10.0) as client:
191                  # Step 1: PRM discovery to learn the authorization_server URL.
192                  for url in build_protected_resource_metadata_discovery_urls(
193                      None, server_url
194                  ):
195                      req = create_oauth_metadata_request(url)
196                      try:
197                          resp = await client.send(req)
198                      except httpx.HTTPError as exc:
199                          logger.debug(
200                              "MCP OAuth '%s': PRM discovery to %s failed: %s",
201                              self._hermes_server_name, url, exc,
202                          )
203                          continue
204                      prm = await handle_protected_resource_response(resp)
205                      if prm:
206                          self.context.protected_resource_metadata = prm
207                          if prm.authorization_servers:
208                              self.context.auth_server_url = str(
209                                  prm.authorization_servers[0]
210                              )
211                          break
212  
213                  # Step 2: ASM discovery against the auth_server_url (or
214                  # server_url fallback for legacy providers).
215                  for url in build_oauth_authorization_server_metadata_discovery_urls(
216                      self.context.auth_server_url, server_url
217                  ):
218                      req = create_oauth_metadata_request(url)
219                      try:
220                          resp = await client.send(req)
221                      except httpx.HTTPError as exc:
222                          logger.debug(
223                              "MCP OAuth '%s': ASM discovery to %s failed: %s",
224                              self._hermes_server_name, url, exc,
225                          )
226                          continue
227                      ok, asm = await handle_auth_metadata_response(resp)
228                      if not ok:
229                          break
230                      if asm:
231                          self.context.oauth_metadata = asm
232                          logger.debug(
233                              "MCP OAuth '%s': pre-flight ASM discovered "
234                              "token_endpoint=%s",
235                              self._hermes_server_name, asm.token_endpoint,
236                          )
237                          break
238  
239          async def async_auth_flow(self, request):  # type: ignore[override]
240              # Pre-flow hook: ask the manager to refresh from disk if needed.
241              # Any failure here is non-fatal — we just log and proceed with
242              # whatever state the SDK already has.
243              try:
244                  await get_manager().invalidate_if_disk_changed(
245                      self._hermes_server_name
246                  )
247              except Exception as exc:  # pragma: no cover — defensive
248                  logger.debug(
249                      "MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s",
250                      self._hermes_server_name, exc,
251                  )
252  
253              # Manually bridge the bidirectional generator protocol. httpx's
254              # auth_flow driver (httpx._client._send_handling_auth) calls
255              # ``auth_flow.asend(response)`` to feed HTTP responses back into
256              # the generator. A naive wrapper using ``async for item in inner:
257              # yield item`` DISCARDS those .asend(response) values and resumes
258              # the inner generator with None, so the SDK's
259              # ``response = yield request`` branch in
260              # mcp/client/auth/oauth2.py sees response=None and crashes at
261              # ``if response.status_code == 401`` with AttributeError.
262              #
263              # The bridge below forwards each .asend() value into the inner
264              # generator via inner.asend(incoming), preserving the bidirectional
265              # contract. Regression from PR #11383 caught by
266              # tests/tools/test_mcp_oauth_bidirectional.py.
267              inner = super().async_auth_flow(request)
268              try:
269                  outgoing = await inner.__anext__()
270                  while True:
271                      incoming = yield outgoing
272                      outgoing = await inner.asend(incoming)
273              except StopAsyncIteration:
274                  return
275  
276      return HermesMCPOAuthProvider
277  
278  
279  # Cached at import time. Tested and used by :class:`MCPOAuthManager`.
280  _HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class()
281  
282  
283  # ---------------------------------------------------------------------------
284  # Manager
285  # ---------------------------------------------------------------------------
286  
287  
288  class MCPOAuthManager:
289      """Single source of truth for per-server MCP OAuth state.
290  
291      Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for
292      get-or-create semantics. Per-entry state is guarded by the entry's own
293      ``asyncio.Lock`` (used from the MCP event loop thread).
294      """
295  
296      def __init__(self) -> None:
297          self._entries: dict[str, _ProviderEntry] = {}
298          self._entries_lock = threading.Lock()
299  
300      # -- Provider construction / caching -------------------------------------
301  
302      def get_or_build_provider(
303          self,
304          server_name: str,
305          server_url: str,
306          oauth_config: Optional[dict],
307      ) -> Optional[Any]:
308          """Return a cached OAuth provider for ``server_name`` or build one.
309  
310          Idempotent: repeat calls with the same name return the same instance.
311          If ``server_url`` changes for a given name, the cached entry is
312          discarded and a fresh provider is built.
313  
314          Returns None if the MCP SDK's OAuth support is unavailable.
315          """
316          with self._entries_lock:
317              entry = self._entries.get(server_name)
318              if entry is not None and entry.server_url != server_url:
319                  logger.info(
320                      "MCP OAuth '%s': URL changed from %s to %s, discarding cache",
321                      server_name, entry.server_url, server_url,
322                  )
323                  entry = None
324  
325              if entry is None:
326                  entry = _ProviderEntry(
327                      server_url=server_url,
328                      oauth_config=oauth_config,
329                  )
330                  self._entries[server_name] = entry
331  
332              if entry.provider is None:
333                  entry.provider = self._build_provider(server_name, entry)
334  
335              return entry.provider
336  
337      def _build_provider(
338          self,
339          server_name: str,
340          entry: _ProviderEntry,
341      ) -> Optional[Any]:
342          """Build the underlying OAuth provider.
343  
344          Constructs :class:`HermesMCPOAuthProvider` directly using the helpers
345          extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow
346          disk-watch hook so external token refreshes (cron, other CLI
347          instances) are visible to running MCP sessions.
348  
349          Returns None if the MCP SDK's OAuth support is unavailable.
350          """
351          if _HERMES_PROVIDER_CLS is None:
352              logger.warning(
353                  "MCP OAuth '%s': SDK auth module unavailable", server_name,
354              )
355              return None
356  
357          # Local imports avoid circular deps at module import time.
358          from tools.mcp_oauth import (
359              HermesTokenStorage,
360              _OAUTH_AVAILABLE,
361              _build_client_metadata,
362              _configure_callback_port,
363              _is_interactive,
364              _maybe_preregister_client,
365              _redirect_handler,
366              _wait_for_callback,
367          )
368  
369          if not _OAUTH_AVAILABLE:
370              return None
371  
372          cfg = dict(entry.oauth_config or {})
373          storage = HermesTokenStorage(server_name)
374  
375          if not _is_interactive() and not storage.has_cached_tokens():
376              logger.warning(
377                  "MCP OAuth for '%s': non-interactive environment and no "
378                  "cached tokens found. Run interactively first to complete "
379                  "initial authorization.",
380                  server_name,
381              )
382  
383          _configure_callback_port(cfg)
384          client_metadata = _build_client_metadata(cfg)
385          _maybe_preregister_client(storage, cfg, client_metadata)
386  
387          return _HERMES_PROVIDER_CLS(
388              server_name=server_name,
389              server_url=entry.server_url,
390              client_metadata=client_metadata,
391              storage=storage,
392              redirect_handler=_redirect_handler,
393              callback_handler=_wait_for_callback,
394              timeout=float(cfg.get("timeout", 300)),
395          )
396  
397      def remove(self, server_name: str) -> None:
398          """Evict the provider from cache AND delete tokens from disk.
399  
400          Called by ``hermes mcp remove <name>`` and (indirectly) by
401          ``hermes mcp login <name>`` during forced re-auth.
402          """
403          with self._entries_lock:
404              self._entries.pop(server_name, None)
405  
406          from tools.mcp_oauth import remove_oauth_tokens
407          remove_oauth_tokens(server_name)
408          logger.info(
409              "MCP OAuth '%s': evicted from cache and removed from disk",
410              server_name,
411          )
412  
413      # -- Disk watch ----------------------------------------------------------
414  
415      async def invalidate_if_disk_changed(self, server_name: str) -> bool:
416          """If the tokens file on disk has a newer mtime than last-seen, force
417          the MCP SDK provider to reload its in-memory state.
418  
419          Returns True if the cache was invalidated (mtime differed). This is
420          the core fix for the external-refresh workflow: a cron job writes
421          fresh tokens to disk, and on the next tool call the running MCP
422          session picks them up without a restart.
423          """
424          from tools.mcp_oauth import _get_token_dir, _safe_filename
425  
426          entry = self._entries.get(server_name)
427          if entry is None or entry.provider is None:
428              return False
429  
430          async with entry.lock:
431              tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json"
432              try:
433                  mtime_ns = tokens_path.stat().st_mtime_ns
434              except (FileNotFoundError, OSError):
435                  return False
436  
437              if mtime_ns != entry.last_mtime_ns:
438                  old = entry.last_mtime_ns
439                  entry.last_mtime_ns = mtime_ns
440                  # Force the SDK's OAuthClientProvider to reload from storage
441                  # on its next auth flow. `_initialized` is private API but
442                  # stable across the MCP SDK versions we pin (>=1.26.0).
443                  if hasattr(entry.provider, "_initialized"):
444                      entry.provider._initialized = False  # noqa: SLF001
445                  logger.info(
446                      "MCP OAuth '%s': tokens file changed (mtime %d -> %d), "
447                      "forcing reload",
448                      server_name, old, mtime_ns,
449                  )
450                  return True
451              return False
452  
453      # -- 401 handler (dedup'd) -----------------------------------------------
454  
455      async def handle_401(
456          self,
457          server_name: str,
458          failed_access_token: Optional[str] = None,
459      ) -> bool:
460          """Handle a 401 from a tool call, deduplicated across concurrent callers.
461  
462          Returns:
463              True  if a (possibly new) access token is now available — caller
464                    should trigger a reconnect and retry the operation.
465              False if no recovery path exists — caller should surface a
466                    ``needs_reauth`` error to the model so it stops hallucinating
467                    manual refresh attempts.
468  
469          Thundering-herd protection: if N concurrent tool calls hit 401 with
470          the same ``failed_access_token``, only one recovery attempt fires.
471          Others await the same future.
472          """
473          entry = self._entries.get(server_name)
474          if entry is None or entry.provider is None:
475              return False
476  
477          key = failed_access_token or "<unknown>"
478          loop = asyncio.get_running_loop()
479  
480          async with entry.lock:
481              pending = entry.pending_401.get(key)
482              if pending is None:
483                  pending = loop.create_future()
484                  entry.pending_401[key] = pending
485  
486                  async def _do_handle() -> None:
487                      try:
488                          # Step 1: Did disk change? Picks up external refresh.
489                          disk_changed = await self.invalidate_if_disk_changed(
490                              server_name
491                          )
492                          if disk_changed:
493                              if not pending.done():
494                                  pending.set_result(True)
495                              return
496  
497                          # Step 2: No disk change — if the SDK can refresh
498                          # in-place, let the caller retry. The SDK's httpx.Auth
499                          # flow will issue the refresh on the next request.
500                          provider = entry.provider
501                          ctx = getattr(provider, "context", None)
502                          can_refresh = False
503                          if ctx is not None:
504                              can_refresh_fn = getattr(ctx, "can_refresh_token", None)
505                              if callable(can_refresh_fn):
506                                  try:
507                                      can_refresh = bool(can_refresh_fn())
508                                  except Exception:
509                                      can_refresh = False
510                          if not pending.done():
511                              pending.set_result(can_refresh)
512                      except Exception as exc:  # pragma: no cover — defensive
513                          logger.warning(
514                              "MCP OAuth '%s': 401 handler failed: %s",
515                              server_name, exc,
516                          )
517                          if not pending.done():
518                              pending.set_result(False)
519                      finally:
520                          entry.pending_401.pop(key, None)
521  
522                  asyncio.create_task(_do_handle())
523  
524          try:
525              return await pending
526          except Exception as exc:  # pragma: no cover — defensive
527              logger.warning(
528                  "MCP OAuth '%s': awaiting 401 handler failed: %s",
529                  server_name, exc,
530              )
531              return False
532  
533  
534  # ---------------------------------------------------------------------------
535  # Module-level singleton
536  # ---------------------------------------------------------------------------
537  
538  
539  _MANAGER: Optional[MCPOAuthManager] = None
540  _MANAGER_LOCK = threading.Lock()
541  
542  
543  def get_manager() -> MCPOAuthManager:
544      """Return the process-wide :class:`MCPOAuthManager` singleton."""
545      global _MANAGER
546      with _MANAGER_LOCK:
547          if _MANAGER is None:
548              _MANAGER = MCPOAuthManager()
549          return _MANAGER
550  
551  
552  def reset_manager_for_tests() -> None:
553      """Test-only helper: drop the singleton so fixtures start clean."""
554      global _MANAGER
555      with _MANAGER_LOCK:
556          _MANAGER = None