/ tests / utils / test_providers.py
test_providers.py
  1  import json
  2  from unittest import mock
  3  
  4  import pytest
  5  
  6  from mlflow.exceptions import MlflowException
  7  from mlflow.utils.providers import (
  8      _fetch_remote_provider,
  9      _flatten_catalog_entry,
 10      _get_remote_cache,
 11      _list_provider_names,
 12      _load_bundled_provider,
 13      _load_provider,
 14      _normalize_provider,
 15      cost_per_token,
 16      get_all_providers,
 17      get_models,
 18      get_provider_config_response,
 19  )
 20  
 21  
 22  def test_normalize_provider_normalizes_vertex_ai_variants():
 23      assert _normalize_provider("vertex_ai") == "vertex_ai"
 24      assert _normalize_provider("vertex_ai-anthropic") == "vertex_ai"
 25      assert _normalize_provider("vertex_ai-llama_models") == "vertex_ai"
 26      assert _normalize_provider("vertex_ai-mistral") == "vertex_ai"
 27  
 28  
 29  def test_normalize_provider_does_not_normalize_other_providers():
 30      assert _normalize_provider("openai") == "openai"
 31      assert _normalize_provider("anthropic") == "anthropic"
 32      assert _normalize_provider("bedrock") == "bedrock"
 33      assert _normalize_provider("gemini") == "gemini"
 34  
 35  
 36  def test_list_provider_names_returns_bundled_providers():
 37      _list_provider_names.cache_clear()
 38      providers = _list_provider_names()
 39      assert len(providers) > 0
 40      assert "openai" in providers
 41      assert "anthropic" in providers
 42      assert "bedrock" in providers
 43  
 44  
 45  def test_list_provider_names_excludes_non_json():
 46      _list_provider_names.cache_clear()
 47      providers = _list_provider_names()
 48      # __init__.py should not appear
 49      assert "__init__" not in providers
 50      for p in providers:
 51          assert not p.endswith(".py")
 52  
 53  
 54  def test_load_provider_returns_models(monkeypatch):
 55      monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "")
 56      _load_bundled_provider.cache_clear()
 57      models = _load_provider("openai")
 58      assert len(models) > 0
 59      assert "gpt-4o" in models
 60      info = models["gpt-4o"]
 61      assert info["mode"] == "chat"
 62      assert "input_cost_per_token" in info
 63      assert info["input_cost_per_token"] > 0
 64  
 65  
 66  def test_load_provider_returns_empty_for_unknown(monkeypatch):
 67      monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "")
 68      _load_bundled_provider.cache_clear()
 69      assert _load_provider("nonexistent_provider_xyz") == {}
 70  
 71  
 72  def test_load_provider_flattens_pricing(monkeypatch):
 73      monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "")
 74      _load_bundled_provider.cache_clear()
 75      models = _load_provider("anthropic")
 76      model = next(iter(models.values()))
 77      # Should have flat ModelInfo keys, not nested pricing/capabilities
 78      assert "input_cost_per_token" in model or "mode" in model
 79      assert "pricing" not in model
 80      assert "context_window" not in model
 81  
 82  
 83  def _mock_catalog(provider_data):
 84      """Context manager that mocks the per-provider catalog with the given data.
 85  
 86      ``provider_data`` is a dict mapping provider names to ``{model_name: info}`` dicts.
 87      """
 88      return (
 89          mock.patch(
 90              "mlflow.utils.providers._load_provider",
 91              side_effect=lambda p: provider_data.get(p, {}),
 92          ),
 93          mock.patch(
 94              "mlflow.utils.providers._list_provider_names",
 95              return_value=list(provider_data.keys()),
 96          ),
 97      )
 98  
 99  
