/ bridge / jwtUtils.ts
jwtUtils.ts
  1  import { logEvent } from '../services/analytics/index.js'
  2  import { logForDebugging } from '../utils/debug.js'
  3  import { logForDiagnosticsNoPII } from '../utils/diagLogs.js'
  4  import { errorMessage } from '../utils/errors.js'
  5  import { jsonParse } from '../utils/slowOperations.js'
  6  
  7  /** Format a millisecond duration as a human-readable string (e.g. "5m 30s"). */
  8  function formatDuration(ms: number): string {
  9    if (ms < 60_000) return `${Math.round(ms / 1000)}s`
 10    const m = Math.floor(ms / 60_000)
 11    const s = Math.round((ms % 60_000) / 1000)
 12    return s > 0 ? `${m}m ${s}s` : `${m}m`
 13  }
 14  
 15  /**
 16   * Decode a JWT's payload segment without verifying the signature.
 17   * Strips the `sk-ant-si-` session-ingress prefix if present.
 18   * Returns the parsed JSON payload as `unknown`, or `null` if the
 19   * token is malformed or the payload is not valid JSON.
 20   */
 21  export function decodeJwtPayload(token: string): unknown | null {
 22    const jwt = token.startsWith('sk-ant-si-')
 23      ? token.slice('sk-ant-si-'.length)
 24      : token
 25    const parts = jwt.split('.')
 26    if (parts.length !== 3 || !parts[1]) return null
 27    try {
 28      return jsonParse(Buffer.from(parts[1], 'base64url').toString('utf8'))
 29    } catch {
 30      return null
 31    }
 32  }
 33  
 34  /**
 35   * Decode the `exp` (expiry) claim from a JWT without verifying the signature.
 36   * @returns The `exp` value in Unix seconds, or `null` if unparseable
 37   */
 38  export function decodeJwtExpiry(token: string): number | null {
 39    const payload = decodeJwtPayload(token)
 40    if (
 41      payload !== null &&
 42      typeof payload === 'object' &&
 43      'exp' in payload &&
 44      typeof payload.exp === 'number'
 45    ) {
 46      return payload.exp
 47    }
 48    return null
 49  }
 50  
 51  /** Refresh buffer: request a new token before expiry. */
 52  const TOKEN_REFRESH_BUFFER_MS = 5 * 60 * 1000
 53  
 54  /** Fallback refresh interval when the new token's expiry is unknown. */
 55  const FALLBACK_REFRESH_INTERVAL_MS = 30 * 60 * 1000 // 30 minutes
 56  
 57  /** Max consecutive failures before giving up on the refresh chain. */
 58  const MAX_REFRESH_FAILURES = 3
 59  
 60  /** Retry delay when getAccessToken returns undefined. */
 61  const REFRESH_RETRY_DELAY_MS = 60_000
 62  
 63  /**
 64   * Creates a token refresh scheduler that proactively refreshes session tokens
 65   * before they expire. Used by both the standalone bridge and the REPL bridge.
 66   *
 67   * When a token is about to expire, the scheduler calls `onRefresh` with the
 68   * session ID and the bridge's OAuth access token. The caller is responsible
 69   * for delivering the token to the appropriate transport (child process stdin
 70   * for standalone bridge, WebSocket reconnect for REPL bridge).
 71   */
 72  export function createTokenRefreshScheduler({
 73    getAccessToken,
 74    onRefresh,
 75    label,
 76    refreshBufferMs = TOKEN_REFRESH_BUFFER_MS,
 77  }: {
 78    getAccessToken: () => string | undefined | Promise<string | undefined>
 79    onRefresh: (sessionId: string, oauthToken: string) => void
 80    label: string
 81    /** How long before expiry to fire refresh. Defaults to 5 min. */
 82    refreshBufferMs?: number
 83  }): {
 84    schedule: (sessionId: string, token: string) => void
 85    scheduleFromExpiresIn: (sessionId: string, expiresInSeconds: number) => void
 86    cancel: (sessionId: string) => void
 87    cancelAll: () => void
 88  } {
 89    const timers = new Map<string, ReturnType<typeof setTimeout>>()
 90    const failureCounts = new Map<string, number>()
 91    // Generation counter per session — incremented by schedule() and cancel()
 92    // so that in-flight async doRefresh() calls can detect when they've been
 93    // superseded and should skip setting follow-up timers.
 94    const generations = new Map<string, number>()
 95  
 96    function nextGeneration(sessionId: string): number {
 97      const gen = (generations.get(sessionId) ?? 0) + 1
 98      generations.set(sessionId, gen)
 99      return gen
100    }
101  
102    function schedule(sessionId: string, token: string): void {
103      const expiry = decodeJwtExpiry(token)
104      if (!expiry) {
105        // Token is not a decodable JWT (e.g. an OAuth token passed from the
106        // REPL bridge WebSocket open handler).  Preserve any existing timer
107        // (such as the follow-up refresh set by doRefresh) so the refresh
108        // chain is not broken.
109        logForDebugging(
110          `[${label}:token] Could not decode JWT expiry for sessionId=${sessionId}, token prefix=${token.slice(0, 15)}…, keeping existing timer`,
111        )
112        return
113      }
114  
115      // Clear any existing refresh timer — we have a concrete expiry to replace it.
116      const existing = timers.get(sessionId)
117      if (existing) {
118        clearTimeout(existing)
119      }
120  
121      // Bump generation to invalidate any in-flight async doRefresh.
122      const gen = nextGeneration(sessionId)
123  
124      const expiryDate = new Date(expiry * 1000).toISOString()
125      const delayMs = expiry * 1000 - Date.now() - refreshBufferMs
126      if (delayMs <= 0) {
127        logForDebugging(
128          `[${label}:token] Token for sessionId=${sessionId} expires=${expiryDate} (past or within buffer), refreshing immediately`,
129        )
130        void doRefresh(sessionId, gen)
131        return
132      }
133  
134      logForDebugging(
135        `[${label}:token] Scheduled token refresh for sessionId=${sessionId} in ${formatDuration(delayMs)} (expires=${expiryDate}, buffer=${refreshBufferMs / 1000}s)`,
136      )
137  
138      const timer = setTimeout(doRefresh, delayMs, sessionId, gen)
139      timers.set(sessionId, timer)
140    }
141  
142    /**
143     * Schedule refresh using an explicit TTL (seconds until expiry) rather
144     * than decoding a JWT's exp claim. Used by callers whose JWT is opaque
145     * (e.g. POST /v1/code/sessions/{id}/bridge returns expires_in directly).
146     */
147    function scheduleFromExpiresIn(
148      sessionId: string,
149      expiresInSeconds: number,
150    ): void {
151      const existing = timers.get(sessionId)
152      if (existing) clearTimeout(existing)
153      const gen = nextGeneration(sessionId)
154      // Clamp to 30s floor — if refreshBufferMs exceeds the server's expires_in
155      // (e.g. very large buffer for frequent-refresh testing, or server shortens
156      // expires_in unexpectedly), unclamped delayMs ≤ 0 would tight-loop.
157      const delayMs = Math.max(expiresInSeconds * 1000 - refreshBufferMs, 30_000)
158      logForDebugging(
159        `[${label}:token] Scheduled token refresh for sessionId=${sessionId} in ${formatDuration(delayMs)} (expires_in=${expiresInSeconds}s, buffer=${refreshBufferMs / 1000}s)`,
160      )
161      const timer = setTimeout(doRefresh, delayMs, sessionId, gen)
162      timers.set(sessionId, timer)
163    }
164  
165    async function doRefresh(sessionId: string, gen: number): Promise<void> {
166      let oauthToken: string | undefined
167      try {
168        oauthToken = await getAccessToken()
169      } catch (err) {
170        logForDebugging(
171          `[${label}:token] getAccessToken threw for sessionId=${sessionId}: ${errorMessage(err)}`,
172          { level: 'error' },
173        )
174      }
175  
176      // If the session was cancelled or rescheduled while we were awaiting,
177      // the generation will have changed — bail out to avoid orphaned timers.
178      if (generations.get(sessionId) !== gen) {
179        logForDebugging(
180          `[${label}:token] doRefresh for sessionId=${sessionId} stale (gen ${gen} vs ${generations.get(sessionId)}), skipping`,
181        )
182        return
183      }
184  
185      if (!oauthToken) {
186        const failures = (failureCounts.get(sessionId) ?? 0) + 1
187        failureCounts.set(sessionId, failures)
188        logForDebugging(
189          `[${label}:token] No OAuth token available for refresh, sessionId=${sessionId} (failure ${failures}/${MAX_REFRESH_FAILURES})`,
190          { level: 'error' },
191        )
192        logForDiagnosticsNoPII('error', 'bridge_token_refresh_no_oauth')
193        // Schedule a retry so the refresh chain can recover if the token
194        // becomes available again (e.g. transient cache clear during refresh).
195        // Cap retries to avoid spamming on genuine failures.
196        if (failures < MAX_REFRESH_FAILURES) {
197          const retryTimer = setTimeout(
198            doRefresh,
199            REFRESH_RETRY_DELAY_MS,
200            sessionId,
201            gen,
202          )
203          timers.set(sessionId, retryTimer)
204        }
205        return
206      }
207  
208      // Reset failure counter on successful token retrieval
209      failureCounts.delete(sessionId)
210  
211      logForDebugging(
212        `[${label}:token] Refreshing token for sessionId=${sessionId}: new token prefix=${oauthToken.slice(0, 15)}…`,
213      )
214      logEvent('tengu_bridge_token_refreshed', {})
215      onRefresh(sessionId, oauthToken)
216  
217      // Schedule a follow-up refresh so long-running sessions stay authenticated.
218      // Without this, the initial one-shot timer leaves the session vulnerable
219      // to token expiry if it runs past the first refresh window.
220      const timer = setTimeout(
221        doRefresh,
222        FALLBACK_REFRESH_INTERVAL_MS,
223        sessionId,
224        gen,
225      )
226      timers.set(sessionId, timer)
227      logForDebugging(
228        `[${label}:token] Scheduled follow-up refresh for sessionId=${sessionId} in ${formatDuration(FALLBACK_REFRESH_INTERVAL_MS)}`,
229      )
230    }
231  
232    function cancel(sessionId: string): void {
233      // Bump generation to invalidate any in-flight async doRefresh.
234      nextGeneration(sessionId)
235      const timer = timers.get(sessionId)
236      if (timer) {
237        clearTimeout(timer)
238        timers.delete(sessionId)
239      }
240      failureCounts.delete(sessionId)
241    }
242  
243    function cancelAll(): void {
244      // Bump all generations so in-flight doRefresh calls are invalidated.
245      for (const sessionId of generations.keys()) {
246        nextGeneration(sessionId)
247      }
248      for (const timer of timers.values()) {
249        clearTimeout(timer)
250      }
251      timers.clear()
252      failureCounts.clear()
253    }
254  
255    return { schedule, scheduleFromExpiresIn, cancel, cancelAll }
256  }