security_utils.py
1 """ 2 Shared security utilities for MLflow server middleware. 3 4 This module contains common functions used by both Flask and FastAPI 5 security implementations. 6 """ 7 8 import fnmatch 9 from urllib.parse import urlparse 10 11 from mlflow.environment_variables import ( 12 MLFLOW_SERVER_ALLOWED_HOSTS, 13 MLFLOW_SERVER_CORS_ALLOWED_ORIGINS, 14 ) 15 16 # Security response messages 17 INVALID_HOST_MSG = "Invalid Host header - possible DNS rebinding attack detected" 18 CORS_BLOCKED_MSG = "Cross-origin request blocked" 19 20 # HTTP methods that modify state 21 STATE_CHANGING_METHODS = ["POST", "PUT", "DELETE", "PATCH"] 22 23 # Paths exempt from host validation 24 HEALTH_ENDPOINTS = ["/health", "/version"] 25 26 # API path prefixes for MLflow endpoints 27 API_PATH_PREFIX = "/api/" 28 AJAX_API_PATH_PREFIX = "/ajax-api/" 29 30 # Test-only endpoints that should not have CORS blocking 31 TEST_ENDPOINTS = ["/test", "/api/test"] 32 33 # Localhost addresses 34 LOCALHOST_VARIANTS = ["localhost", "127.0.0.1", "[::1]", "0.0.0.0"] 35 CORS_LOCALHOST_HOSTS = ["localhost", "127.0.0.1", "[::1]", "::1"] 36 37 # Private IP range start values for 172.16.0.0/12 38 PRIVATE_172_RANGE_START = 16 39 PRIVATE_172_RANGE_END = 32 40 41 # Regex patterns for localhost origins 42 LOCALHOST_ORIGIN_PATTERNS = [ 43 r"^http://localhost(:[0-9]+)?$", 44 r"^http://127\.0\.0\.1(:[0-9]+)?$", 45 r"^http://\[::1\](:[0-9]+)?$", 46 ] 47 48 49 def get_localhost_addresses() -> list[str]: 50 """Get localhost/loopback addresses.""" 51 return LOCALHOST_VARIANTS 52 53 54 def get_private_ip_patterns() -> list[str]: 55 """ 56 Generate wildcard patterns for private IP ranges. 57 58 These are the standard RFC-defined private address ranges: 59 - RFC 1918 (IPv4): 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 60 https://datatracker.ietf.org/doc/html/rfc1918 61 - RFC 4193 (IPv6): fc00::/7 62 https://datatracker.ietf.org/doc/html/rfc4193 63 64 Additional references: 65 - IANA IPv4 Special-Purpose Address Registry: 66 https://www.iana.org/assignments/iana-ipv4-special-registry/ 67 - IANA IPv6 Special-Purpose Address Registry: 68 https://www.iana.org/assignments/iana-ipv6-special-registry/ 69 """ 70 return [ 71 "192.168.*", 72 "10.*", 73 *[f"172.{i}.*" for i in range(PRIVATE_172_RANGE_START, PRIVATE_172_RANGE_END)], 74 "fc00:*", 75 "fd00:*", 76 ] 77 78 79 def get_allowed_hosts_from_env() -> list[str] | None: 80 """Get allowed hosts from environment variable.""" 81 if allowed_hosts_env := MLFLOW_SERVER_ALLOWED_HOSTS.get(): 82 return [host.strip() for host in allowed_hosts_env.split(",")] 83 return None 84 85 86 def get_allowed_origins_from_env() -> list[str] | None: 87 """Get allowed CORS origins from environment variable.""" 88 if allowed_origins_env := MLFLOW_SERVER_CORS_ALLOWED_ORIGINS.get(): 89 return [origin.strip() for origin in allowed_origins_env.split(",")] 90 return None 91 92 93 def is_localhost_origin(origin: str) -> bool: 94 """Check if an origin is from localhost.""" 95 if not origin: 96 return False 97 98 try: 99 parsed = urlparse(origin) 100 hostname = parsed.hostname 101 return hostname in CORS_LOCALHOST_HOSTS 102 except Exception: 103 return False 104 105 106 def should_block_cors_request(origin: str, method: str, allowed_origins: list[str] | None) -> bool: 107 """Determine if a CORS request should be blocked.""" 108 if not origin or method not in STATE_CHANGING_METHODS: 109 return False 110 111 if is_localhost_origin(origin): 112 return False 113 114 if allowed_origins: 115 # If wildcard "*" is in the list, allow all origins 116 if "*" in allowed_origins: 117 return False 118 119 return not any( 120 fnmatch.fnmatch(origin, allowed) if "*" in allowed else origin == allowed 121 for allowed in allowed_origins 122 ) 123 124 return True 125 126 127 def is_api_endpoint(path: str) -> bool: 128 """Check if a path is an API endpoint that should have CORS/OPTIONS handling.""" 129 return ( 130 path.startswith(API_PATH_PREFIX) or path.startswith(AJAX_API_PATH_PREFIX) 131 ) and path not in TEST_ENDPOINTS 132 133 134 def is_allowed_host_header(allowed_hosts: list[str], host: str) -> bool: 135 """Validate if the host header matches allowed patterns.""" 136 if not host: 137 return False 138 139 # If wildcard "*" is in the list, allow all hosts 140 if "*" in allowed_hosts: 141 return True 142 143 return any( 144 fnmatch.fnmatch(host, allowed) if "*" in allowed else host == allowed 145 for allowed in allowed_hosts 146 ) 147 148 149 def get_default_allowed_hosts() -> list[str]: 150 """Get default allowed hosts patterns.""" 151 wildcard_hosts = [] 152 for host in get_localhost_addresses(): 153 if host.startswith("["): 154 # IPv6: escape opening bracket for fnmatch 155 escaped = host.replace("[", "[[]", 1) 156 wildcard_hosts.append(f"{escaped}:*") 157 else: 158 wildcard_hosts.append(f"{host}:*") 159 160 return get_localhost_addresses() + wildcard_hosts + get_private_ip_patterns()