/ mlflow / server / fastapi_app.py
fastapi_app.py
  1  """
  2  FastAPI application wrapper for MLflow server.
  3  
  4  This module provides a FastAPI application that wraps the existing Flask application
  5  using WSGIMiddleware to maintain 100% API compatibility while enabling future migration
  6  to FastAPI endpoints.
  7  """
  8  
  9  import json
 10  import time
 11  import typing
 12  
 13  import anyio
 14  from fastapi import FastAPI, Request
 15  from fastapi.responses import JSONResponse
 16  from flask import Flask
 17  from starlette.middleware.wsgi import WSGIResponder, build_environ
 18  from starlette.types import Receive, Scope, Send
 19  
 20  from mlflow.exceptions import MlflowException
 21  from mlflow.gateway.constants import MLFLOW_GATEWAY_DURATION_HEADER, MLFLOW_GATEWAY_OVERHEAD_HEADER
 22  from mlflow.gateway.providers.utils import provider_call_duration_ms
 23  from mlflow.server import app as flask_app
 24  from mlflow.server.assistant.api import assistant_router
 25  from mlflow.server.fastapi_security import init_fastapi_security
 26  from mlflow.server.gateway_api import gateway_router
 27  from mlflow.server.job_api import job_api_router
 28  from mlflow.server.otel_api import otel_router
 29  from mlflow.server.workspace_helpers import (
 30      WORKSPACE_HEADER_NAME,
 31      resolve_workspace_for_request_if_enabled,
 32  )
 33  from mlflow.utils.workspace_context import (
 34      clear_server_request_workspace,
 35      set_server_request_workspace,
 36  )
 37  from mlflow.version import VERSION
 38  
 39  
 40  class _EfficientWSGIResponder(WSGIResponder):
 41      """WSGIResponder with O(n) body buffering instead of O(n^2) concatenation.
 42  
 43      Starlette's WSGIMiddleware is deprecated and upstream has declined to fix the
 44      quadratic body buffering (see https://github.com/Kludex/starlette/pull/2450,
 45      closed in favor of deprecating the module entirely).
 46  
 47      Ref: https://github.com/Kludex/starlette/blob/0e88e92b592bfa11fd92e331869a8d49ba34b541/starlette/middleware/wsgi.py#L98-L117
 48      """
 49  
 50      async def __call__(self, receive: Receive, send: Send) -> None:
 51          # >>> Changed from original: use list + join instead of body += chunk
 52          chunks: list[bytes] = []
 53          more_body = True
 54          while more_body:
 55              message = await receive()
 56              if chunk := message.get("body", b""):
 57                  chunks.append(chunk)
 58              more_body = message.get("more_body", False)
 59          body = b"".join(chunks)
 60          del chunks  # Free chunk list before build_environ copies body into BytesIO
 61          # <<< End of change
 62          environ = build_environ(self.scope, body)
 63  
 64          async with anyio.create_task_group() as task_group:
 65              task_group.start_soon(self.sender, send)
 66              async with self.stream_send:
 67                  await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
 68          if self.exc_info is not None:
 69              raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2])
 70  
 71  
 72  class _EfficientWSGIMiddleware:
 73      """Drop-in replacement for starlette's WSGIMiddleware that avoids O(n^2) body buffering."""
 74  
 75      def __init__(self, app: typing.Callable[..., typing.Any]) -> None:
 76          self.app = app
 77  
 78      async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
 79          assert scope["type"] == "http"
 80          responder = _EfficientWSGIResponder(self.app, scope)
 81          await responder(receive, send)
 82  
 83  
 84  def add_fastapi_workspace_middleware(fastapi_app: FastAPI) -> None:
 85      if getattr(fastapi_app.state, "workspace_middleware_added", False):
 86          return
 87  
 88      @fastapi_app.middleware("http")
 89      async def workspace_context_middleware(request: Request, call_next):
 90          try:
 91              workspace = resolve_workspace_for_request_if_enabled(
 92                  request.url.path,
 93                  request.headers.get(WORKSPACE_HEADER_NAME),
 94              )
 95          except MlflowException as e:
 96              return JSONResponse(
 97                  status_code=e.get_http_status_code(),
 98                  content=json.loads(e.serialize_as_json()),
 99              )
