/ dev / update_model_catalog.py
update_model_catalog.py
  1  """Update the MLflow model catalog from upstream data sources.
  2  
  3  Usage:
  4      uv run python dev/update_model_catalog.py [--output-dir PATH]
  5  
  6  Fetches the LiteLLM model_prices_and_context_window.json from GitHub,
  7  transforms it into the MLflow-native schema, and merges the results into the
  8  per-provider catalog files in the output directory (default: mlflow/utils/model_catalog/).
  9  
 10  Models present in the upstream source always take precedence over existing entries
 11  (upstream wins). Models not present in the upstream source are preserved, allowing
 12  community additions to coexist with automated upstream syncs.
 13  Models with a deprecation_date in the past are dropped during conversion.
 14  """
 15  
 16  import argparse
 17  import json
 18  import re
 19  import urllib.request
 20  from datetime import date, datetime
 21  from pathlib import Path
 22  from typing import Any
 23  
 24  SCHEMA_VERSION = "1.0"
 25  
 26  # Modes that MLflow cares about for gateway / cost tracking
 27  _SUPPORTED_MODES = {"chat", "completion", "embedding"}
 28  
 29  # Providers that should be consolidated into a canonical name
 30  _PROVIDER_CONSOLIDATION = {
 31      "vertex_ai-anthropic": "vertex_ai",
 32      "vertex_ai-llama_models": "vertex_ai",
 33      "vertex_ai-mistral": "vertex_ai",
 34      "vertex_ai-chat-models": "vertex_ai",
 35      "vertex_ai-text-models": "vertex_ai",
 36      "vertex_ai-code-chat-models": "vertex_ai",
 37      "vertex_ai-code-text-models": "vertex_ai",
 38      "vertex_ai-embedding-models": "vertex_ai",
 39      "vertex_ai-vision-models": "vertex_ai",
 40      "bedrock_converse": "bedrock",
 41  }
 42  
 43  
 44  def _normalize_provider(provider: str) -> str:
 45      if provider in _PROVIDER_CONSOLIDATION:
 46          return _PROVIDER_CONSOLIDATION[provider]
 47      if provider.startswith("vertex_ai-"):
 48          return "vertex_ai"
 49      return provider
 50  
 51  
 52  _PER_MILLION = 1_000_000
 53  _PER_THOUSAND = 1_000
 54  
 55  
 56  def _to_per_million(cost_per_token: float) -> float:
 57      return round(cost_per_token * _PER_MILLION, 10)
 58  
 59  
 60  def _extract_base_pricing(info: dict[str, Any]) -> dict[str, Any]:
 61      """Extract base pricing fields from a LiteLLM entry (converted to per-million-tokens)."""
 62      pricing = {}
 63      if (v := info.get("input_cost_per_token")) is not None:
 64          pricing["input_per_million_tokens"] = _to_per_million(v)
 65      if (v := info.get("output_cost_per_token")) is not None:
 66          pricing["output_per_million_tokens"] = _to_per_million(v)
 67      if (v := info.get("cache_read_input_token_cost")) is not None:
 68          pricing["cache_read_per_million_tokens"] = _to_per_million(v)
 69      if (v := info.get("cache_creation_input_token_cost")) is not None:
 70          pricing["cache_write_per_million_tokens"] = _to_per_million(v)
 71      return pricing
 72  
 73  
 74  _MODALITY_INPUT = re.compile(r"^input_cost_per_([a-z0-9_]+)_token$")
 75  _MODALITY_OUTPUT = re.compile(r"^output_cost_per_([a-z0-9_]+)_token$")
 76  _MODALITY_CACHE_READ = re.compile(r"^cache_read_input_([a-z0-9_]+)_token_cost$")
 77  _MODALITY_CACHE_WRITE = re.compile(r"^cache_creation_input_([a-z0-9_]+)_token_cost$")
 78  _MODALITY_CACHE_READ_ALT = re.compile(r"^cache_read_input_token_cost_per_([a-z0-9_]+)_token$")
 79  _EXCLUDED_MODALITIES = {"reasoning"}
 80  
 81  
 82  def _extract_modality_pricing(info: dict[str, Any]) -> dict[str, dict[str, float]]:
 83      """Extract modality-specific pricing (audio/image/etc) as per-million-token rates."""
 84      modalities: dict[str, dict[str, float]] = {}
 85      for k, v in info.items():
 86          if m := _MODALITY_INPUT.match(k):
 87              modality = m.group(1)
 88              if modality in _EXCLUDED_MODALITIES:
 89                  continue
 90              modalities.setdefault(modality, {})["input_per_million_tokens"] = _to_per_million(v)
 91          elif m := _MODALITY_OUTPUT.match(k):
 92              modality = m.group(1)
 93              if modality in _EXCLUDED_MODALITIES:
 94                  continue
 95              modalities.setdefault(modality, {})["output_per_million_tokens"] = _to_per_million(v)
 96          elif m := _MODALITY_CACHE_READ.match(k):
 97              modality = m.group(1)
 98              if modality in _EXCLUDED_MODALITIES:
 99                  continue
