test_providers.py
1 import json 2 from unittest import mock 3 4 import pytest 5 6 from mlflow.exceptions import MlflowException 7 from mlflow.utils.providers import ( 8 _fetch_remote_provider, 9 _flatten_catalog_entry, 10 _get_remote_cache, 11 _list_provider_names, 12 _load_bundled_provider, 13 _load_provider, 14 _normalize_provider, 15 cost_per_token, 16 get_all_providers, 17 get_models, 18 get_provider_config_response, 19 ) 20 21 22 def test_normalize_provider_normalizes_vertex_ai_variants(): 23 assert _normalize_provider("vertex_ai") == "vertex_ai" 24 assert _normalize_provider("vertex_ai-anthropic") == "vertex_ai" 25 assert _normalize_provider("vertex_ai-llama_models") == "vertex_ai" 26 assert _normalize_provider("vertex_ai-mistral") == "vertex_ai" 27 28 29 def test_normalize_provider_does_not_normalize_other_providers(): 30 assert _normalize_provider("openai") == "openai" 31 assert _normalize_provider("anthropic") == "anthropic" 32 assert _normalize_provider("bedrock") == "bedrock" 33 assert _normalize_provider("gemini") == "gemini" 34 35 36 def test_list_provider_names_returns_bundled_providers(): 37 _list_provider_names.cache_clear() 38 providers = _list_provider_names() 39 assert len(providers) > 0 40 assert "openai" in providers 41 assert "anthropic" in providers 42 assert "bedrock" in providers 43 44 45 def test_list_provider_names_excludes_non_json(): 46 _list_provider_names.cache_clear() 47 providers = _list_provider_names() 48 # __init__.py should not appear 49 assert "__init__" not in providers 50 for p in providers: 51 assert not p.endswith(".py") 52 53 54 def test_load_provider_returns_models(monkeypatch): 55 monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "") 56 _load_bundled_provider.cache_clear() 57 models = _load_provider("openai") 58 assert len(models) > 0 59 assert "gpt-4o" in models 60 info = models["gpt-4o"] 61 assert info["mode"] == "chat" 62 assert "input_cost_per_token" in info 63 assert info["input_cost_per_token"] > 0 64 65 66 def test_load_provider_returns_empty_for_unknown(monkeypatch): 67 monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "") 68 _load_bundled_provider.cache_clear() 69 assert _load_provider("nonexistent_provider_xyz") == {} 70 71 72 def test_load_provider_flattens_pricing(monkeypatch): 73 monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "") 74 _load_bundled_provider.cache_clear() 75 models = _load_provider("anthropic") 76 model = next(iter(models.values())) 77 # Should have flat ModelInfo keys, not nested pricing/capabilities 78 assert "input_cost_per_token" in model or "mode" in model 79 assert "pricing" not in model 80 assert "context_window" not in model 81 82 83 def _mock_catalog(provider_data): 84 """Context manager that mocks the per-provider catalog with the given data. 85 86 ``provider_data`` is a dict mapping provider names to ``{model_name: info}`` dicts. 87 """ 88 return ( 89 mock.patch( 90 "mlflow.utils.providers._load_provider", 91 side_effect=lambda p: provider_data.get(p, {}), 92 ), 93 mock.patch( 94 "mlflow.utils.providers._list_provider_names", 95 return_value=list(provider_data.keys()), 96 ), 97 ) 98 99 100 def test_get_all_providers_consolidates_vertex_ai_variants(): 101 data = { 102 "openai": {"gpt-4o": {"mode": "chat"}}, 103 "anthropic": {"claude-3-5-sonnet": {"mode": "chat"}}, 104 "vertex_ai": {"gemini-1.5-pro": {"mode": "chat"}}, 105 "vertex_ai-llama_models": {"meta/llama-4-scout": {"mode": "chat"}}, 106 "vertex_ai-anthropic": {"claude-3-5-sonnet": {"mode": "chat"}}, 107 } 108 with _mock_catalog(data)[0], _mock_catalog(data)[1]: 109 providers = get_all_providers() 110 111 assert "vertex_ai" in providers 112 assert "vertex_ai-llama_models" not in providers 113 assert "vertex_ai-anthropic" not in providers 114 assert "openai" in providers 115 assert "anthropic" in providers 116 117 118 def test_get_models_normalizes_vertex_ai_provider_and_strips_prefix(): 119 data = { 120 "vertex_ai-llama_models": { 121 "meta/llama-4-scout-17b-16e-instruct-maas": { 122 "mode": "chat", 123 "supports_function_calling": True, 124 } 125 }, 126 "vertex_ai-anthropic": { 127 "claude-3-5-sonnet": {"mode": "chat", "supports_function_calling": True} 128 }, 129 "vertex_ai": {"gemini-1.5-pro": {"mode": "chat", "supports_function_calling": True}}, 130 } 131 with _mock_catalog(data)[0], _mock_catalog(data)[1]: 132 models = get_models(provider="vertex_ai") 133 134 assert len(models) == 3 135 for model in models: 136 assert model["provider"] == "vertex_ai" 137 138 model_names = [m["model"] for m in models] 139 assert "meta/llama-4-scout-17b-16e-instruct-maas" in model_names 140 assert "claude-3-5-sonnet" in model_names 141 assert "gemini-1.5-pro" in model_names 142 143 144 def test_get_models_filters_by_consolidated_provider(): 145 data = { 146 "openai": {"gpt-4o": {"mode": "chat"}}, 147 "vertex_ai-llama_models": {"meta/llama-4-scout": {"mode": "chat"}}, 148 } 149 with _mock_catalog(data)[0], _mock_catalog(data)[1]: 150 vertex_models = get_models(provider="vertex_ai") 151 assert len(vertex_models) == 1 152 assert vertex_models[0]["model"] == "meta/llama-4-scout" 153 154 openai_models = get_models(provider="openai") 155 assert len(openai_models) == 1 156 assert openai_models[0]["model"] == "gpt-4o" 157 158 159 def test_get_models_does_not_modify_other_providers(): 160 data = { 161 "openai": {"gpt-4o": {"mode": "chat", "supports_function_calling": True}}, 162 "anthropic": {"claude-3-5-sonnet": {"mode": "chat", "supports_function_calling": True}}, 163 } 164 with _mock_catalog(data)[0], _mock_catalog(data)[1]: 165 models = get_models() 166 167 openai_model = next(m for m in models if m["provider"] == "openai") 168 assert openai_model["model"] == "gpt-4o" 169 170 anthropic_model = next(m for m in models if m["provider"] == "anthropic") 171 assert anthropic_model["model"] == "claude-3-5-sonnet" 172 173 174 def test_get_models_dedupes_models_after_normalization(): 175 data = { 176 "vertex_ai": { 177 "gemini-3-flash-preview": {"mode": "chat", "supports_function_calling": True} 178 }, 179 "vertex_ai-chat-models": { 180 "gemini-3-flash-preview": {"mode": "chat", "supports_function_calling": True} 181 }, 182 } 183 with _mock_catalog(data)[0], _mock_catalog(data)[1]: 184 models = get_models(provider="vertex_ai") 185 186 model_names = [m["model"] for m in models] 187 assert model_names.count("gemini-3-flash-preview") == 1 188 assert len(models) == 1 189 190 191 def test_get_all_providers_with_allowed_filter(monkeypatch): 192 data = { 193 "openai": {"gpt-4o": {"mode": "chat"}}, 194 "anthropic": {"claude-3-5-sonnet": {"mode": "chat"}}, 195 "gemini": {"gemini-1.5-pro": {"mode": "chat"}}, 196 } 197 with _mock_catalog(data)[0], _mock_catalog(data)[1]: 198 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic") 199 providers = get_all_providers() 200 assert "openai" in providers 201 assert "anthropic" in providers 202 assert "gemini" not in providers 203 204 205 def test_get_models_filters_with_allowed_providers(monkeypatch): 206 data = { 207 "openai": {"gpt-4o": {"mode": "chat", "supports_function_calling": True}}, 208 "anthropic": {"claude-3-5-sonnet": {"mode": "chat", "supports_function_calling": True}}, 209 "gemini": {"gemini-1.5-pro": {"mode": "chat", "supports_function_calling": True}}, 210 } 211 with _mock_catalog(data)[0], _mock_catalog(data)[1]: 212 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai") 213 models = get_models() 214 providers_in_result = {m["provider"] for m in models} 215 assert providers_in_result == {"openai"} 216 217 218 def test_get_provider_config_rejects_provider_not_in_allowed_list(monkeypatch): 219 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "anthropic") 220 with pytest.raises(MlflowException, match="not allowed"): 221 get_provider_config_response("openai") 222 223 224 def test_get_provider_config_bedrock_has_default_chain(): 225 config = get_provider_config_response("bedrock") 226 modes = {m["mode"] for m in config["auth_modes"]} 227 assert "default_chain" in modes 228 229 default_chain = next(m for m in config["auth_modes"] if m["mode"] == "default_chain") 230 assert default_chain["display_name"] == "Default Credential Chain" 231 assert default_chain["secret_fields"] == [] 232 assert all(not f["required"] for f in default_chain["config_fields"]) 233 234 235 def test_get_provider_config_sagemaker_has_default_chain(): 236 config = get_provider_config_response("sagemaker") 237 modes = {m["mode"] for m in config["auth_modes"]} 238 assert "default_chain" in modes 239 240 241 def test_get_provider_config_vertex_ai_has_default_chain(): 242 config = get_provider_config_response("vertex_ai") 243 modes = {m["mode"] for m in config["auth_modes"]} 244 assert "default_chain" in modes 245 246 default_chain = next(m for m in config["auth_modes"] if m["mode"] == "default_chain") 247 assert default_chain["display_name"] == "Application Default Credentials" 248 assert default_chain["secret_fields"] == [] 249 project_field = next(f for f in default_chain["config_fields"] if f["name"] == "vertex_project") 250 assert project_field["required"] is True 251 252 253 _MOCK_PROVIDER_DATA = { 254 "test_provider": { 255 "test-model": { 256 "input_cost_per_token": 1e-6, 257 "output_cost_per_token": 2e-6, 258 "cache_read_input_token_cost": 5e-7, 259 "cache_creation_input_token_cost": 3e-6, 260 }, 261 }, 262 "openai": { 263 "test-provider-model": { 264 "input_cost_per_token": 1e-6, 265 "output_cost_per_token": 2e-6, 266 }, 267 }, 268 } 269 270 271 def _mock_load_provider(provider): 272 return _MOCK_PROVIDER_DATA.get(provider, {}) 273 274 275 @pytest.fixture 276 def mock_model_cost(): 277 with ( 278 mock.patch( 279 "mlflow.utils.providers._load_provider", side_effect=_mock_load_provider 280 ) as m_load, 281 mock.patch( 282 "mlflow.utils.providers._load_bundled_provider", side_effect=_mock_load_provider 283 ), 284 mock.patch( 285 "mlflow.utils.providers._list_provider_names", 286 return_value=list(_MOCK_PROVIDER_DATA.keys()), 287 ), 288 ): 289 yield m_load 290 291 292 def test_cost_per_token_basic(mock_model_cost): 293 input_cost, output_cost = cost_per_token( 294 model="test-model", prompt_tokens=1000, completion_tokens=500 295 ) 296 # input: 1000 * 1e-6 = 0.001, output: 500 * 2e-6 = 0.001 297 assert input_cost == pytest.approx(0.001) 298 assert output_cost == pytest.approx(0.001) 299 300 301 def test_cost_per_token_with_provider_prefix(mock_model_cost): 302 # "test-provider-model" only exists under "openai/" prefix, so provider lookup is exercised 303 input_cost, output_cost = cost_per_token( 304 model="test-provider-model", 305 prompt_tokens=1000, 306 completion_tokens=500, 307 custom_llm_provider="openai", 308 ) 309 assert input_cost == pytest.approx(0.001) 310 assert output_cost == pytest.approx(0.001) 311 312 313 def test_cost_per_token_strips_provider_prefix(mock_model_cost): 314 # "openai/test-model" should resolve to "test-model" by stripping the prefix 315 input_cost, output_cost = cost_per_token( 316 model="openai/test-model", prompt_tokens=1000, completion_tokens=500 317 ) 318 assert input_cost == pytest.approx(0.001) 319 assert output_cost == pytest.approx(0.001) 320 321 322 def test_cost_per_token_cache_read_tokens(mock_model_cost): 323 input_cost, output_cost = cost_per_token( 324 model="test-model", 325 prompt_tokens=1000, 326 completion_tokens=500, 327 cache_read_input_tokens=200, 328 ) 329 # regular: (1000-200) * 1e-6 = 0.0008 330 # cache_read: 200 * 5e-7 = 0.0001 331 assert input_cost == pytest.approx(0.0009) 332 assert output_cost == pytest.approx(0.001) 333 334 335 def test_cost_per_token_cache_creation_tokens(mock_model_cost): 336 input_cost, output_cost = cost_per_token( 337 model="test-model", 338 prompt_tokens=1000, 339 completion_tokens=500, 340 cache_creation_input_tokens=300, 341 ) 342 # regular: (1000-300) * 1e-6 = 0.0007 343 # cache_creation: 300 * 3e-6 = 0.0009 344 assert input_cost == pytest.approx(0.0016) 345 assert output_cost == pytest.approx(0.001) 346 347 348 def test_cost_per_token_zero_tokens(mock_model_cost): 349 input_cost, output_cost = cost_per_token( 350 model="test-model", prompt_tokens=0, completion_tokens=0 351 ) 352 assert input_cost == 0.0 353 assert output_cost == 0.0 354 355 356 def test_cost_per_token_unknown_model_returns_none(mock_model_cost): 357 assert cost_per_token(model="totally-unknown-model", prompt_tokens=100) is None 358 359 360 def test_cost_per_token_unknown_model_with_provider_returns_none(mock_model_cost): 361 assert ( 362 cost_per_token( 363 model="totally-unknown-model", 364 prompt_tokens=100, 365 custom_llm_provider="unknown-provider", 366 ) 367 is None 368 ) 369 370 371 def test_cost_per_token_no_cache_cost_falls_back_to_input_rate(): 372 no_cache_data = { 373 "nocache_provider": { 374 "test-model": { 375 "input_cost_per_token": 1e-6, 376 "output_cost_per_token": 2e-6, 377 } 378 } 379 } 380 with ( 381 mock.patch( 382 "mlflow.utils.providers._load_provider", 383 side_effect=lambda p: no_cache_data.get(p, {}), 384 ), 385 mock.patch( 386 "mlflow.utils.providers._load_bundled_provider", 387 side_effect=lambda p: no_cache_data.get(p, {}), 388 ), 389 mock.patch( 390 "mlflow.utils.providers._list_provider_names", 391 return_value=list(no_cache_data.keys()), 392 ), 393 ): 394 input_cost, output_cost = cost_per_token( 395 model="test-model", 396 prompt_tokens=1000, 397 completion_tokens=500, 398 cache_read_input_tokens=200, 399 ) 400 # No cache_read_input_token_cost, falls back to input_cost_per_token 401 # regular: 800 * 1e-6 = 0.0008 402 # cache_read: 200 * 1e-6 = 0.0002 (same rate as regular) 403 assert input_cost == pytest.approx(0.001) 404 assert output_cost == pytest.approx(0.001) 405 406 407 def test_flatten_catalog_entry(): 408 entry = { 409 "mode": "chat", 410 "context_window": {"max_input": 128000, "max_output": 16384}, 411 "pricing": { 412 "input_per_million_tokens": 2.5, 413 "output_per_million_tokens": 10.0, 414 "cache_read_per_million_tokens": 1.25, 415 "cache_write_per_million_tokens": 5.0, 416 }, 417 "capabilities": { 418 "function_calling": True, 419 "vision": True, 420 "reasoning": False, 421 "prompt_caching": True, 422 "response_schema": True, 423 }, 424 "deprecation_date": "2026-01-01", 425 } 426 info = _flatten_catalog_entry(entry) 427 assert info["mode"] == "chat" 428 assert info["max_input_tokens"] == 128000 429 assert info["max_output_tokens"] == 16384 430 assert info["input_cost_per_token"] == pytest.approx(2.5e-6) 431 assert info["output_cost_per_token"] == pytest.approx(1e-5) 432 assert info["cache_read_input_token_cost"] == pytest.approx(1.25e-6) 433 assert info["cache_creation_input_token_cost"] == pytest.approx(5e-6) 434 assert info["supports_function_calling"] is True 435 assert info["supports_vision"] is True 436 assert info["supports_reasoning"] is False 437 assert info["deprecation_date"] == "2026-01-01" 438 439 440 def test_flatten_catalog_entry_with_last_updated_at(): 441 entry = { 442 "mode": "chat", 443 "capabilities": { 444 "function_calling": False, 445 "vision": False, 446 "reasoning": False, 447 "prompt_caching": False, 448 "response_schema": False, 449 }, 450 "last_updated_at": "2025-01-15", 451 } 452 info = _flatten_catalog_entry(entry) 453 assert info["last_updated_at"] == "2025-01-15" 454 455 456 def test_flatten_catalog_entry_without_last_updated_at(): 457 entry = { 458 "mode": "chat", 459 "capabilities": { 460 "function_calling": False, 461 "vision": False, 462 "reasoning": False, 463 "prompt_caching": False, 464 "response_schema": False, 465 }, 466 } 467 info = _flatten_catalog_entry(entry) 468 assert "last_updated_at" not in info 469 470 471 def test_load_bundled_provider_returns_data(): 472 _load_bundled_provider.cache_clear() 473 result = _load_bundled_provider("openai") 474 assert len(result) > 0 475 assert "gpt-4o" in result 476 info = result["gpt-4o"] 477 assert info["mode"] == "chat" 478 assert "input_cost_per_token" in info 479 480 481 def test_load_provider_uses_remote_when_available(): 482 remote_data = {"test-model": {"mode": "chat", "input_cost_per_token": 1e-6}} 483 with mock.patch( 484 "mlflow.utils.providers._fetch_remote_provider", return_value=remote_data 485 ) as mock_remote: 486 result = _load_provider("openai") 487 mock_remote.assert_called_once_with("openai") 488 assert result is remote_data 489 490 491 def test_load_provider_falls_back_to_bundled_when_remote_fails(): 492 with mock.patch( 493 "mlflow.utils.providers._fetch_remote_provider", return_value=None 494 ) as mock_remote: 495 result = _load_provider("openai") 496 mock_remote.assert_called_once_with("openai") 497 assert len(result) > 0 498 assert "gpt-4o" in result 499 500 501 def test_fetch_remote_provider_disabled_when_url_empty(monkeypatch): 502 monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "") 503 assert _fetch_remote_provider("openai") is None 504 505 506 def test_fetch_remote_provider_supports_file_url(tmp_path, monkeypatch): 507 catalog = { 508 "schema_version": "1.0", 509 "models": {"test-model": {"mode": "chat", "pricing": {"input_per_million_tokens": 1.0}}}, 510 } 511 (tmp_path / "test_provider.json").write_text(json.dumps(catalog)) 512 monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", tmp_path.as_uri()) 513 _get_remote_cache().clear() 514 result = _fetch_remote_provider("test_provider") 515 assert result is not None 516 assert "test-model" in result