/ mlflow / utils / provider_filter.py
provider_filter.py
 1  import logging
 2  import threading
 3  
 4  from cachetools import LRUCache
 5  from cachetools.func import cached
 6  
 7  from mlflow.environment_variables import MLFLOW_GATEWAY_ALLOWED_PROVIDERS
 8  
 9  _logger = logging.getLogger(__name__)
10  
11  # Single source of truth for provider name aliases (string-level).
12  _PROVIDER_ALIASES: dict[str, str] = {
13      "amazon-bedrock": "bedrock",
14      "databricks-model-serving": "databricks",
15  }
16  
17  _provider_filter_cache = LRUCache(maxsize=16)
18  _provider_filter_cache_lock = threading.RLock()
19  
20  
21  def normalize_provider_name(name: str) -> str:
22      return _PROVIDER_ALIASES.get(name, name)
23  
24  
25  def _parse_provider_list(value: str | None) -> frozenset[str]:
26      if not value:
27          return frozenset()
28      return frozenset(
29          normalize_provider_name(p.strip().lower()) for p in value.split(",") if p.strip()
30      )
31  
32  
33  @cached(cache=_provider_filter_cache, lock=_provider_filter_cache_lock)
34  def _parse_allowed_providers(allowed_raw: str | None) -> frozenset[str] | None:
35      return _parse_provider_list(allowed_raw) or None
36  
37  
38  def _get_allowed_providers() -> frozenset[str] | None:
39      return _parse_allowed_providers(MLFLOW_GATEWAY_ALLOWED_PROVIDERS.get())
40  
41  
42  def is_provider_allowed(provider_name: str) -> bool:
43      allowed = _get_allowed_providers()
44      if allowed is None:
45          return True
46      name = normalize_provider_name(provider_name.lower())
47      return name in allowed
48  
49  
50  def filter_providers(providers: list[str]) -> list[str]:
51      allowed = _get_allowed_providers()
52      if allowed is None:
53          return providers
54  
55      result = []
56      for p in providers:
57          name = normalize_provider_name(p.lower())
58          if name not in allowed:
59              _logger.debug("Provider '%s' is not in MLFLOW_GATEWAY_ALLOWED_PROVIDERS", p)
60              continue
61          result.append(p)
62      return result