test_provider_filter.py
1 import pytest 2 3 from mlflow.utils.provider_filter import ( 4 _parse_provider_list, 5 filter_providers, 6 is_provider_allowed, 7 ) 8 9 10 @pytest.mark.parametrize( 11 ("input_value", "expected"), 12 [ 13 (None, frozenset()), 14 ("", frozenset()), 15 ("openai", frozenset({"openai"})), 16 ("openai,anthropic", frozenset({"openai", "anthropic"})), 17 ("openai, anthropic, gemini", frozenset({"openai", "anthropic", "gemini"})), 18 (" openai , anthropic ", frozenset({"openai", "anthropic"})), 19 ("OpenAI,Anthropic", frozenset({"openai", "anthropic"})), 20 ("openai,,anthropic", frozenset({"openai", "anthropic"})), 21 (" , , ", frozenset()), 22 ("amazon-bedrock", frozenset({"bedrock"})), 23 ("amazon-bedrock,openai", frozenset({"bedrock", "openai"})), 24 ], 25 ) 26 def test_parse_provider_list(input_value, expected): 27 assert _parse_provider_list(input_value) == expected 28 29 30 def test_filter_providers_no_filter(): 31 providers = ["openai", "anthropic", "gemini"] 32 assert filter_providers(providers) == providers 33 34 35 def test_filter_providers_with_allowed_list(monkeypatch): 36 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic") 37 result = filter_providers(["openai", "anthropic", "gemini", "bedrock"]) 38 assert result == ["openai", "anthropic"] 39 40 41 def test_filter_providers_case_insensitive(monkeypatch): 42 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "OpenAI,ANTHROPIC") 43 result = filter_providers(["openai", "anthropic", "gemini"]) 44 assert result == ["openai", "anthropic"] 45 46 47 def test_is_provider_allowed_no_filter(): 48 assert is_provider_allowed("openai") is True 49 assert is_provider_allowed("litellm") is True 50 51 52 def test_is_provider_allowed_with_allowed_list(monkeypatch): 53 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic") 54 assert is_provider_allowed("openai") is True 55 assert is_provider_allowed("anthropic") is True 56 assert is_provider_allowed("gemini") is False 57 assert is_provider_allowed("litellm") is False 58 59 60 def test_is_provider_allowed_case_insensitive_with_allowed_list(monkeypatch): 61 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai") 62 assert is_provider_allowed("OpenAI") is True 63 assert is_provider_allowed("OPENAI") is True 64 assert is_provider_allowed("openai") is True 65 66 67 def test_is_provider_allowed_bedrock_in_allowed_list(monkeypatch): 68 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "bedrock") 69 assert is_provider_allowed("bedrock") is True 70 assert is_provider_allowed("amazon-bedrock") is True 71 assert is_provider_allowed("openai") is False 72 73 74 def test_filter_providers_normalizes_bedrock_alias(monkeypatch): 75 monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "bedrock,openai") 76 result = filter_providers(["openai", "bedrock", "amazon-bedrock", "anthropic"]) 77 assert result == ["openai", "bedrock", "amazon-bedrock"]