/ mlflow / server / fastapi_security.py
fastapi_security.py
  1  import logging
  2  from http import HTTPStatus
  3  
  4  from fastapi import FastAPI
  5  from fastapi.middleware.cors import CORSMiddleware
  6  from starlette.types import ASGIApp
  7  
  8  from mlflow.environment_variables import (
  9      MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE,
 10      MLFLOW_SERVER_X_FRAME_OPTIONS,
 11  )
 12  from mlflow.server.security_utils import (
 13      CORS_BLOCKED_MSG,
 14      HEALTH_ENDPOINTS,
 15      INVALID_HOST_MSG,
 16      LOCALHOST_ORIGIN_PATTERNS,
 17      get_allowed_hosts_from_env,
 18      get_allowed_origins_from_env,
 19      get_default_allowed_hosts,
 20      is_allowed_host_header,
 21      is_api_endpoint,
 22      should_block_cors_request,
 23  )
 24  from mlflow.tracing.constant import TRACE_RENDERER_ASSET_PATH
 25  
 26  _logger = logging.getLogger(__name__)
 27  
 28  
 29  class HostValidationMiddleware:
 30      """Middleware to validate Host headers using fnmatch patterns."""
 31  
 32      def __init__(self, app: ASGIApp, allowed_hosts: list[str]):
 33          self.app = app
 34          self.allowed_hosts = allowed_hosts
 35  
 36      async def __call__(self, scope, receive, send):
 37          if scope["type"] != "http":
 38              return await self.app(scope, receive, send)
 39  
 40          if scope["path"] in HEALTH_ENDPOINTS:
 41              return await self.app(scope, receive, send)
 42  
 43          headers = dict(scope.get("headers", []))
 44          host = headers.get(b"host", b"").decode("utf-8")
 45  
 46          if not is_allowed_host_header(self.allowed_hosts, host):
 47              _logger.warning(f"Rejected request with invalid Host header: {host}")
 48  
 49              async def send_403(message):
 50                  if message["type"] == "http.response.start":
 51                      message["status"] = 403
 52                      message["headers"] = [(b"content-type", b"text/plain")]
 53                  await send(message)
 54  
 55              await send_403({"type": "http.response.start", "status": 403, "headers": []})
 56              await send({"type": "http.response.body", "body": INVALID_HOST_MSG.encode()})
 57              return
 58  
 59          return await self.app(scope, receive, send)
 60  
 61  
 62  class SecurityHeadersMiddleware:
 63      """Middleware to add security headers to all responses."""
 64  
 65      def __init__(self, app: ASGIApp):
 66          self.app = app
 67          self.x_frame_options = MLFLOW_SERVER_X_FRAME_OPTIONS.get()
 68  
 69      async def __call__(self, scope, receive, send):
 70          if scope["type"] != "http":
 71              return await self.app(scope, receive, send)
 72  
 73          async def send_wrapper(message):
 74              if message["type"] == "http.response.start":
 75                  headers = dict(message.get("headers", []))
 76                  headers[b"x-content-type-options"] = b"nosniff"
 77  
 78                  # Skip X-Frame-Options for notebook renderer to allow iframe embedding in notebooks
 79                  path = scope.get("path", "")
 80                  is_notebook_renderer = path.startswith(TRACE_RENDERER_ASSET_PATH)
 81  
 82                  if (
 83                      self.x_frame_options
 84                      and self.x_frame_options.upper() != "NONE"
 85                      and not is_notebook_renderer
 86                  ):
 87                      headers[b"x-frame-options"] = self.x_frame_options.upper().encode()
 88  
 89                  if (
 90                      scope["method"] == "OPTIONS"
 91                      and message.get("status") == 200
 92                      and is_api_endpoint(scope["path"])
 93                  ):
 94                      message["status"] = HTTPStatus.NO_CONTENT
 95  
 96                  message["headers"] = list(headers.items())
 97              await send(message)
 98  
 99          await self.app(scope, receive, send_wrapper)
100  
101  
102  class CORSBlockingMiddleware:
103      """Middleware to actively block cross-origin state-changing requests."""
104  
105      def __init__(self, app: ASGIApp, allowed_origins: list[str]):
106          self.app = app
107          self.allowed_origins = allowed_origins
108  
109      async def __call__(self, scope, receive, send):
110          if scope["type"] != "http":
111              return await self.app(scope, receive, send)
112  
113          if not is_api_endpoint(scope["path"]):
114              return await self.app(scope, receive, send)
115  
116          method = scope["method"]
117          headers = dict(scope["headers"])
118          origin = headers.get(b"origin", b"").decode("utf-8")
119  
120          if should_block_cors_request(origin, method, self.allowed_origins):
121              _logger.warning(f"Blocked cross-origin request from {origin}")
122              await send({
123                  "type": "http.response.start",
124                  "status": HTTPStatus.FORBIDDEN,
125                  "headers": [[b"content-type", b"text/plain"]],
126              })
127              await send({
128                  "type": "http.response.body",
129                  "body": CORS_BLOCKED_MSG.encode(),
130              })
131              return
132  
133          await self.app(scope, receive, send)
134  
135  
136  def get_allowed_hosts() -> list[str]:
137      """Get list of allowed hosts from environment or defaults."""
138      return get_allowed_hosts_from_env() or get_default_allowed_hosts()
139  
140  
141  def get_allowed_origins() -> list[str]:
142      """Get list of allowed CORS origins from environment or defaults."""
143      return get_allowed_origins_from_env() or []
144  
145  
146  def init_fastapi_security(app: FastAPI) -> None:
147      """
148      Initialize security middleware for FastAPI application.
149  
150      This configures:
151      - Host header validation (DNS rebinding protection) via TrustedHostMiddleware
152      - CORS protection via CORSMiddleware
153      - Security headers via custom middleware
154  
155      Args:
156          app: FastAPI application instance.
157      """
158      if MLFLOW_SERVER_DISABLE_SECURITY_MIDDLEWARE.get() == "true":
159          return
160  
161      app.add_middleware(SecurityHeadersMiddleware)
162  
163      allowed_origins = get_allowed_origins()
164  
165      if allowed_origins and "*" in allowed_origins:
166          app.add_middleware(
167              CORSMiddleware,
168              allow_origins=["*"],
169              allow_credentials=True,
170              allow_methods=["*"],
171              allow_headers=["*"],
172              expose_headers=["*"],
173          )
174      else:
175          # Use CORSBlockingMiddleware for blocking CORS requests on the server side,
176          # and CORSMiddleware for responding to OPTIONS requests.
177          app.add_middleware(CORSBlockingMiddleware, allowed_origins=allowed_origins)
178          app.add_middleware(
179              CORSMiddleware,
180              allow_origins=allowed_origins,
181              allow_origin_regex="|".join(LOCALHOST_ORIGIN_PATTERNS),
182              allow_credentials=True,
183              allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"],
184              allow_headers=["*"],
185              expose_headers=["*"],
186          )
187  
188      allowed_hosts = get_allowed_hosts()
189  
190      if allowed_hosts and "*" not in allowed_hosts:
191          app.add_middleware(HostValidationMiddleware, allowed_hosts=allowed_hosts)