update_model_catalog.py
1 """Update the MLflow model catalog from upstream data sources. 2 3 Usage: 4 uv run python dev/update_model_catalog.py [--output-dir PATH] 5 6 Fetches the LiteLLM model_prices_and_context_window.json from GitHub, 7 transforms it into the MLflow-native schema, and merges the results into the 8 per-provider catalog files in the output directory (default: mlflow/utils/model_catalog/). 9 10 Models present in the upstream source always take precedence over existing entries 11 (upstream wins). Models not present in the upstream source are preserved, allowing 12 community additions to coexist with automated upstream syncs. 13 Models with a deprecation_date in the past are dropped during conversion. 14 """ 15 16 import argparse 17 import json 18 import re 19 import urllib.request 20 from datetime import date, datetime 21 from pathlib import Path 22 from typing import Any 23 24 SCHEMA_VERSION = "1.0" 25 26 # Modes that MLflow cares about for gateway / cost tracking 27 _SUPPORTED_MODES = {"chat", "completion", "embedding"} 28 29 # Providers that should be consolidated into a canonical name 30 _PROVIDER_CONSOLIDATION = { 31 "vertex_ai-anthropic": "vertex_ai", 32 "vertex_ai-llama_models": "vertex_ai", 33 "vertex_ai-mistral": "vertex_ai", 34 "vertex_ai-chat-models": "vertex_ai", 35 "vertex_ai-text-models": "vertex_ai", 36 "vertex_ai-code-chat-models": "vertex_ai", 37 "vertex_ai-code-text-models": "vertex_ai", 38 "vertex_ai-embedding-models": "vertex_ai", 39 "vertex_ai-vision-models": "vertex_ai", 40 "bedrock_converse": "bedrock", 41 } 42 43 44 def _normalize_provider(provider: str) -> str: 45 if provider in _PROVIDER_CONSOLIDATION: 46 return _PROVIDER_CONSOLIDATION[provider] 47 if provider.startswith("vertex_ai-"): 48 return "vertex_ai" 49 return provider 50 51 52 _PER_MILLION = 1_000_000 53 _PER_THOUSAND = 1_000 54 55 56 def _to_per_million(cost_per_token: float) -> float: 57 return round(cost_per_token * _PER_MILLION, 10) 58 59 60 def _extract_base_pricing(info: dict[str, Any]) -> dict[str, Any]: 61 """Extract base pricing fields from a LiteLLM entry (converted to per-million-tokens).""" 62 pricing = {} 63 if (v := info.get("input_cost_per_token")) is not None: 64 pricing["input_per_million_tokens"] = _to_per_million(v) 65 if (v := info.get("output_cost_per_token")) is not None: 66 pricing["output_per_million_tokens"] = _to_per_million(v) 67 if (v := info.get("cache_read_input_token_cost")) is not None: 68 pricing["cache_read_per_million_tokens"] = _to_per_million(v) 69 if (v := info.get("cache_creation_input_token_cost")) is not None: 70 pricing["cache_write_per_million_tokens"] = _to_per_million(v) 71 return pricing 72 73 74 _MODALITY_INPUT = re.compile(r"^input_cost_per_([a-z0-9_]+)_token$") 75 _MODALITY_OUTPUT = re.compile(r"^output_cost_per_([a-z0-9_]+)_token$") 76 _MODALITY_CACHE_READ = re.compile(r"^cache_read_input_([a-z0-9_]+)_token_cost$") 77 _MODALITY_CACHE_WRITE = re.compile(r"^cache_creation_input_([a-z0-9_]+)_token_cost$") 78 _MODALITY_CACHE_READ_ALT = re.compile(r"^cache_read_input_token_cost_per_([a-z0-9_]+)_token$") 79 _EXCLUDED_MODALITIES = {"reasoning"} 80 81 82 def _extract_modality_pricing(info: dict[str, Any]) -> dict[str, dict[str, float]]: 83 """Extract modality-specific pricing (audio/image/etc) as per-million-token rates.""" 84 modalities: dict[str, dict[str, float]] = {} 85 for k, v in info.items(): 86 if m := _MODALITY_INPUT.match(k): 87 modality = m.group(1) 88 if modality in _EXCLUDED_MODALITIES: 89 continue 90 modalities.setdefault(modality, {})["input_per_million_tokens"] = _to_per_million(v) 91 elif m := _MODALITY_OUTPUT.match(k): 92 modality = m.group(1) 93 if modality in _EXCLUDED_MODALITIES: 94 continue 95 modalities.setdefault(modality, {})["output_per_million_tokens"] = _to_per_million(v) 96 elif m := _MODALITY_CACHE_READ.match(k): 97 modality = m.group(1) 98 if modality in _EXCLUDED_MODALITIES: 99 continue 100 modality_entry = modalities.setdefault(modality, {}) 101 modality_entry["cache_read_per_million_tokens"] = _to_per_million(v) 102 elif m := _MODALITY_CACHE_WRITE.match(k): 103 modality = m.group(1) 104 if modality in _EXCLUDED_MODALITIES: 105 continue 106 modality_entry = modalities.setdefault(modality, {}) 107 modality_entry["cache_write_per_million_tokens"] = _to_per_million(v) 108 elif m := _MODALITY_CACHE_READ_ALT.match(k): 109 modality = m.group(1) 110 if modality in _EXCLUDED_MODALITIES: 111 continue 112 modality_entry = modalities.setdefault(modality, {}) 113 modality_entry["cache_read_per_million_tokens"] = _to_per_million(v) 114 115 return modalities 116 117 118 def _extract_tool_pricing(info: dict[str, Any]) -> dict[str, Any]: 119 """Extract tool-related pricing and tool-use token overhead fields.""" 120 tool_pricing: dict[str, Any] = {} 121 122 if (v := info.get("computer_use_input_cost_per_1k_tokens")) is not None: 123 tool_pricing.setdefault("computer_use", {})["input_per_million_tokens"] = round( 124 v * _PER_THOUSAND, 10 125 ) 126 if (v := info.get("computer_use_output_cost_per_1k_tokens")) is not None: 127 tool_pricing.setdefault("computer_use", {})["output_per_million_tokens"] = round( 128 v * _PER_THOUSAND, 10 129 ) 130 if (v := info.get("search_context_cost_per_query")) is not None: 131 tool_pricing["search_context_per_query"] = v 132 if (v := info.get("tool_use_system_prompt_tokens")) is not None: 133 tool_pricing["tool_use_system_prompt_tokens"] = v 134 135 return tool_pricing 136 137 138 # LiteLLM uses suffixes like _batches, _batch_requests, _flex, _priority 139 _TIER_PATTERNS = { 140 "batch": re.compile(r"^(input|output)_cost_per_token_(batches|batch_requests)$"), 141 "flex": re.compile(r"^(input|output)_cost_per_token_flex$"), 142 "priority": re.compile(r"^(input|output)_cost_per_token_priority$"), 143 } 144 145 _TIER_CACHE_PATTERNS = { 146 "batch": re.compile(r"^cache_read_input_token_cost_(batches|batch_requests)$"), 147 "flex": re.compile(r"^cache_read_input_token_cost_flex$"), 148 "priority": re.compile(r"^cache_read_input_token_cost_priority$"), 149 } 150 151 152 def _extract_service_tiers(info: dict[str, Any]) -> dict[str, dict[str, float]]: 153 """Extract service tier pricing overrides (batch, flex, priority).""" 154 tiers: dict[str, dict[str, float]] = {} 155 for tier_name, pattern in _TIER_PATTERNS.items(): 156 for k, v in info.items(): 157 if m := pattern.match(k): 158 direction = m.group(1) # "input" or "output" 159 tiers.setdefault(tier_name, {})[f"{direction}_per_million_tokens"] = ( 160 _to_per_million(v) 161 ) 162 163 for tier_name, pattern in _TIER_CACHE_PATTERNS.items(): 164 for k, v in info.items(): 165 if pattern.match(k): 166 tiers.setdefault(tier_name, {})["cache_read_per_million_tokens"] = _to_per_million( 167 v 168 ) 169 170 return tiers 171 172 173 # Matches keys like input_cost_per_token_above_200k_tokens or 174 # cache_read_input_token_cost_above_128k_tokens 175 _LONG_CTX_INPUT = re.compile(r"^input_cost_per_token_above_(\d+[km]?)_tokens$") 176 _LONG_CTX_OUTPUT = re.compile(r"^output_cost_per_token_above_(\d+[km]?)_tokens$") 177 _LONG_CTX_CACHE_READ = re.compile(r"^cache_read_input_token_cost_above_(\d+[km]?)_tokens$") 178 _LONG_CTX_CACHE_WRITE = re.compile(r"^cache_creation_input_token_cost_above_(\d+[km]?)_tokens$") 179 180 181 def _parse_threshold(s: str) -> int: 182 """Convert threshold string like '200k', '128k', or '1m' to token count.""" 183 s = s.lower() 184 if s.endswith("m"): 185 return int(s[:-1]) * 1_000_000 186 if s.endswith("k"): 187 return int(s[:-1]) * 1_000 188 return int(s) 189 190 191 def _extract_long_context_pricing(info: dict[str, Any]) -> list[dict[str, Any]]: 192 """Extract long-context pricing tiers as a list of threshold overrides.""" 193 # Group by threshold 194 thresholds: dict[int, dict[str, Any]] = {} 195 for k, v in info.items(): 196 if m := _LONG_CTX_INPUT.match(k): 197 t = _parse_threshold(m.group(1)) 198 thresholds.setdefault(t, {"threshold_tokens": t})["input_per_million_tokens"] = ( 199 _to_per_million(v) 200 ) 201 elif m := _LONG_CTX_OUTPUT.match(k): 202 t = _parse_threshold(m.group(1)) 203 thresholds.setdefault(t, {"threshold_tokens": t})["output_per_million_tokens"] = ( 204 _to_per_million(v) 205 ) 206 elif m := _LONG_CTX_CACHE_READ.match(k): 207 t = _parse_threshold(m.group(1)) 208 thresholds.setdefault(t, {"threshold_tokens": t})["cache_read_per_million_tokens"] = ( 209 _to_per_million(v) 210 ) 211 elif m := _LONG_CTX_CACHE_WRITE.match(k): 212 t = _parse_threshold(m.group(1)) 213 thresholds.setdefault(t, {"threshold_tokens": t})["cache_write_per_million_tokens"] = ( 214 _to_per_million(v) 215 ) 216 217 return sorted(thresholds.values(), key=lambda x: x["threshold_tokens"]) 218 219 220 def _is_deprecated(info: dict[str, Any]) -> bool: 221 """Return True if the model's deprecation_date is in the past.""" 222 dep = info.get("deprecation_date") 223 if not dep: 224 return False 225 try: 226 return datetime.strptime(dep, "%Y-%m-%d").date() < date.today() 227 except ValueError: 228 return False 229 230 231 def _transform_entry(info: dict[str, Any]) -> dict[str, Any] | None: 232 """Transform a single LiteLLM model entry into MLflow-native schema.""" 233 mode = info.get("mode") 234 if mode not in _SUPPORTED_MODES: 235 return None 236 237 if _is_deprecated(info): 238 return None 239 240 pricing = _extract_base_pricing(info) 241 242 if service_tiers := _extract_service_tiers(info): 243 pricing["service_tiers"] = service_tiers 244 245 if long_context := _extract_long_context_pricing(info): 246 pricing["long_context"] = long_context 247 248 if modality_pricing := _extract_modality_pricing(info): 249 pricing["modality"] = modality_pricing 250 251 if tool_pricing := _extract_tool_pricing(info): 252 pricing["tooling"] = tool_pricing 253 254 capabilities = { 255 "function_calling": info.get("supports_function_calling", False), 256 "vision": info.get("supports_vision", False), 257 "reasoning": info.get("supports_reasoning", False), 258 "prompt_caching": info.get("supports_prompt_caching", False), 259 "response_schema": info.get("supports_response_schema", False), 260 } 261 262 context_window = {} 263 if (v := info.get("max_input_tokens")) is not None: 264 context_window["max_input"] = v 265 if (v := info.get("max_output_tokens")) is not None: 266 context_window["max_output"] = v 267 if (v := info.get("max_tokens")) is not None: 268 context_window["max_tokens"] = v 269 270 entry = {"mode": mode} 271 if context_window: 272 entry["context_window"] = context_window 273 if pricing: 274 entry["pricing"] = pricing 275 entry["capabilities"] = capabilities 276 if dep := info.get("deprecation_date"): 277 entry["deprecation_date"] = dep 278 279 return entry 280 281 282 _LEGACY_PRICING_KEY_MAP = { 283 "input_per_token": "input_per_million_tokens", 284 "output_per_token": "output_per_million_tokens", 285 "cache_read_per_token": "cache_read_per_million_tokens", 286 "cache_write_per_token": "cache_write_per_million_tokens", 287 } 288 289 290 def _migrate_pricing_block(pricing: dict[str, Any]) -> dict[str, Any]: 291 """Convert legacy *_per_token keys to *_per_million_tokens in a flat pricing block.""" 292 result = {} 293 for k, v in pricing.items(): 294 if k in _LEGACY_PRICING_KEY_MAP: 295 result[_LEGACY_PRICING_KEY_MAP[k]] = _to_per_million(v) 296 else: 297 result[k] = v 298 return result 299 300 301 def _migrate_legacy_pricing(entry: dict[str, Any]) -> dict[str, Any]: 302 """Migrate legacy *_per_token pricing keys to *_per_million_tokens in a catalog entry. 303 304 Applies the migration at the top level of the pricing block and recursively within 305 service_tiers, long_context, and modality sub-sections. 306 """ 307 if "pricing" not in entry: 308 return entry 309 310 entry = {**entry} 311 pricing = _migrate_pricing_block(entry["pricing"]) 312 313 if "service_tiers" in pricing: 314 pricing["service_tiers"] = { 315 tier: _migrate_pricing_block(tier_data) 316 for tier, tier_data in pricing["service_tiers"].items() 317 } 318 319 if "long_context" in pricing: 320 pricing["long_context"] = [_migrate_pricing_block(ctx) for ctx in pricing["long_context"]] 321 322 if "modality" in pricing: 323 pricing["modality"] = { 324 mod: _migrate_pricing_block(mod_data) for mod, mod_data in pricing["modality"].items() 325 } 326 327 entry["pricing"] = pricing 328 return entry 329 330 331 _LITELLM_URL = ( 332 "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" 333 ) 334 335 336 def _fetch_litellm_catalog() -> dict[str, Any]: 337 """Download the latest LiteLLM model catalog from GitHub.""" 338 print(f"Fetching {_LITELLM_URL} ...") 339 with urllib.request.urlopen(_LITELLM_URL, timeout=30) as resp: 340 data: dict[str, Any] = json.loads(resp.read().decode("utf-8")) 341 return data 342 343 344 def convert(raw: dict[str, Any], output_dir: Path) -> dict[str, int]: 345 """Convert upstream catalog dict to per-provider MLflow catalog files. 346 347 Returns a dict mapping provider names to model counts. 348 """ 349 350 today = date.today().isoformat() 351 352 # Load existing catalog files first so we can detect which entries have changed 353 output_dir.mkdir(parents=True, exist_ok=True) 354 existing_catalogs: dict[str, dict[str, Any]] = {} 355 for provider_file in output_dir.glob("*.json"): 356 try: 357 existing = json.loads(provider_file.read_text(encoding="utf-8")) 358 existing_catalogs[provider_file.stem] = existing.get("models", {}) 359 except (json.JSONDecodeError, OSError) as e: 360 print(f" Warning: could not read existing {provider_file.name}: {e}") 361 362 # Group by provider 363 providers: dict[str, dict[str, dict[str, Any]]] = {} 364 seen: set[tuple[str, str]] = set() 365 366 for key, info in raw.items(): 367 if key == "sample_spec": 368 continue 369 370 provider = info.get("litellm_provider") 371 if not provider: 372 continue 373 provider = _normalize_provider(provider) 374 model_name = key.split("/", 1)[-1] 375 376 # Skip fine-tuned variants 377 if model_name.startswith("ft:"): 378 continue 379 380 # Dedupe by (provider, model_name) 381 dedup_key = (provider, model_name) 382 if dedup_key in seen: 383 continue 384 seen.add(dedup_key) 385 386 entry = _transform_entry(info) 387 if entry is None: 388 continue 389 390 # Determine last_updated_at: carry over existing date if entry is unchanged; 391 # set today if no existing date (first-time backfill or new entry) 392 existing_entry = existing_catalogs.get(provider, {}).get(model_name) 393 if existing_entry is not None: 394 existing_without_last_updated_at = { 395 k: v for k, v in existing_entry.items() if k != "last_updated_at" 396 } 397 if entry == existing_without_last_updated_at: 398 # Entry is unchanged; preserve existing last_updated_at or set today if absent 399 entry["last_updated_at"] = existing_entry.get("last_updated_at") or today 400 else: 401 entry["last_updated_at"] = today 402 else: 403 entry["last_updated_at"] = today 404 405 providers.setdefault(provider, {})[model_name] = entry 406 407 # Merge with existing catalogs: preserve models not in upstream (community additions) 408 for provider, existing_models in existing_catalogs.items(): 409 if provider not in providers: 410 providers[provider] = {} 411 for model_name, entry in existing_models.items(): 412 if model_name not in providers.get(provider, {}): 413 migrated = _migrate_legacy_pricing(entry) 414 if "last_updated_at" not in migrated: 415 migrated = {**migrated, "last_updated_at": today} 416 providers.setdefault(provider, {})[model_name] = migrated 417 418 stats = {} 419 for provider, models in sorted(providers.items()): 420 if not models: 421 continue 422 catalog = { 423 "schema_version": SCHEMA_VERSION, 424 "models": dict(sorted(models.items())), 425 } 426 out_path = output_dir / f"{provider}.json" 427 out_path.write_text(json.dumps(catalog, indent=2) + "\n", encoding="utf-8") 428 stats[provider] = len(models) 429 430 return stats 431 432 433 def main() -> None: 434 parser = argparse.ArgumentParser(description=__doc__) 435 parser.add_argument( 436 "--output-dir", 437 type=Path, 438 default=Path("mlflow/utils/model_catalog"), 439 help="Output directory for per-provider JSON files", 440 ) 441 args = parser.parse_args() 442 443 raw = _fetch_litellm_catalog() 444 stats = convert(raw, args.output_dir) 445 total = sum(stats.values()) 446 print(f"Converted {total} models across {len(stats)} providers:") 447 for provider, count in sorted(stats.items(), key=lambda x: -x[1]): 448 print(f" {provider}: {count}") 449 450 451 if __name__ == "__main__": 452 main()