/ tests / utils / test_provider_filter.py
test_provider_filter.py
 1  import pytest
 2  
 3  from mlflow.utils.provider_filter import (
 4      _parse_provider_list,
 5      filter_providers,
 6      is_provider_allowed,
 7  )
 8  
 9  
10  @pytest.mark.parametrize(
11      ("input_value", "expected"),
12      [
13          (None, frozenset()),
14          ("", frozenset()),
15          ("openai", frozenset({"openai"})),
16          ("openai,anthropic", frozenset({"openai", "anthropic"})),
17          ("openai, anthropic, gemini", frozenset({"openai", "anthropic", "gemini"})),
18          (" openai , anthropic ", frozenset({"openai", "anthropic"})),
19          ("OpenAI,Anthropic", frozenset({"openai", "anthropic"})),
20          ("openai,,anthropic", frozenset({"openai", "anthropic"})),
21          ("  ,  ,  ", frozenset()),
22          ("amazon-bedrock", frozenset({"bedrock"})),
23          ("amazon-bedrock,openai", frozenset({"bedrock", "openai"})),
24      ],
25  )
26  def test_parse_provider_list(input_value, expected):
27      assert _parse_provider_list(input_value) == expected
28  
29  
30  def test_filter_providers_no_filter():
31      providers = ["openai", "anthropic", "gemini"]
32      assert filter_providers(providers) == providers
33  
34  
35  def test_filter_providers_with_allowed_list(monkeypatch):
36      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic")
37      result = filter_providers(["openai", "anthropic", "gemini", "bedrock"])
38      assert result == ["openai", "anthropic"]
39  
40  
41  def test_filter_providers_case_insensitive(monkeypatch):
42      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "OpenAI,ANTHROPIC")
43      result = filter_providers(["openai", "anthropic", "gemini"])
44      assert result == ["openai", "anthropic"]
45  
46  
47  def test_is_provider_allowed_no_filter():
48      assert is_provider_allowed("openai") is True
49      assert is_provider_allowed("litellm") is True
50  
51  
52  def test_is_provider_allowed_with_allowed_list(monkeypatch):
53      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic")
54      assert is_provider_allowed("openai") is True
55      assert is_provider_allowed("anthropic") is True
56      assert is_provider_allowed("gemini") is False
57      assert is_provider_allowed("litellm") is False
58  
59  
60  def test_is_provider_allowed_case_insensitive_with_allowed_list(monkeypatch):
61      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai")
62      assert is_provider_allowed("OpenAI") is True
63      assert is_provider_allowed("OPENAI") is True
64      assert is_provider_allowed("openai") is True
65  
66  
67  def test_is_provider_allowed_bedrock_in_allowed_list(monkeypatch):
68      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "bedrock")
69      assert is_provider_allowed("bedrock") is True
70      assert is_provider_allowed("amazon-bedrock") is True
71      assert is_provider_allowed("openai") is False
72  
73  
74  def test_filter_providers_normalizes_bedrock_alias(monkeypatch):
75      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "bedrock,openai")
76      result = filter_providers(["openai", "bedrock", "amazon-bedrock", "anthropic"])
77      assert result == ["openai", "bedrock", "amazon-bedrock"]