provider_filter.py
1 import logging 2 import threading 3 4 from cachetools import LRUCache 5 from cachetools.func import cached 6 7 from mlflow.environment_variables import MLFLOW_GATEWAY_ALLOWED_PROVIDERS 8 9 _logger = logging.getLogger(__name__) 10 11 # Single source of truth for provider name aliases (string-level). 12 _PROVIDER_ALIASES: dict[str, str] = { 13 "amazon-bedrock": "bedrock", 14 "databricks-model-serving": "databricks", 15 } 16 17 _provider_filter_cache = LRUCache(maxsize=16) 18 _provider_filter_cache_lock = threading.RLock() 19 20 21 def normalize_provider_name(name: str) -> str: 22 return _PROVIDER_ALIASES.get(name, name) 23 24 25 def _parse_provider_list(value: str | None) -> frozenset[str]: 26 if not value: 27 return frozenset() 28 return frozenset( 29 normalize_provider_name(p.strip().lower()) for p in value.split(",") if p.strip() 30 ) 31 32 33 @cached(cache=_provider_filter_cache, lock=_provider_filter_cache_lock) 34 def _parse_allowed_providers(allowed_raw: str | None) -> frozenset[str] | None: 35 return _parse_provider_list(allowed_raw) or None 36 37 38 def _get_allowed_providers() -> frozenset[str] | None: 39 return _parse_allowed_providers(MLFLOW_GATEWAY_ALLOWED_PROVIDERS.get()) 40 41 42 def is_provider_allowed(provider_name: str) -> bool: 43 allowed = _get_allowed_providers() 44 if allowed is None: 45 return True 46 name = normalize_provider_name(provider_name.lower()) 47 return name in allowed 48 49 50 def filter_providers(providers: list[str]) -> list[str]: 51 allowed = _get_allowed_providers() 52 if allowed is None: 53 return providers 54 55 result = [] 56 for p in providers: 57 name = normalize_provider_name(p.lower()) 58 if name not in allowed: 59 _logger.debug("Provider '%s' is not in MLFLOW_GATEWAY_ALLOWED_PROVIDERS", p) 60 continue 61 result.append(p) 62 return result