100  def test_get_all_providers_consolidates_vertex_ai_variants():
101      data = {
102          "openai": {"gpt-4o": {"mode": "chat"}},
103          "anthropic": {"claude-3-5-sonnet": {"mode": "chat"}},
104          "vertex_ai": {"gemini-1.5-pro": {"mode": "chat"}},
105          "vertex_ai-llama_models": {"meta/llama-4-scout": {"mode": "chat"}},
106          "vertex_ai-anthropic": {"claude-3-5-sonnet": {"mode": "chat"}},
107      }
108      with _mock_catalog(data)[0], _mock_catalog(data)[1]:
109          providers = get_all_providers()
110  
111          assert "vertex_ai" in providers
112          assert "vertex_ai-llama_models" not in providers
113          assert "vertex_ai-anthropic" not in providers
114          assert "openai" in providers
115          assert "anthropic" in providers
116  
117  
118  def test_get_models_normalizes_vertex_ai_provider_and_strips_prefix():
119      data = {
120          "vertex_ai-llama_models": {
121              "meta/llama-4-scout-17b-16e-instruct-maas": {
122                  "mode": "chat",
123                  "supports_function_calling": True,
124              }
125          },
126          "vertex_ai-anthropic": {
127              "claude-3-5-sonnet": {"mode": "chat", "supports_function_calling": True}
128          },
129          "vertex_ai": {"gemini-1.5-pro": {"mode": "chat", "supports_function_calling": True}},
130      }
131      with _mock_catalog(data)[0], _mock_catalog(data)[1]:
132          models = get_models(provider="vertex_ai")
133  
134          assert len(models) == 3
135          for model in models:
136              assert model["provider"] == "vertex_ai"
137  
138          model_names = [m["model"] for m in models]
139          assert "meta/llama-4-scout-17b-16e-instruct-maas" in model_names
140          assert "claude-3-5-sonnet" in model_names
141          assert "gemini-1.5-pro" in model_names
142  
143  
144  def test_get_models_filters_by_consolidated_provider():
145      data = {
146          "openai": {"gpt-4o": {"mode": "chat"}},
147          "vertex_ai-llama_models": {"meta/llama-4-scout": {"mode": "chat"}},
148      }
149      with _mock_catalog(data)[0], _mock_catalog(data)[1]:
150          vertex_models = get_models(provider="vertex_ai")
151          assert len(vertex_models) == 1
152          assert vertex_models[0]["model"] == "meta/llama-4-scout"
153  
154          openai_models = get_models(provider="openai")
155          assert len(openai_models) == 1
156          assert openai_models[0]["model"] == "gpt-4o"
157  
158  
159  def test_get_models_does_not_modify_other_providers():
160      data = {
161          "openai": {"gpt-4o": {"mode": "chat", "supports_function_calling": True}},
162          "anthropic": {"claude-3-5-sonnet": {"mode": "chat", "supports_function_calling": True}},
163      }
164      with _mock_catalog(data)[0], _mock_catalog(data)[1]:
165          models = get_models()
166  
167          openai_model = next(m for m in models if m["provider"] == "openai")
168          assert openai_model["model"] == "gpt-4o"
169  
170          anthropic_model = next(m for m in models if m["provider"] == "anthropic")
171          assert anthropic_model["model"] == "claude-3-5-sonnet"
172  
173  
174  def test_get_models_dedupes_models_after_normalization():
175      data = {
176          "vertex_ai": {
177              "gemini-3-flash-preview": {"mode": "chat", "supports_function_calling": True}
178          },
179          "vertex_ai-chat-models": {
180              "gemini-3-flash-preview": {"mode": "chat", "supports_function_calling": True}
181          },
182      }
183      with _mock_catalog(data)[0], _mock_catalog(data)[1]:
184          models = get_models(provider="vertex_ai")
185  
186          model_names = [m["model"] for m in models]
187          assert model_names.count("gemini-3-flash-preview") == 1
188          assert len(models) == 1
189  
190  
191  def test_get_all_providers_with_allowed_filter(monkeypatch):
192      data = {
193          "openai": {"gpt-4o": {"mode": "chat"}},
194          "anthropic": {"claude-3-5-sonnet": {"mode": "chat"}},
195          "gemini": {"gemini-1.5-pro": {"mode": "chat"}},
196      }
197      with _mock_catalog(data)[0], _mock_catalog(data)[1]:
198          monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic")
199          providers = get_all_providers()
200          assert "openai" in providers
201          assert "anthropic" in providers
202          assert "gemini" not in providers
203  
204  
205  def test_get_models_filters_with_allowed_providers(monkeypatch):
206      data = {
207          "openai": {"gpt-4o": {"mode": "chat", "supports_function_calling": True}},
208          "anthropic": {"claude-3-5-sonnet": {"mode": "chat", "supports_function_calling": True}},
209          "gemini": {"gemini-1.5-pro": {"mode": "chat", "supports_function_calling": True}},
210      }
211      with _mock_catalog(data)[0], _mock_catalog(data)[1]:
212          monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai")
213          models = get_models()
214          providers_in_result = {m["provider"] for m in models}
215          assert providers_in_result == {"openai"}
216  
217  
218  def test_get_provider_config_rejects_provider_not_in_allowed_list(monkeypatch):
219      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "anthropic")
220      with pytest.raises(MlflowException, match="not allowed"):
221          get_provider_config_response("openai")
222  
223  
224  def test_get_provider_config_bedrock_has_default_chain():
225      config = get_provider_config_response("bedrock")
226      modes = {m["mode"] for m in config["auth_modes"]}
227      assert "default_chain" in modes
228  
229      default_chain = next(m for m in config["auth_modes"] if m["mode"] == "default_chain")
230      assert default_chain["display_name"] == "Default Credential Chain"
231      assert default_chain["secret_fields"] == []
232      assert all(not f["required"] for f in default_chain["config_fields"])
233  
234  
235  def test_get_provider_config_sagemaker_has_default_chain():
236      config = get_provider_config_response("sagemaker")
237      modes = {m["mode"] for m in config["auth_modes"]}
238      assert "default_chain" in modes
239  
240  
241  def test_get_provider_config_vertex_ai_has_default_chain():
242      config = get_provider_config_response("vertex_ai")
243      modes = {m["mode"] for m in config["auth_modes"]}
244      assert "default_chain" in modes
245  
246      default_chain = next(m for m in config["auth_modes"] if m["mode"] == "default_chain")
247      assert default_chain["display_name"] == "Application Default Credentials"
248      assert default_chain["secret_fields"] == []
249      project_field = next(f for f in default_chain["config_fields"] if f["name"] == "vertex_project")
250      assert project_field["required"] is True
251  
252  
253  _MOCK_PROVIDER_DATA = {
254      "test_provider": {
255          "test-model": {
256              "input_cost_per_token": 1e-6,
257              "output_cost_per_token": 2e-6,
258              "cache_read_input_token_cost": 5e-7,
259              "cache_creation_input_token_cost": 3e-6,
260          },
261      },
262      "openai": {
263          "test-provider-model": {
264              "input_cost_per_token": 1e-6,
265              "output_cost_per_token": 2e-6,
266          },
267      },
268  }
269  
270  
271  def _mock_load_provider(provider):
272      return _MOCK_PROVIDER_DATA.get(provider, {})
273  
274  
275  @pytest.fixture
276  def mock_model_cost():
277      with (
278          mock.patch(
279              "mlflow.utils.providers._load_provider", side_effect=_mock_load_provider
280          ) as m_load,
281          mock.patch(
282              "mlflow.utils.providers._load_bundled_provider", side_effect=_mock_load_provider
283          ),
284          mock.patch(
285              "mlflow.utils.providers._list_provider_names",
286              return_value=list(_MOCK_PROVIDER_DATA.keys()),
287          ),
288      ):
289          yield m_load
290  
291  
292  def test_cost_per_token_basic(mock_model_cost):
293      input_cost, output_cost = cost_per_token(
294          model="test-model", prompt_tokens=1000, completion_tokens=500
295      )
296      # input: 1000 * 1e-6 = 0.001, output: 500 * 2e-6 = 0.001
297      assert input_cost == pytest.approx(0.001)
298      assert output_cost == pytest.approx(0.001)
299  
300  
301  def test_cost_per_token_with_provider_prefix(mock_model_cost):
302      # "test-provider-model" only exists under "openai/" prefix, so provider lookup is exercised
303      input_cost, output_cost = cost_per_token(
304          model="test-provider-model",
305          prompt_tokens=1000,
306          completion_tokens=500,
307          custom_llm_provider="openai",
308      )
309      assert input_cost == pytest.approx(0.001)
310      assert output_cost == pytest.approx(0.001)
311  
312  
313  def test_cost_per_token_strips_provider_prefix(mock_model_cost):
314      # "openai/test-model" should resolve to "test-model" by stripping the prefix
315      input_cost, output_cost = cost_per_token(
316          model="openai/test-model", prompt_tokens=1000, completion_tokens=500
317      )
318      assert input_cost == pytest.approx(0.001)
319      assert output_cost == pytest.approx(0.001)
320  
321  
322  def test_cost_per_token_cache_read_tokens(mock_model_cost):
323      input_cost, output_cost = cost_per_token(
324          model="test-model",
325          prompt_tokens=1000,
326          completion_tokens=500,
327          cache_read_input_tokens=200,
328      )
329      # regular: (1000-200) * 1e-6 = 0.0008
330      # cache_read: 200 * 5e-7 = 0.0001
331      assert input_cost == pytest.approx(0.0009)
332      assert output_cost == pytest.approx(0.001)
333  
334  
335  def test_cost_per_token_cache_creation_tokens(mock_model_cost):
336      input_cost, output_cost = cost_per_token(
337          model="test-model",
338          prompt_tokens=1000,
339          completion_tokens=500,
340          cache_creation_input_tokens=300,
341      )
342      # regular: (1000-300) * 1e-6 = 0.0007
343      # cache_creation: 300 * 3e-6 = 0.0009
344      assert input_cost == pytest.approx(0.0016)
345      assert output_cost == pytest.approx(0.001)
346  
347  
348  def test_cost_per_token_zero_tokens(mock_model_cost):
349      input_cost, output_cost = cost_per_token(
350          model="test-model", prompt_tokens=0, completion_tokens=0
351      )
352      assert input_cost == 0.0
353      assert output_cost == 0.0
354  
355  
356  def test_cost_per_token_unknown_model_returns_none(mock_model_cost):
357      assert cost_per_token(model="totally-unknown-model", prompt_tokens=100) is None
358  
359  
360  def test_cost_per_token_unknown_model_with_provider_returns_none(mock_model_cost):
361      assert (
362          cost_per_token(
363              model="totally-unknown-model",
364              prompt_tokens=100,
365              custom_llm_provider="unknown-provider",
366          )
367          is None
368      )
369  
370  
371  def test_cost_per_token_no_cache_cost_falls_back_to_input_rate():
372      no_cache_data = {
373          "nocache_provider": {
374              "test-model": {
375                  "input_cost_per_token": 1e-6,
376                  "output_cost_per_token": 2e-6,
377              }
378          }
379      }
380      with (
381          mock.patch(
382              "mlflow.utils.providers._load_provider",
383              side_effect=lambda p: no_cache_data.get(p, {}),
384          ),
385          mock.patch(
386              "mlflow.utils.providers._load_bundled_provider",
387              side_effect=lambda p: no_cache_data.get(p, {}),
388          ),
389          mock.patch(
390              "mlflow.utils.providers._list_provider_names",
391              return_value=list(no_cache_data.keys()),
392          ),
393      ):
394          input_cost, output_cost = cost_per_token(
395              model="test-model",
396              prompt_tokens=1000,
397              completion_tokens=500,
398              cache_read_input_tokens=200,
399          )
400          # No cache_read_input_token_cost, falls back to input_cost_per_token
401          # regular: 800 * 1e-6 = 0.0008
402          # cache_read: 200 * 1e-6 = 0.0002 (same rate as regular)
403          assert input_cost == pytest.approx(0.001)
404          assert output_cost == pytest.approx(0.001)
405  
406  
407  def test_flatten_catalog_entry():
408      entry = {
409          "mode": "chat",
410          "context_window": {"max_input": 128000, "max_output": 16384},
411          "pricing": {
412              "input_per_million_tokens": 2.5,
413              "output_per_million_tokens": 10.0,
414              "cache_read_per_million_tokens": 1.25,
415              "cache_write_per_million_tokens": 5.0,
416          },
417          "capabilities": {
418              "function_calling": True,
419              "vision": True,
420              "reasoning": False,
421              "prompt_caching": True,
422              "response_schema": True,
423          },
424          "deprecation_date": "2026-01-01",
425      }
426      info = _flatten_catalog_entry(entry)
427      assert info["mode"] == "chat"
428      assert info["max_input_tokens"] == 128000
429      assert info["max_output_tokens"] == 16384
430      assert info["input_cost_per_token"] == pytest.approx(2.5e-6)
431      assert info["output_cost_per_token"] == pytest.approx(1e-5)
432      assert info["cache_read_input_token_cost"] == pytest.approx(1.25e-6)
433      assert info["cache_creation_input_token_cost"] == pytest.approx(5e-6)
434      assert info["supports_function_calling"] is True
435      assert info["supports_vision"] is True
436      assert info["supports_reasoning"] is False
437      assert info["deprecation_date"] == "2026-01-01"
438  
439  
440  def test_flatten_catalog_entry_with_last_updated_at():
441      entry = {
442          "mode": "chat",
443          "capabilities": {
444              "function_calling": False,
445              "vision": False,
446              "reasoning": False,
447              "prompt_caching": False,
448              "response_schema": False,
449          },
450          "last_updated_at": "2025-01-15",
451      }
452      info = _flatten_catalog_entry(entry)
453      assert info["last_updated_at"] == "2025-01-15"
454  
455  
456  def test_flatten_catalog_entry_without_last_updated_at():
457      entry = {
458          "mode": "chat",
459          "capabilities": {
460              "function_calling": False,
461              "vision": False,
462              "reasoning": False,
463              "prompt_caching": False,
464              "response_schema": False,
465          },
466      }
467      info = _flatten_catalog_entry(entry)
468      assert "last_updated_at" not in info
469  
470  
471  def test_load_bundled_provider_returns_data():
472      _load_bundled_provider.cache_clear()
473      result = _load_bundled_provider("openai")
474      assert len(result) > 0
475      assert "gpt-4o" in result
476      info = result["gpt-4o"]
477      assert info["mode"] == "chat"
478      assert "input_cost_per_token" in info
479  
480  
481  def test_load_provider_uses_remote_when_available():
482      remote_data = {"test-model": {"mode": "chat", "input_cost_per_token": 1e-6}}
483      with mock.patch(
484          "mlflow.utils.providers._fetch_remote_provider", return_value=remote_data
485      ) as mock_remote:
486          result = _load_provider("openai")
487          mock_remote.assert_called_once_with("openai")
488          assert result is remote_data
489  
490  
491  def test_load_provider_falls_back_to_bundled_when_remote_fails():
492      with mock.patch(
493          "mlflow.utils.providers._fetch_remote_provider", return_value=None
494      ) as mock_remote:
495          result = _load_provider("openai")
496          mock_remote.assert_called_once_with("openai")
497          assert len(result) > 0
498          assert "gpt-4o" in result
499  
500  
501  def test_fetch_remote_provider_disabled_when_url_empty(monkeypatch):
502      monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", "")
503      assert _fetch_remote_provider("openai") is None
504  
505  
506  def test_fetch_remote_provider_supports_file_url(tmp_path, monkeypatch):
507      catalog = {
508          "schema_version": "1.0",
509          "models": {"test-model": {"mode": "chat", "pricing": {"input_per_million_tokens": 1.0}}},
510      }
511      (tmp_path / "test_provider.json").write_text(json.dumps(catalog))
512      monkeypatch.setenv("MLFLOW_MODEL_CATALOG_URI", tmp_path.as_uri())
513      _get_remote_cache().clear()
514      result = _fetch_remote_provider("test_provider")
515      assert result is not None
516      assert "test-model" in result