fastapi_app.py
1 """ 2 FastAPI application wrapper for MLflow server. 3 4 This module provides a FastAPI application that wraps the existing Flask application 5 using WSGIMiddleware to maintain 100% API compatibility while enabling future migration 6 to FastAPI endpoints. 7 """ 8 9 import json 10 import time 11 import typing 12 13 import anyio 14 from fastapi import FastAPI, Request 15 from fastapi.responses import JSONResponse 16 from flask import Flask 17 from starlette.middleware.wsgi import WSGIResponder, build_environ 18 from starlette.types import Receive, Scope, Send 19 20 from mlflow.exceptions import MlflowException 21 from mlflow.gateway.constants import MLFLOW_GATEWAY_DURATION_HEADER, MLFLOW_GATEWAY_OVERHEAD_HEADER 22 from mlflow.gateway.providers.utils import provider_call_duration_ms 23 from mlflow.server import app as flask_app 24 from mlflow.server.assistant.api import assistant_router 25 from mlflow.server.fastapi_security import init_fastapi_security 26 from mlflow.server.gateway_api import gateway_router 27 from mlflow.server.job_api import job_api_router 28 from mlflow.server.otel_api import otel_router 29 from mlflow.server.workspace_helpers import ( 30 WORKSPACE_HEADER_NAME, 31 resolve_workspace_for_request_if_enabled, 32 ) 33 from mlflow.utils.workspace_context import ( 34 clear_server_request_workspace, 35 set_server_request_workspace, 36 ) 37 from mlflow.version import VERSION 38 39 40 class _EfficientWSGIResponder(WSGIResponder): 41 """WSGIResponder with O(n) body buffering instead of O(n^2) concatenation. 42 43 Starlette's WSGIMiddleware is deprecated and upstream has declined to fix the 44 quadratic body buffering (see https://github.com/Kludex/starlette/pull/2450, 45 closed in favor of deprecating the module entirely). 46 47 Ref: https://github.com/Kludex/starlette/blob/0e88e92b592bfa11fd92e331869a8d49ba34b541/starlette/middleware/wsgi.py#L98-L117 48 """ 49 50 async def __call__(self, receive: Receive, send: Send) -> None: 51 # >>> Changed from original: use list + join instead of body += chunk 52 chunks: list[bytes] = [] 53 more_body = True 54 while more_body: 55 message = await receive() 56 if chunk := message.get("body", b""): 57 chunks.append(chunk) 58 more_body = message.get("more_body", False) 59 body = b"".join(chunks) 60 del chunks # Free chunk list before build_environ copies body into BytesIO 61 # <<< End of change 62 environ = build_environ(self.scope, body) 63 64 async with anyio.create_task_group() as task_group: 65 task_group.start_soon(self.sender, send) 66 async with self.stream_send: 67 await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response) 68 if self.exc_info is not None: 69 raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) 70 71 72 class _EfficientWSGIMiddleware: 73 """Drop-in replacement for starlette's WSGIMiddleware that avoids O(n^2) body buffering.""" 74 75 def __init__(self, app: typing.Callable[..., typing.Any]) -> None: 76 self.app = app 77 78 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 79 assert scope["type"] == "http" 80 responder = _EfficientWSGIResponder(self.app, scope) 81 await responder(receive, send) 82 83 84 def add_fastapi_workspace_middleware(fastapi_app: FastAPI) -> None: 85 if getattr(fastapi_app.state, "workspace_middleware_added", False): 86 return 87 88 @fastapi_app.middleware("http") 89 async def workspace_context_middleware(request: Request, call_next): 90 try: 91 workspace = resolve_workspace_for_request_if_enabled( 92 request.url.path, 93 request.headers.get(WORKSPACE_HEADER_NAME), 94 ) 95 except MlflowException as e: 96 return JSONResponse( 97 status_code=e.get_http_status_code(), 98 content=json.loads(e.serialize_as_json()), 99 ) 100 101 set_server_request_workspace(workspace.name if workspace else None) 102 try: 103 response = await call_next(request) 104 finally: 105 clear_server_request_workspace() 106 return response 107 108 fastapi_app.state.workspace_middleware_added = True 109 110 111 def add_gateway_timing_middleware(fastapi_app: FastAPI) -> None: 112 if getattr(fastapi_app.state, "gateway_timing_middleware_added", False): 113 return 114 115 @fastapi_app.middleware("http") 116 async def gateway_timing_middleware(request: Request, call_next): 117 if not request.url.path.startswith("/gateway/"): 118 return await call_next(request) 119 120 # Reset the ContextVar so the handler task starts at 0. The handler task 121 # inherits a copy of this context (Starlette's call_next uses copy_context), 122 # so the reset is visible to send_request inside the handler. 123 provider_call_duration_ms.set(0.0) 124 start = time.perf_counter() 125 response = await call_next(request) 126 duration_ms = int((time.perf_counter() - start) * 1000) 127 # Read provider duration relayed via request.state by _record_gateway_invocation. 128 # We can't read the ContextVar directly here because the handler runs in a 129 # separate task and ContextVar mutations don't propagate back. 130 provider_duration_ms = int(getattr(request.state, "gateway_provider_duration_ms", 0)) 131 132 # For non-streaming responses, duration_ms covers the full round-trip. 133 # For streaming responses, duration_ms covers only gateway setup time 134 # (until the StreamingResponse object is returned, before the stream body 135 # is iterated), so it reflects time-to-first-stream rather than total 136 # streaming duration. 137 response.headers[MLFLOW_GATEWAY_DURATION_HEADER] = str(duration_ms) 138 if provider_duration_ms > 0: 139 response.headers[MLFLOW_GATEWAY_OVERHEAD_HEADER] = str( 140 max(0, duration_ms - provider_duration_ms) 141 ) 142 return response 143 144 fastapi_app.state.gateway_timing_middleware_added = True 145 146 147 def create_fastapi_app(flask_app: Flask = flask_app): 148 """ 149 Create a FastAPI application that wraps the existing Flask app. 150 151 Returns: 152 FastAPI application instance with the Flask app mounted via WSGIMiddleware. 153 """ 154 # Create FastAPI app with metadata 155 fastapi_app = FastAPI( 156 title="MLflow Tracking Server", 157 description="MLflow Tracking Server API", 158 version=VERSION, 159 # TODO: Enable API documentation when we have native FastAPI endpoints 160 # For now, disable docs since we only have Flask routes via WSGI 161 docs_url=None, 162 redoc_url=None, 163 openapi_url=None, 164 ) 165 166 # Initialize security middleware BEFORE adding routes 167 init_fastapi_security(fastapi_app) 168 169 add_fastapi_workspace_middleware(fastapi_app) 170 add_gateway_timing_middleware(fastapi_app) 171 172 # Include OpenTelemetry API router BEFORE mounting Flask app 173 # This ensures FastAPI routes take precedence over the catch-all Flask mount 174 fastapi_app.include_router(otel_router) 175 176 fastapi_app.include_router(job_api_router) 177 178 # Include Gateway API router for database-backed endpoints 179 # This provides /gateway/{endpoint_name}/mlflow/invocations routes 180 fastapi_app.include_router(gateway_router) 181 182 # Include Assistant API router for AI-powered trace analysis 183 # This provides /ajax-api/3.0/mlflow/assistant/* endpoints (localhost only) 184 fastapi_app.include_router(assistant_router) 185 186 # Mount the entire Flask application at the root path 187 # This ensures compatibility with existing APIs 188 # NOTE: This must come AFTER include_router to avoid Flask catching all requests 189 fastapi_app.mount("/", _EfficientWSGIMiddleware(flask_app)) 190 191 return fastapi_app 192 193 194 # Create the app instance that can be used by ASGI servers 195 app = create_fastapi_app()