100  
101          set_server_request_workspace(workspace.name if workspace else None)
102          try:
103              response = await call_next(request)
104          finally:
105              clear_server_request_workspace()
106          return response
107  
108      fastapi_app.state.workspace_middleware_added = True
109  
110  
111  def add_gateway_timing_middleware(fastapi_app: FastAPI) -> None:
112      if getattr(fastapi_app.state, "gateway_timing_middleware_added", False):
113          return
114  
115      @fastapi_app.middleware("http")
116      async def gateway_timing_middleware(request: Request, call_next):
117          if not request.url.path.startswith("/gateway/"):
118              return await call_next(request)
119  
120          # Reset the ContextVar so the handler task starts at 0. The handler task
121          # inherits a copy of this context (Starlette's call_next uses copy_context),
122          # so the reset is visible to send_request inside the handler.
123          provider_call_duration_ms.set(0.0)
124          start = time.perf_counter()
125          response = await call_next(request)
126          duration_ms = int((time.perf_counter() - start) * 1000)
127          # Read provider duration relayed via request.state by _record_gateway_invocation.
128          # We can't read the ContextVar directly here because the handler runs in a
129          # separate task and ContextVar mutations don't propagate back.
130          provider_duration_ms = int(getattr(request.state, "gateway_provider_duration_ms", 0))
131  
132          # For non-streaming responses, duration_ms covers the full round-trip.
133          # For streaming responses, duration_ms covers only gateway setup time
134          # (until the StreamingResponse object is returned, before the stream body
135          # is iterated), so it reflects time-to-first-stream rather than total
136          # streaming duration.
137          response.headers[MLFLOW_GATEWAY_DURATION_HEADER] = str(duration_ms)
138          if provider_duration_ms > 0:
139              response.headers[MLFLOW_GATEWAY_OVERHEAD_HEADER] = str(
140                  max(0, duration_ms - provider_duration_ms)
141              )
142          return response
143  
144      fastapi_app.state.gateway_timing_middleware_added = True
145  
146  
147  def create_fastapi_app(flask_app: Flask = flask_app):
148      """
149      Create a FastAPI application that wraps the existing Flask app.
150  
151      Returns:
152          FastAPI application instance with the Flask app mounted via WSGIMiddleware.
153      """
154      # Create FastAPI app with metadata
155      fastapi_app = FastAPI(
156          title="MLflow Tracking Server",
157          description="MLflow Tracking Server API",
158          version=VERSION,
159          # TODO: Enable API documentation when we have native FastAPI endpoints
160          # For now, disable docs since we only have Flask routes via WSGI
161          docs_url=None,
162          redoc_url=None,
163          openapi_url=None,
164      )
165  
166      # Initialize security middleware BEFORE adding routes
167      init_fastapi_security(fastapi_app)
168  
169      add_fastapi_workspace_middleware(fastapi_app)
170      add_gateway_timing_middleware(fastapi_app)
171  
172      # Include OpenTelemetry API router BEFORE mounting Flask app
173      # This ensures FastAPI routes take precedence over the catch-all Flask mount
174      fastapi_app.include_router(otel_router)
175  
176      fastapi_app.include_router(job_api_router)
177  
178      # Include Gateway API router for database-backed endpoints
179      # This provides /gateway/{endpoint_name}/mlflow/invocations routes
180      fastapi_app.include_router(gateway_router)
181  
182      # Include Assistant API router for AI-powered trace analysis
183      # This provides /ajax-api/3.0/mlflow/assistant/* endpoints (localhost only)
184      fastapi_app.include_router(assistant_router)
185  
186      # Mount the entire Flask application at the root path
187      # This ensures compatibility with existing APIs
188      # NOTE: This must come AFTER include_router to avoid Flask catching all requests
189      fastapi_app.mount("/", _EfficientWSGIMiddleware(flask_app))
190  
191      return fastapi_app
192  
193  
194  # Create the app instance that can be used by ASGI servers
195  app = create_fastapi_app()