/ mlflow / server / security_utils.py
security_utils.py
  1  """
  2  Shared security utilities for MLflow server middleware.
  3  
  4  This module contains common functions used by both Flask and FastAPI
  5  security implementations.
  6  """
  7  
  8  import fnmatch
  9  from urllib.parse import urlparse
 10  
 11  from mlflow.environment_variables import (
 12      MLFLOW_SERVER_ALLOWED_HOSTS,
 13      MLFLOW_SERVER_CORS_ALLOWED_ORIGINS,
 14  )
 15  
 16  # Security response messages
 17  INVALID_HOST_MSG = "Invalid Host header - possible DNS rebinding attack detected"
 18  CORS_BLOCKED_MSG = "Cross-origin request blocked"
 19  
 20  # HTTP methods that modify state
 21  STATE_CHANGING_METHODS = ["POST", "PUT", "DELETE", "PATCH"]
 22  
 23  # Paths exempt from host validation
 24  HEALTH_ENDPOINTS = ["/health", "/version"]
 25  
 26  # API path prefixes for MLflow endpoints
 27  API_PATH_PREFIX = "/api/"
 28  AJAX_API_PATH_PREFIX = "/ajax-api/"
 29  
 30  # Test-only endpoints that should not have CORS blocking
 31  TEST_ENDPOINTS = ["/test", "/api/test"]
 32  
 33  # Localhost addresses
 34  LOCALHOST_VARIANTS = ["localhost", "127.0.0.1", "[::1]", "0.0.0.0"]
 35  CORS_LOCALHOST_HOSTS = ["localhost", "127.0.0.1", "[::1]", "::1"]
 36  
 37  # Private IP range start values for 172.16.0.0/12
 38  PRIVATE_172_RANGE_START = 16
 39  PRIVATE_172_RANGE_END = 32
 40  
 41  # Regex patterns for localhost origins
 42  LOCALHOST_ORIGIN_PATTERNS = [
 43      r"^http://localhost(:[0-9]+)?$",
 44      r"^http://127\.0\.0\.1(:[0-9]+)?$",
 45      r"^http://\[::1\](:[0-9]+)?$",
 46  ]
 47  
 48  
 49  def get_localhost_addresses() -> list[str]:
 50      """Get localhost/loopback addresses."""
 51      return LOCALHOST_VARIANTS
 52  
 53  
 54  def get_private_ip_patterns() -> list[str]:
 55      """
 56      Generate wildcard patterns for private IP ranges.
 57  
 58      These are the standard RFC-defined private address ranges:
 59      - RFC 1918 (IPv4): 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
 60        https://datatracker.ietf.org/doc/html/rfc1918
 61      - RFC 4193 (IPv6): fc00::/7
 62        https://datatracker.ietf.org/doc/html/rfc4193
 63  
 64      Additional references:
 65      - IANA IPv4 Special-Purpose Address Registry:
 66        https://www.iana.org/assignments/iana-ipv4-special-registry/
 67      - IANA IPv6 Special-Purpose Address Registry:
 68        https://www.iana.org/assignments/iana-ipv6-special-registry/
 69      """
 70      return [
 71          "192.168.*",
 72          "10.*",
 73          *[f"172.{i}.*" for i in range(PRIVATE_172_RANGE_START, PRIVATE_172_RANGE_END)],
 74          "fc00:*",
 75          "fd00:*",
 76      ]
 77  
 78  
 79  def get_allowed_hosts_from_env() -> list[str] | None:
 80      """Get allowed hosts from environment variable."""
 81      if allowed_hosts_env := MLFLOW_SERVER_ALLOWED_HOSTS.get():
 82          return [host.strip() for host in allowed_hosts_env.split(",")]
 83      return None
 84  
 85  
 86  def get_allowed_origins_from_env() -> list[str] | None:
 87      """Get allowed CORS origins from environment variable."""
 88      if allowed_origins_env := MLFLOW_SERVER_CORS_ALLOWED_ORIGINS.get():
 89          return [origin.strip() for origin in allowed_origins_env.split(",")]
 90      return None
 91  
 92  
 93  def is_localhost_origin(origin: str) -> bool:
 94      """Check if an origin is from localhost."""
 95      if not origin:
 96          return False
 97  
 98      try:
 99          parsed = urlparse(origin)
100          hostname = parsed.hostname
101          return hostname in CORS_LOCALHOST_HOSTS
102      except Exception:
103          return False
104  
105  
106  def should_block_cors_request(origin: str, method: str, allowed_origins: list[str] | None) -> bool:
107      """Determine if a CORS request should be blocked."""
108      if not origin or method not in STATE_CHANGING_METHODS:
109          return False
110  
111      if is_localhost_origin(origin):
112          return False
113  
114      if allowed_origins:
115          # If wildcard "*" is in the list, allow all origins
116          if "*" in allowed_origins:
117              return False
118  
119          return not any(
120              fnmatch.fnmatch(origin, allowed) if "*" in allowed else origin == allowed
121              for allowed in allowed_origins
122          )
123  
124      return True
125  
126  
127  def is_api_endpoint(path: str) -> bool:
128      """Check if a path is an API endpoint that should have CORS/OPTIONS handling."""
129      return (
130          path.startswith(API_PATH_PREFIX) or path.startswith(AJAX_API_PATH_PREFIX)
131      ) and path not in TEST_ENDPOINTS
132  
133  
134  def is_allowed_host_header(allowed_hosts: list[str], host: str) -> bool:
135      """Validate if the host header matches allowed patterns."""
136      if not host:
137          return False
138  
139      # If wildcard "*" is in the list, allow all hosts
140      if "*" in allowed_hosts:
141          return True
142  
143      return any(
144          fnmatch.fnmatch(host, allowed) if "*" in allowed else host == allowed
145          for allowed in allowed_hosts
146      )
147  
148  
149  def get_default_allowed_hosts() -> list[str]:
150      """Get default allowed hosts patterns."""
151      wildcard_hosts = []
152      for host in get_localhost_addresses():
153          if host.startswith("["):
154              # IPv6: escape opening bracket for fnmatch
155              escaped = host.replace("[", "[[]", 1)
156              wildcard_hosts.append(f"{escaped}:*")
157          else:
158              wildcard_hosts.append(f"{host}:*")
159  
160      return get_localhost_addresses() + wildcard_hosts + get_private_ip_patterns()