/ utils / model / modelCapabilities.ts
modelCapabilities.ts
  1  import { readFileSync } from 'fs'
  2  import { mkdir, writeFile } from 'fs/promises'
  3  import isEqual from 'lodash-es/isEqual.js'
  4  import memoize from 'lodash-es/memoize.js'
  5  import { join } from 'path'
  6  import { z } from 'zod/v4'
  7  import { OAUTH_BETA_HEADER } from '../../constants/oauth.js'
  8  import { getAnthropicClient } from '../../services/api/client.js'
  9  import { isClaudeAISubscriber } from '../auth.js'
 10  import { logForDebugging } from '../debug.js'
 11  import { getClaudeConfigHomeDir } from '../envUtils.js'
 12  import { safeParseJSON } from '../json.js'
 13  import { lazySchema } from '../lazySchema.js'
 14  import { isEssentialTrafficOnly } from '../privacyLevel.js'
 15  import { jsonStringify } from '../slowOperations.js'
 16  import { getAPIProvider, isFirstPartyAnthropicBaseUrl } from './providers.js'
 17  
 18  // .strip() — don't persist internal-only fields (mycro_deployments etc.) to disk
 19  const ModelCapabilitySchema = lazySchema(() =>
 20    z
 21      .object({
 22        id: z.string(),
 23        max_input_tokens: z.number().optional(),
 24        max_tokens: z.number().optional(),
 25      })
 26      .strip(),
 27  )
 28  
 29  const CacheFileSchema = lazySchema(() =>
 30    z.object({
 31      models: z.array(ModelCapabilitySchema()),
 32      timestamp: z.number(),
 33    }),
 34  )
 35  
 36  export type ModelCapability = z.infer<ReturnType<typeof ModelCapabilitySchema>>
 37  
 38  function getCacheDir(): string {
 39    return join(getClaudeConfigHomeDir(), 'cache')
 40  }
 41  
 42  function getCachePath(): string {
 43    return join(getCacheDir(), 'model-capabilities.json')
 44  }
 45  
 46  function isModelCapabilitiesEligible(): boolean {
 47    if (process.env.USER_TYPE !== 'ant') return false
 48    if (getAPIProvider() !== 'firstParty') return false
 49    if (!isFirstPartyAnthropicBaseUrl()) return false
 50    return true
 51  }
 52  
 53  // Longest-id-first so substring match prefers most specific; secondary key for stable isEqual
 54  function sortForMatching(models: ModelCapability[]): ModelCapability[] {
 55    return [...models].sort(
 56      (a, b) => b.id.length - a.id.length || a.id.localeCompare(b.id),
 57    )
 58  }
 59  
 60  // Keyed on cache path so tests that set CLAUDE_CONFIG_DIR get a fresh read
 61  const loadCache = memoize(
 62    (path: string): ModelCapability[] | null => {
 63      try {
 64        // eslint-disable-next-line custom-rules/no-sync-fs -- memoized; called from sync getContextWindowForModel
 65        const raw = readFileSync(path, 'utf-8')
 66        const parsed = CacheFileSchema().safeParse(safeParseJSON(raw, false))
 67        return parsed.success ? parsed.data.models : null
 68      } catch {
 69        return null
 70      }
 71    },
 72    path => path,
 73  )
 74  
 75  export function getModelCapability(model: string): ModelCapability | undefined {
 76    if (!isModelCapabilitiesEligible()) return undefined
 77    const cached = loadCache(getCachePath())
 78    if (!cached || cached.length === 0) return undefined
 79    const m = model.toLowerCase()
 80    const exact = cached.find(c => c.id.toLowerCase() === m)
 81    if (exact) return exact
 82    return cached.find(c => m.includes(c.id.toLowerCase()))
 83  }
 84  
 85  export async function refreshModelCapabilities(): Promise<void> {
 86    if (!isModelCapabilitiesEligible()) return
 87    if (isEssentialTrafficOnly()) return
 88  
 89    try {
 90      const anthropic = await getAnthropicClient({ maxRetries: 1 })
 91      const betas = isClaudeAISubscriber() ? [OAUTH_BETA_HEADER] : undefined
 92      const parsed: ModelCapability[] = []
 93      for await (const entry of anthropic.models.list({ betas })) {
 94        const result = ModelCapabilitySchema().safeParse(entry)
 95        if (result.success) parsed.push(result.data)
 96      }
 97      if (parsed.length === 0) return
 98  
 99      const path = getCachePath()
100      const models = sortForMatching(parsed)
101      if (isEqual(loadCache(path), models)) {
102        logForDebugging('[modelCapabilities] cache unchanged, skipping write')
103        return
104      }
105  
106      await mkdir(getCacheDir(), { recursive: true })
107      await writeFile(path, jsonStringify({ models, timestamp: Date.now() }), {
108        encoding: 'utf-8',
109        mode: 0o600,
110      })
111      loadCache.cache.delete(path)
112      logForDebugging(`[modelCapabilities] cached ${models.length} models`)
113    } catch (error) {
114      logForDebugging(
115        `[modelCapabilities] fetch failed: ${error instanceof Error ? error.message : 'unknown'}`,
116      )
117    }
118  }