100              modality_entry = modalities.setdefault(modality, {})
101              modality_entry["cache_read_per_million_tokens"] = _to_per_million(v)
102          elif m := _MODALITY_CACHE_WRITE.match(k):
103              modality = m.group(1)
104              if modality in _EXCLUDED_MODALITIES:
105                  continue
106              modality_entry = modalities.setdefault(modality, {})
107              modality_entry["cache_write_per_million_tokens"] = _to_per_million(v)
108          elif m := _MODALITY_CACHE_READ_ALT.match(k):
109              modality = m.group(1)
110              if modality in _EXCLUDED_MODALITIES:
111                  continue
112              modality_entry = modalities.setdefault(modality, {})
113              modality_entry["cache_read_per_million_tokens"] = _to_per_million(v)
114  
115      return modalities
116  
117  
118  def _extract_tool_pricing(info: dict[str, Any]) -> dict[str, Any]:
119      """Extract tool-related pricing and tool-use token overhead fields."""
120      tool_pricing: dict[str, Any] = {}
121  
122      if (v := info.get("computer_use_input_cost_per_1k_tokens")) is not None:
123          tool_pricing.setdefault("computer_use", {})["input_per_million_tokens"] = round(
124              v * _PER_THOUSAND, 10
125          )
126      if (v := info.get("computer_use_output_cost_per_1k_tokens")) is not None:
127          tool_pricing.setdefault("computer_use", {})["output_per_million_tokens"] = round(
128              v * _PER_THOUSAND, 10
129          )
130      if (v := info.get("search_context_cost_per_query")) is not None:
131          tool_pricing["search_context_per_query"] = v
132      if (v := info.get("tool_use_system_prompt_tokens")) is not None:
133          tool_pricing["tool_use_system_prompt_tokens"] = v
134  
135      return tool_pricing
136  
137  
138  # LiteLLM uses suffixes like _batches, _batch_requests, _flex, _priority
139  _TIER_PATTERNS = {
140      "batch": re.compile(r"^(input|output)_cost_per_token_(batches|batch_requests)$"),
141      "flex": re.compile(r"^(input|output)_cost_per_token_flex$"),
142      "priority": re.compile(r"^(input|output)_cost_per_token_priority$"),
143  }
144  
145  _TIER_CACHE_PATTERNS = {
146      "batch": re.compile(r"^cache_read_input_token_cost_(batches|batch_requests)$"),
147      "flex": re.compile(r"^cache_read_input_token_cost_flex$"),
148      "priority": re.compile(r"^cache_read_input_token_cost_priority$"),
149  }
150  
151  
152  def _extract_service_tiers(info: dict[str, Any]) -> dict[str, dict[str, float]]:
153      """Extract service tier pricing overrides (batch, flex, priority)."""
154      tiers: dict[str, dict[str, float]] = {}
155      for tier_name, pattern in _TIER_PATTERNS.items():
156          for k, v in info.items():
157              if m := pattern.match(k):
158                  direction = m.group(1)  # "input" or "output"
159                  tiers.setdefault(tier_name, {})[f"{direction}_per_million_tokens"] = (
160                      _to_per_million(v)
161                  )
162  
163      for tier_name, pattern in _TIER_CACHE_PATTERNS.items():
164          for k, v in info.items():
165              if pattern.match(k):
166                  tiers.setdefault(tier_name, {})["cache_read_per_million_tokens"] = _to_per_million(
167                      v
168                  )
169  
170      return tiers
171  
172  
173  # Matches keys like input_cost_per_token_above_200k_tokens or
174  # cache_read_input_token_cost_above_128k_tokens
175  _LONG_CTX_INPUT = re.compile(r"^input_cost_per_token_above_(\d+[km]?)_tokens$")
176  _LONG_CTX_OUTPUT = re.compile(r"^output_cost_per_token_above_(\d+[km]?)_tokens$")
177  _LONG_CTX_CACHE_READ = re.compile(r"^cache_read_input_token_cost_above_(\d+[km]?)_tokens$")
178  _LONG_CTX_CACHE_WRITE = re.compile(r"^cache_creation_input_token_cost_above_(\d+[km]?)_tokens$")
179  
180  
181  def _parse_threshold(s: str) -> int:
182      """Convert threshold string like '200k', '128k', or '1m' to token count."""
183      s = s.lower()
184      if s.endswith("m"):
185          return int(s[:-1]) * 1_000_000
186      if s.endswith("k"):
187          return int(s[:-1]) * 1_000
188      return int(s)
189  
190  
191  def _extract_long_context_pricing(info: dict[str, Any]) -> list[dict[str, Any]]:
192      """Extract long-context pricing tiers as a list of threshold overrides."""
193      # Group by threshold
194      thresholds: dict[int, dict[str, Any]] = {}
195      for k, v in info.items():
196          if m := _LONG_CTX_INPUT.match(k):
197              t = _parse_threshold(m.group(1))
198              thresholds.setdefault(t, {"threshold_tokens": t})["input_per_million_tokens"] = (
199                  _to_per_million(v)
200              )
201          elif m := _LONG_CTX_OUTPUT.match(k):
202              t = _parse_threshold(m.group(1))
203              thresholds.setdefault(t, {"threshold_tokens": t})["output_per_million_tokens"] = (
204                  _to_per_million(v)
205              )
206          elif m := _LONG_CTX_CACHE_READ.match(k):
207              t = _parse_threshold(m.group(1))
208              thresholds.setdefault(t, {"threshold_tokens": t})["cache_read_per_million_tokens"] = (
209                  _to_per_million(v)
210              )
211          elif m := _LONG_CTX_CACHE_WRITE.match(k):
212              t = _parse_threshold(m.group(1))
213              thresholds.setdefault(t, {"threshold_tokens": t})["cache_write_per_million_tokens"] = (
214                  _to_per_million(v)
215              )
216  
217      return sorted(thresholds.values(), key=lambda x: x["threshold_tokens"])
218  
219  
220  def _is_deprecated(info: dict[str, Any]) -> bool:
221      """Return True if the model's deprecation_date is in the past."""
222      dep = info.get("deprecation_date")
223      if not dep:
224          return False
225      try:
226          return datetime.strptime(dep, "%Y-%m-%d").date() < date.today()
227      except ValueError:
228          return False
229  
230  
231  def _transform_entry(info: dict[str, Any]) -> dict[str, Any] | None:
232      """Transform a single LiteLLM model entry into MLflow-native schema."""
233      mode = info.get("mode")
234      if mode not in _SUPPORTED_MODES:
235          return None
236  
237      if _is_deprecated(info):
238          return None
239  
240      pricing = _extract_base_pricing(info)
241  
242      if service_tiers := _extract_service_tiers(info):
243          pricing["service_tiers"] = service_tiers
244  
245      if long_context := _extract_long_context_pricing(info):
246          pricing["long_context"] = long_context
247  
248      if modality_pricing := _extract_modality_pricing(info):
249          pricing["modality"] = modality_pricing
250  
251      if tool_pricing := _extract_tool_pricing(info):
252          pricing["tooling"] = tool_pricing
253  
254      capabilities = {
255          "function_calling": info.get("supports_function_calling", False),
256          "vision": info.get("supports_vision", False),
257          "reasoning": info.get("supports_reasoning", False),
258          "prompt_caching": info.get("supports_prompt_caching", False),
259          "response_schema": info.get("supports_response_schema", False),
260      }
261  
262      context_window = {}
263      if (v := info.get("max_input_tokens")) is not None:
264          context_window["max_input"] = v
265      if (v := info.get("max_output_tokens")) is not None:
266          context_window["max_output"] = v
267      if (v := info.get("max_tokens")) is not None:
268          context_window["max_tokens"] = v
269  
270      entry = {"mode": mode}
271      if context_window:
272          entry["context_window"] = context_window
273      if pricing:
274          entry["pricing"] = pricing
275      entry["capabilities"] = capabilities
276      if dep := info.get("deprecation_date"):
277          entry["deprecation_date"] = dep
278  
279      return entry
280  
281  
282  _LEGACY_PRICING_KEY_MAP = {
283      "input_per_token": "input_per_million_tokens",
284      "output_per_token": "output_per_million_tokens",
285      "cache_read_per_token": "cache_read_per_million_tokens",
286      "cache_write_per_token": "cache_write_per_million_tokens",
287  }
288  
289  
290  def _migrate_pricing_block(pricing: dict[str, Any]) -> dict[str, Any]:
291      """Convert legacy *_per_token keys to *_per_million_tokens in a flat pricing block."""
292      result = {}
293      for k, v in pricing.items():
294          if k in _LEGACY_PRICING_KEY_MAP:
295              result[_LEGACY_PRICING_KEY_MAP[k]] = _to_per_million(v)
296          else:
297              result[k] = v
298      return result
299  
300  
301  def _migrate_legacy_pricing(entry: dict[str, Any]) -> dict[str, Any]:
302      """Migrate legacy *_per_token pricing keys to *_per_million_tokens in a catalog entry.
303  
304      Applies the migration at the top level of the pricing block and recursively within
305      service_tiers, long_context, and modality sub-sections.
306      """
307      if "pricing" not in entry:
308          return entry
309  
310      entry = {**entry}
311      pricing = _migrate_pricing_block(entry["pricing"])
312  
313      if "service_tiers" in pricing:
314          pricing["service_tiers"] = {
315              tier: _migrate_pricing_block(tier_data)
316              for tier, tier_data in pricing["service_tiers"].items()
317          }
318  
319      if "long_context" in pricing:
320          pricing["long_context"] = [_migrate_pricing_block(ctx) for ctx in pricing["long_context"]]
321  
322      if "modality" in pricing:
323          pricing["modality"] = {
324              mod: _migrate_pricing_block(mod_data) for mod, mod_data in pricing["modality"].items()
325          }
326  
327      entry["pricing"] = pricing
328      return entry
329  
330  
331  _LITELLM_URL = (
332      "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
333  )
334  
335  
336  def _fetch_litellm_catalog() -> dict[str, Any]:
337      """Download the latest LiteLLM model catalog from GitHub."""
338      print(f"Fetching {_LITELLM_URL} ...")
339      with urllib.request.urlopen(_LITELLM_URL, timeout=30) as resp:
340          data: dict[str, Any] = json.loads(resp.read().decode("utf-8"))
341          return data
342  
343  
344  def convert(raw: dict[str, Any], output_dir: Path) -> dict[str, int]:
345      """Convert upstream catalog dict to per-provider MLflow catalog files.
346  
347      Returns a dict mapping provider names to model counts.
348      """
349  
350      today = date.today().isoformat()
351  
352      # Load existing catalog files first so we can detect which entries have changed
353      output_dir.mkdir(parents=True, exist_ok=True)
354      existing_catalogs: dict[str, dict[str, Any]] = {}
355      for provider_file in output_dir.glob("*.json"):
356          try:
357              existing = json.loads(provider_file.read_text(encoding="utf-8"))
358              existing_catalogs[provider_file.stem] = existing.get("models", {})
359          except (json.JSONDecodeError, OSError) as e:
360              print(f"  Warning: could not read existing {provider_file.name}: {e}")
361  
362      # Group by provider
363      providers: dict[str, dict[str, dict[str, Any]]] = {}
364      seen: set[tuple[str, str]] = set()
365  
366      for key, info in raw.items():
367          if key == "sample_spec":
368              continue
369  
370          provider = info.get("litellm_provider")
371          if not provider:
372              continue
373          provider = _normalize_provider(provider)
374          model_name = key.split("/", 1)[-1]
375  
376          # Skip fine-tuned variants
377          if model_name.startswith("ft:"):
378              continue
379  
380          # Dedupe by (provider, model_name)
381          dedup_key = (provider, model_name)
382          if dedup_key in seen:
383              continue
384          seen.add(dedup_key)
385  
386          entry = _transform_entry(info)
387          if entry is None:
388              continue
389  
390          # Determine last_updated_at: carry over existing date if entry is unchanged;
391          # set today if no existing date (first-time backfill or new entry)
392          existing_entry = existing_catalogs.get(provider, {}).get(model_name)
393          if existing_entry is not None:
394              existing_without_last_updated_at = {
395                  k: v for k, v in existing_entry.items() if k != "last_updated_at"
396              }
397              if entry == existing_without_last_updated_at:
398                  # Entry is unchanged; preserve existing last_updated_at or set today if absent
399                  entry["last_updated_at"] = existing_entry.get("last_updated_at") or today
400              else:
401                  entry["last_updated_at"] = today
402          else:
403              entry["last_updated_at"] = today
404  
405          providers.setdefault(provider, {})[model_name] = entry
406  
407      # Merge with existing catalogs: preserve models not in upstream (community additions)
408      for provider, existing_models in existing_catalogs.items():
409          if provider not in providers:
410              providers[provider] = {}
411          for model_name, entry in existing_models.items():
412              if model_name not in providers.get(provider, {}):
413                  migrated = _migrate_legacy_pricing(entry)
414                  if "last_updated_at" not in migrated:
415                      migrated = {**migrated, "last_updated_at": today}
416                  providers.setdefault(provider, {})[model_name] = migrated
417  
418      stats = {}
419      for provider, models in sorted(providers.items()):
420          if not models:
421              continue
422          catalog = {
423              "schema_version": SCHEMA_VERSION,
424              "models": dict(sorted(models.items())),
425          }
426          out_path = output_dir / f"{provider}.json"
427          out_path.write_text(json.dumps(catalog, indent=2) + "\n", encoding="utf-8")
428          stats[provider] = len(models)
429  
430      return stats
431  
432  
433  def main() -> None:
434      parser = argparse.ArgumentParser(description=__doc__)
435      parser.add_argument(
436          "--output-dir",
437          type=Path,
438          default=Path("mlflow/utils/model_catalog"),
439          help="Output directory for per-provider JSON files",
440      )
441      args = parser.parse_args()
442  
443      raw = _fetch_litellm_catalog()
444      stats = convert(raw, args.output_dir)
445      total = sum(stats.values())
446      print(f"Converted {total} models across {len(stats)} providers:")
447      for provider, count in sorted(stats.items(), key=lambda x: -x[1]):
448          print(f"  {provider}: {count}")
449  
450  
451  if __name__ == "__main__":
452      main()