fastapi_security.py
1 import logging 2 from http import HTTPStatus 3 4 from fastapi import FastAPI 5 from fastapi.middleware.cors import CORSMiddleware 6 from starlette.types import ASGIApp 7 8 from mlflow.environment_variables import ( 9 MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE, 10 MLFLOW_SERVER_X_FRAME_OPTIONS, 11 ) 12 from mlflow.server.security_utils import ( 13 CORS_BLOCKED_MSG, 14 HEALTH_ENDPOINTS, 15 INVALID_HOST_MSG, 16 LOCALHOST_ORIGIN_PATTERNS, 17 get_allowed_hosts_from_env, 18 get_allowed_origins_from_env, 19 get_default_allowed_hosts, 20 is_allowed_host_header, 21 is_api_endpoint, 22 should_block_cors_request, 23 ) 24 from mlflow.tracing.constant import TRACE_RENDERER_ASSET_PATH 25 26 _logger = logging.getLogger(__name__) 27 28 29 class HostValidationMiddleware: 30 """Middleware to validate Host headers using fnmatch patterns.""" 31 32 def __init__(self, app: ASGIApp, allowed_hosts: list[str]): 33 self.app = app 34 self.allowed_hosts = allowed_hosts 35 36 async def __call__(self, scope, receive, send): 37 if scope["type"] != "http": 38 return await self.app(scope, receive, send) 39 40 if scope["path"] in HEALTH_ENDPOINTS: 41 return await self.app(scope, receive, send) 42 43 headers = dict(scope.get("headers", [])) 44 host = headers.get(b"host", b"").decode("utf-8") 45 46 if not is_allowed_host_header(self.allowed_hosts, host): 47 _logger.warning(f"Rejected request with invalid Host header: {host}") 48 49 async def send_403(message): 50 if message["type"] == "http.response.start": 51 message["status"] = 403 52 message["headers"] = [(b"content-type", b"text/plain")] 53 await send(message) 54 55 await send_403({"type": "http.response.start", "status": 403, "headers": []}) 56 await send({"type": "http.response.body", "body": INVALID_HOST_MSG.encode()}) 57 return 58 59 return await self.app(scope, receive, send) 60 61 62 class SecurityHeadersMiddleware: 63 """Middleware to add security headers to all responses.""" 64 65 def __init__(self, app: ASGIApp): 66 self.app = app 67 self.x_frame_options = MLFLOW_SERVER_X_FRAME_OPTIONS.get() 68 69 async def __call__(self, scope, receive, send): 70 if scope["type"] != "http": 71 return await self.app(scope, receive, send) 72 73 async def send_wrapper(message): 74 if message["type"] == "http.response.start": 75 headers = dict(message.get("headers", [])) 76 headers[b"x-content-type-options"] = b"nosniff" 77 78 # Skip X-Frame-Options for notebook renderer to allow iframe embedding in notebooks 79 path = scope.get("path", "") 80 is_notebook_renderer = path.startswith(TRACE_RENDERER_ASSET_PATH) 81 82 if ( 83 self.x_frame_options 84 and self.x_frame_options.upper() != "NONE" 85 and not is_notebook_renderer 86 ): 87 headers[b"x-frame-options"] = self.x_frame_options.upper().encode() 88 89 if ( 90 scope["method"] == "OPTIONS" 91 and message.get("status") == 200 92 and is_api_endpoint(scope["path"]) 93 ): 94 message["status"] = HTTPStatus.NO_CONTENT 95 96 message["headers"] = list(headers.items()) 97 await send(message) 98 99 await self.app(scope, receive, send_wrapper) 100 101 102 class CORSBlockingMiddleware: 103 """Middleware to actively block cross-origin state-changing requests.""" 104 105 def __init__(self, app: ASGIApp, allowed_origins: list[str]): 106 self.app = app 107 self.allowed_origins = allowed_origins 108 109 async def __call__(self, scope, receive, send): 110 if scope["type"] != "http": 111 return await self.app(scope, receive, send) 112 113 if not is_api_endpoint(scope["path"]): 114 return await self.app(scope, receive, send) 115 116 method = scope["method"] 117 headers = dict(scope["headers"]) 118 origin = headers.get(b"origin", b"").decode("utf-8") 119 120 if should_block_cors_request(origin, method, self.allowed_origins): 121 _logger.warning(f"Blocked cross-origin request from {origin}") 122 await send({ 123 "type": "http.response.start", 124 "status": HTTPStatus.FORBIDDEN, 125 "headers": [[b"content-type", b"text/plain"]], 126 }) 127 await send({ 128 "type": "http.response.body", 129 "body": CORS_BLOCKED_MSG.encode(), 130 }) 131 return 132 133 await self.app(scope, receive, send) 134 135 136 def get_allowed_hosts() -> list[str]: 137 """Get list of allowed hosts from environment or defaults.""" 138 return get_allowed_hosts_from_env() or get_default_allowed_hosts() 139 140 141 def get_allowed_origins() -> list[str]: 142 """Get list of allowed CORS origins from environment or defaults.""" 143 return get_allowed_origins_from_env() or [] 144 145 146 def init_fastapi_security(app: FastAPI) -> None: 147 """ 148 Initialize security middleware for FastAPI application. 149 150 This configures: 151 - Host header validation (DNS rebinding protection) via TrustedHostMiddleware 152 - CORS protection via CORSMiddleware 153 - Security headers via custom middleware 154 155 Args: 156 app: FastAPI application instance. 157 """ 158 if MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE.get() == "true": 159 return 160 161 app.add_middleware(SecurityHeadersMiddleware) 162 163 allowed_origins = get_allowed_origins() 164 165 if allowed_origins and "*" in allowed_origins: 166 app.add_middleware( 167 CORSMiddleware, 168 allow_origins=["*"], 169 allow_credentials=True, 170 allow_methods=["*"], 171 allow_headers=["*"], 172 expose_headers=["*"], 173 ) 174 else: 175 # Use CORSBlockingMiddleware for blocking CORS requests on the server side, 176 # and CORSMiddleware for responding to OPTIONS requests. 177 app.add_middleware(CORSBlockingMiddleware, allowed_origins=allowed_origins) 178 app.add_middleware( 179 CORSMiddleware, 180 allow_origins=allowed_origins, 181 allow_origin_regex="|".join(LOCALHOST_ORIGIN_PATTERNS), 182 allow_credentials=True, 183 allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"], 184 allow_headers=["*"], 185 expose_headers=["*"], 186 ) 187 188 allowed_hosts = get_allowed_hosts() 189 190 if allowed_hosts and "*" not in allowed_hosts: 191 app.add_middleware(HostValidationMiddleware, allowed_hosts=allowed_hosts)