/ utils / model / modelStrings.ts
modelStrings.ts
  1  import {
  2    getModelStrings as getModelStringsState,
  3    setModelStrings as setModelStringsState,
  4  } from 'src/bootstrap/state.js'
  5  import { logError } from '../log.js'
  6  import { sequential } from '../sequential.js'
  7  import { getInitialSettings } from '../settings/settings.js'
  8  import { findFirstMatch, getBedrockInferenceProfiles } from './bedrock.js'
  9  import {
 10    ALL_MODEL_CONFIGS,
 11    CANONICAL_ID_TO_KEY,
 12    type CanonicalModelId,
 13    type ModelKey,
 14  } from './configs.js'
 15  import { type APIProvider, getAPIProvider } from './providers.js'
 16  
 17  /**
 18   * Maps each model version to its provider-specific model ID string.
 19   * Derived from ALL_MODEL_CONFIGS — adding a model there extends this type.
 20   */
 21  export type ModelStrings = Record<ModelKey, string>
 22  
 23  const MODEL_KEYS = Object.keys(ALL_MODEL_CONFIGS) as ModelKey[]
 24  
 25  function getBuiltinModelStrings(provider: APIProvider): ModelStrings {
 26    const out = {} as ModelStrings
 27    for (const key of MODEL_KEYS) {
 28      out[key] = ALL_MODEL_CONFIGS[key][provider]
 29    }
 30    return out
 31  }
 32  
 33  async function getBedrockModelStrings(): Promise<ModelStrings> {
 34    const fallback = getBuiltinModelStrings('bedrock')
 35    let profiles: string[] | undefined
 36    try {
 37      profiles = await getBedrockInferenceProfiles()
 38    } catch (error) {
 39      logError(error as Error)
 40      return fallback
 41    }
 42    if (!profiles?.length) {
 43      return fallback
 44    }
 45    // Each config's firstParty ID is the canonical substring we search for in the
 46    // user's inference profile list (e.g. "claude-opus-4-6" matches
 47    // "eu.anthropic.claude-opus-4-6-v1"). Fall back to the hardcoded bedrock ID
 48    // when no matching profile is found.
 49    const out = {} as ModelStrings
 50    for (const key of MODEL_KEYS) {
 51      const needle = ALL_MODEL_CONFIGS[key].firstParty
 52      out[key] = findFirstMatch(profiles, needle) || fallback[key]
 53    }
 54    return out
 55  }
 56  
 57  /**
 58   * Layer user-configured modelOverrides (from settings.json) on top of the
 59   * provider-derived model strings. Overrides are keyed by canonical first-party
 60   * model ID (e.g. "claude-opus-4-6") and map to arbitrary provider-specific
 61   * strings — typically Bedrock inference profile ARNs.
 62   */
 63  function applyModelOverrides(ms: ModelStrings): ModelStrings {
 64    const overrides = getInitialSettings().modelOverrides
 65    if (!overrides) {
 66      return ms
 67    }
 68    const out = { ...ms }
 69    for (const [canonicalId, override] of Object.entries(overrides)) {
 70      const key = CANONICAL_ID_TO_KEY[canonicalId as CanonicalModelId]
 71      if (key && override) {
 72        out[key] = override
 73      }
 74    }
 75    return out
 76  }
 77  
 78  /**
 79   * Resolve an overridden model ID (e.g. a Bedrock ARN) back to its canonical
 80   * first-party model ID. If the input doesn't match any current override value,
 81   * it is returned unchanged. Safe to call during module init (no-ops if settings
 82   * aren't loaded yet).
 83   */
 84  export function resolveOverriddenModel(modelId: string): string {
 85    let overrides: Record<string, string> | undefined
 86    try {
 87      overrides = getInitialSettings().modelOverrides
 88    } catch {
 89      return modelId
 90    }
 91    if (!overrides) {
 92      return modelId
 93    }
 94    for (const [canonicalId, override] of Object.entries(overrides)) {
 95      if (override === modelId) {
 96        return canonicalId
 97      }
 98    }
 99    return modelId
100  }
101  
102  const updateBedrockModelStrings = sequential(async () => {
103    if (getModelStringsState() !== null) {
104      // Already initialized. Doing the check here, combined with
105      // `sequential`, allows the test suite to reset the state
106      // between tests while still preventing multiple API calls
107      // in production.
108      return
109    }
110    try {
111      const ms = await getBedrockModelStrings()
112      setModelStringsState(ms)
113    } catch (error) {
114      logError(error as Error)
115    }
116  })
117  
118  function initModelStrings(): void {
119    const ms = getModelStringsState()
120    if (ms !== null) {
121      // Already initialized
122      return
123    }
124    // Initial with default values for non-Bedrock providers
125    if (getAPIProvider() !== 'bedrock') {
126      setModelStringsState(getBuiltinModelStrings(getAPIProvider()))
127      return
128    }
129    // On Bedrock, update model strings in the background without blocking.
130    // Don't set the state in this case so that we can use `sequential` on
131    // `updateBedrockModelStrings` and check for existing state on multiple
132    // calls.
133    void updateBedrockModelStrings()
134  }
135  
136  export function getModelStrings(): ModelStrings {
137    const ms = getModelStringsState()
138    if (ms === null) {
139      initModelStrings()
140      // Bedrock path falls through here while the profile fetch runs in the
141      // background — still honor overrides on the interim defaults.
142      return applyModelOverrides(getBuiltinModelStrings(getAPIProvider()))
143    }
144    return applyModelOverrides(ms)
145  }
146  
147  /**
148   * Ensure model strings are fully initialized.
149   * For Bedrock users, this waits for the profile fetch to complete.
150   * Call this before generating model options to ensure correct region strings.
151   */
152  export async function ensureModelStringsInitialized(): Promise<void> {
153    const ms = getModelStringsState()
154    if (ms !== null) {
155      return
156    }
157  
158    // For non-Bedrock, initialize synchronously
159    if (getAPIProvider() !== 'bedrock') {
160      setModelStringsState(getBuiltinModelStrings(getAPIProvider()))
161      return
162    }
163  
164    // For Bedrock, wait for the profile fetch
165    await updateBedrockModelStrings()
166  }