/ mlflow / server / gateway_api.py
gateway_api.py
   1  """
   2  Database-backed Gateway API endpoints for MLflow Server.
   3  
   4  This module provides dynamic gateway endpoints that are configured from the database
   5  rather than from a static YAML configuration file. It integrates the AI Gateway
   6  functionality directly into the MLflow tracking server.
   7  """
   8  
   9  import functools
  10  import logging
  11  import sys
  12  import time
  13  from collections.abc import Callable
  14  from typing import Any
  15  
  16  from fastapi import APIRouter, HTTPException, Request
  17  from fastapi.responses import StreamingResponse
  18  
  19  from mlflow.entities.gateway_endpoint import GatewayModelLinkageType
  20  from mlflow.exceptions import MlflowException
  21  from mlflow.gateway.budget import check_budget_limit, make_budget_on_complete
  22  from mlflow.gateway.config import (
  23      AmazonBedrockConfig,
  24      AnthropicConfig,
  25      EndpointConfig,
  26      EndpointType,
  27      GatewayRequestType,
  28      GeminiConfig,
  29      LiteLLMConfig,
  30      MistralConfig,
  31      OpenAIAPIType,
  32      OpenAIConfig,
  33      Provider,
  34      VertexAIConfig,
  35      _AuthConfigKey,
  36      _OpenAICompatibleConfig,
  37  )
  38  from mlflow.gateway.constants import MLFLOW_GATEWAY_CALLER_HEADER, GatewayCaller
  39  from mlflow.gateway.guardrail_utils import (
  40      extract_auth_headers,
  41      load_guardrails,
  42      run_post_llm_guardrails,
  43      run_pre_llm_guardrails,
  44  )
  45  from mlflow.gateway.guardrails import (
  46      _SANITIZE_BYPASS_HEADER,
  47      GuardrailViolation,
  48      JudgeGuardrail,
  49  )
  50  from mlflow.gateway.providers import get_provider
  51  from mlflow.gateway.providers.base import (
  52      PASSTHROUGH_ROUTES,
  53      BaseProvider,
  54      FallbackProvider,
  55      PassthroughAction,
  56      TrafficRouteProvider,
  57  )
  58  from mlflow.gateway.providers.utils import provider_call_duration_ms
  59  from mlflow.gateway.schemas import chat, embeddings
  60  from mlflow.gateway.tracing_utils import (
  61      aggregate_anthropic_messages_stream_chunks,
  62      aggregate_chat_stream_chunks,
  63      aggregate_gemini_stream_generate_content_chunks,
  64      aggregate_openai_responses_stream_chunks,
  65      maybe_traced_gateway_call,
  66  )
  67  from mlflow.gateway.utils import safe_stream, to_sse_chunk, translate_http_exception
  68  from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST
  69  from mlflow.store.tracking.abstract_store import AbstractStore
  70  from mlflow.store.tracking.gateway.config_resolver import get_endpoint_config
  71  from mlflow.store.tracking.gateway.entities import (
  72      GatewayEndpointConfig,
  73      GatewayModelConfig,
  74      RoutingStrategy,
  75  )
  76  from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
  77  from mlflow.telemetry.events import GatewayInvocationEvent, GatewayInvocationType
  78  from mlflow.telemetry.track import _record_event
  79  from mlflow.tracing.constant import TraceMetadataKey
  80  from mlflow.tracking._tracking_service.utils import _get_store
  81  from mlflow.utils.provider_filter import is_provider_allowed, normalize_provider_name
  82  from mlflow.utils.workspace_context import get_request_workspace
  83  
  84  _logger = logging.getLogger(__name__)
  85  
  86  gateway_router = APIRouter(prefix="/gateway", tags=["gateway"])
  87  
  88  
  89  async def _get_request_body(request: Request) -> dict[str, Any]:
  90      """
  91      Get request body, using cached version if available.
  92  
  93      The auth middleware may have already parsed the request body for permission
  94      validation. Since Starlette request body can only be read once, we cache
  95      the parsed body in request.state.cached_body for reuse by route handlers.
  96  
  97      Args:
  98          request: The FastAPI Request object.
  99  
 100      Returns:
 101          Parsed JSON body as a dictionary.
 102  
 103      Raises:
 104          HTTPException: If the request body is not valid JSON.
 105      """
 106      # Check if body was already parsed by auth middleware
 107      cached_body = getattr(request.state, "cached_body", None)
 108      if isinstance(cached_body, dict):
 109          return cached_body
 110  
 111      # Otherwise parse it now
 112      try:
 113          return await request.json()
 114      except Exception as e:
 115          raise HTTPException(status_code=400, detail=f"Invalid JSON payload: {e!s}")
 116  
 117  
 118  def _get_user_metadata(request: Request) -> dict[str, Any]:
 119      """
 120      Extract user metadata from request state for tracing.
 121  
 122      The auth middleware stores the authenticated user's info in request.state.
 123  
 124      Args:
 125          request: The FastAPI Request object.
 126  
 127      Returns:
 128          Dictionary with user metadata (username and user_id if available).
 129      """
 130      metadata = {}
 131      if username := getattr(request.state, "username", None):
 132          metadata[TraceMetadataKey.AUTH_USERNAME] = username
 133      if user_id := getattr(request.state, "user_id", None):
 134          metadata[TraceMetadataKey.AUTH_USER_ID] = str(user_id)
 135      return metadata
 136  
 137  
 138  def _record_gateway_invocation(invocation_type: GatewayInvocationType) -> Callable[..., Any]:
 139      """
 140      Decorator for gateway invocation endpoints that records telemetry:
 141      success/failure status, duration, streaming mode, and caller.
 142  
 143      As a side effect, relays provider call duration to the gateway timing middleware by
 144      writing `request.state.gateway_provider_duration_ms`. This is required because
 145      Starlette's call_next() copies the ContextVar context for the handler task, so
 146      mutations to provider_call_duration_ms don't propagate back to the middleware.
 147  
 148      Timing headers (X-MLflow-Gateway-Duration-Ms, X-MLflow-Gateway-Overhead-Duration-Ms)
 149      are injected by gateway_timing_middleware in fastapi_app.py.
 150  
 151      Args:
 152          invocation_type: The type of invocation endpoint.
 153      """
 154  
 155      def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
 156          @functools.wraps(func)
 157          async def wrapper(*args, **kwargs):
 158              start_time = time.perf_counter()
 159              success = True
 160              result = None
 161  
 162              # Extract caller header from the Request object if present,
 163              # only accepting known caller values to avoid logging arbitrary input.
 164              caller = None
 165              request = next((a for a in (*args, *kwargs.values()) if isinstance(a, Request)), None)
 166              if request is not None:
 167                  raw_caller = request.headers.get(MLFLOW_GATEWAY_CALLER_HEADER)
 168                  if raw_caller in {e.value for e in GatewayCaller}:
 169                      caller = raw_caller
 170  
 171              try:
 172                  result = await func(*args, **kwargs)
 173              except Exception:
 174                  success = False
 175                  raise
 176              finally:
 177                  duration_ms = int((time.perf_counter() - start_time) * 1000)
 178                  provider_duration = int(provider_call_duration_ms.get())
 179                  is_streaming = isinstance(result, StreamingResponse)
 180                  params = {
 181                      "is_streaming": is_streaming,
 182                      "invocation_type": invocation_type,
 183                  }
 184                  # provider_call_duration_ms is only updated by send_request()
 185                  # (non-streaming); send_stream_request() never sets it, so
 186                  # timing fields would always be 0 for streaming responses.
 187                  if not is_streaming:
 188                      params["provider_duration_ms"] = provider_duration
 189                      params["gateway_overhead_ms"] = max(0, duration_ms - provider_duration)
 190                  if caller:
 191                      params["caller"] = caller
 192                  if request is not None:
 193                      params["has_traceparent"] = request.headers.get("traceparent") is not None
 194                      auth_mod = sys.modules.get("mlflow.server.auth")
 195                      params["auth_enabled"] = auth_mod.is_auth_enabled() if auth_mod else False
 196                      if endpoint_id := getattr(request.state, "endpoint_id", None):
 197                          params["endpoint_id"] = endpoint_id
 198                      # Prefer the actual provider from the response (set by
 199                      # BaseProvider after the call) over the endpoint config
 200                      # estimate, which may not reflect traffic-split/fallback.
 201                      actual_provider = getattr(result, "provider", None)
 202                      if provider := (actual_provider or getattr(request.state, "provider", None)):
 203                          params["provider"] = provider
 204                  _record_event(
 205                      GatewayInvocationEvent,
 206                      params=params,
 207                      success=success,
 208                      duration_ms=duration_ms,
 209                  )
 210                  # Relay provider timing to the middleware via request.state.
 211                  # ContextVar values set in the handler task don't propagate back
 212                  # to the middleware task (Starlette copies the context for call_next).
 213                  if request is not None:
 214                      request.state.gateway_provider_duration_ms = int(
 215                          provider_call_duration_ms.get()
 216                      )
 217  
 218              return result
 219  
 220          return wrapper
 221  
 222      return decorator
 223  
 224  
 225  def _set_gateway_telemetry_state(request: Request, endpoint_config) -> None:
 226      """Set endpoint_id and provider on request.state for telemetry attribution."""
 227      request.state.endpoint_id = endpoint_config.endpoint_id
 228      if endpoint_config.models:
 229          primary_model = next(
 230              (
 231                  m
 232                  for m in endpoint_config.models
 233                  if m.linkage_type == GatewayModelLinkageType.PRIMARY
 234              ),
 235              endpoint_config.models[0],
 236          )
 237          request.state.provider = str(primary_model.provider)
 238  
 239  
 240  def _build_openai_compatible_config(model_config: "GatewayModelConfig"):
 241      """Build an _OpenAICompatibleConfig for providers that use the OpenAI API format."""
 242      auth_config = model_config.auth_config or {}
 243      return _OpenAICompatibleConfig(
 244          api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY),
 245          api_base=auth_config.get(_AuthConfigKey.API_BASE),
 246      )
 247  
 248  
 249  def _build_endpoint_config(
 250      endpoint_name: str,
 251      model_config: GatewayModelConfig,
 252      endpoint_type: EndpointType,
 253  ) -> EndpointConfig:
 254      """
 255      Build an EndpointConfig from model configuration.
 256  
 257      This function combines provider config building and endpoint config building
 258      into a single operation.
 259  
 260      Args:
 261          endpoint_name: The endpoint name.
 262          model_config: The model configuration object with decrypted secrets.
 263          endpoint_type: Endpoint type (chat or embeddings).
 264  
 265      Returns:
 266          EndpointConfig instance ready for provider instantiation.
 267  
 268      Raises:
 269          MlflowException: If provider configuration is invalid.
 270      """
 271      provider_name = model_config.provider
 272      if not is_provider_allowed(provider_name):
 273          _logger.debug(
 274              "Provider '%s' blocked by MLFLOW_GATEWAY_ALLOWED_PROVIDERS",
 275              provider_name,
 276          )
 277          raise MlflowException.invalid_parameter_value(
 278              f"Provider '{provider_name}' is not allowed by the current gateway provider policy."
 279          )
 280  
 281      provider_config = None
 282  
 283      if model_config.provider == Provider.OPENAI:
 284          auth_config = model_config.auth_config or {}
 285          openai_config = {
 286              "openai_api_key": model_config.secret_value.get(_AuthConfigKey.API_KEY),
 287          }
 288  
 289          # Check if this is Azure OpenAI (requires api_type, deployment_name, api_base, api_version)
 290          if "api_type" in auth_config and auth_config["api_type"] in ("azure", "azuread"):
 291              openai_config["openai_api_type"] = auth_config["api_type"]
 292              openai_config["openai_api_base"] = auth_config.get(_AuthConfigKey.API_BASE)
 293              openai_config["openai_deployment_name"] = auth_config.get("deployment_name")
 294              openai_config["openai_api_version"] = auth_config.get("api_version")
 295          else:
 296              # Standard OpenAI
 297              if _AuthConfigKey.API_BASE in auth_config:
 298                  openai_config["openai_api_base"] = auth_config[_AuthConfigKey.API_BASE]
 299              if "organization" in auth_config:
 300                  openai_config["openai_organization"] = auth_config["organization"]
 301  
 302          provider_config = OpenAIConfig(**openai_config)
 303      elif model_config.provider == Provider.AZURE:
 304          auth_config = model_config.auth_config or {}
 305          model_config.provider = Provider.OPENAI
 306          provider_config = OpenAIConfig(
 307              openai_api_type=OpenAIAPIType.AZURE,
 308              openai_api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY),
 309              openai_api_base=auth_config.get(_AuthConfigKey.API_BASE),
 310              openai_deployment_name=model_config.model_name,
 311              openai_api_version=auth_config.get("api_version"),
 312          )
 313      elif model_config.provider == Provider.ANTHROPIC:
 314          anthropic_config = {
 315              "anthropic_api_key": model_config.secret_value.get(_AuthConfigKey.API_KEY),
 316          }
 317          if model_config.auth_config and "version" in model_config.auth_config:
 318              anthropic_config["anthropic_version"] = model_config.auth_config["version"]
 319          provider_config = AnthropicConfig(**anthropic_config)
 320      elif model_config.provider == Provider.MISTRAL:
 321          provider_config = MistralConfig(
 322              mistral_api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY),
 323          )
 324      elif model_config.provider == Provider.GEMINI:
 325          provider_config = GeminiConfig(
 326              gemini_api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY),
 327          )
 328      elif model_config.provider in {
 329          Provider.GROQ,
 330          Provider.DEEPSEEK,
 331          Provider.XAI,
 332          Provider.OPENROUTER,
 333          Provider.OLLAMA,
 334          Provider.PORTKEY,
 335      }:
 336          provider_config = _build_openai_compatible_config(model_config)
 337      elif normalize_provider_name(model_config.provider) == Provider.DATABRICKS:
 338          from mlflow.gateway.providers.databricks import DatabricksConfig
 339  
 340          auth_config = model_config.auth_config or {}
 341          auth_mode = auth_config.get(_AuthConfigKey.AUTH_MODE, "pat_token")
 342          config_kwargs = {}
 343          if api_base := auth_config.get(_AuthConfigKey.API_BASE):
 344              config_kwargs["host"] = api_base
 345          if auth_mode == "oauth_m2m":
 346              config_kwargs["client_id"] = auth_config.get("client_id")
 347              config_kwargs["client_secret"] = model_config.secret_value.get("client_secret")
 348          else:
 349              config_kwargs["token"] = model_config.secret_value.get(_AuthConfigKey.API_KEY)
 350          provider_config = DatabricksConfig(**config_kwargs)
 351          model_config.provider = Provider.DATABRICKS
 352      elif normalize_provider_name(model_config.provider) == Provider.BEDROCK:
 353          auth_config = model_config.auth_config or {}
 354          auth_mode = auth_config.get(_AuthConfigKey.AUTH_MODE, "api_key")
 355          if auth_mode == "api_key":
 356              # Bearer token auth — bypasses boto3 SigV4
 357              provider_config = AmazonBedrockConfig(
 358                  aws_config={
 359                      "aws_bearer_token": model_config.secret_value.get(_AuthConfigKey.API_KEY),
 360                      "aws_region": auth_config.get("aws_region_name"),
 361                  }
 362              )
 363          elif auth_mode == "access_keys":
 364              provider_config = AmazonBedrockConfig(
 365                  aws_config={
 366                      "aws_access_key_id": model_config.secret_value.get("aws_access_key_id"),
 367                      "aws_secret_access_key": model_config.secret_value.get("aws_secret_access_key"),
 368                      "aws_region": auth_config.get("aws_region_name"),
 369                  }
 370              )
 371          elif auth_mode == "iam_role":
 372              provider_config = AmazonBedrockConfig(
 373                  aws_config={
 374                      "aws_role_arn": auth_config.get("aws_role_name"),
 375                      "aws_region": auth_config.get("aws_region_name"),
 376                  }
 377              )
 378          else:
 379              # default_chain — boto3 resolves credentials from the
 380              # environment (env vars, ~/.aws/credentials, instance profile, etc.)
 381              aws_config = {"aws_region": auth_config.get("aws_region_name")}
 382              if role_arn := auth_config.get("aws_role_name"):
 383                  aws_config["aws_role_arn"] = role_arn
 384              provider_config = AmazonBedrockConfig(aws_config=aws_config)
 385          model_config.provider = Provider.BEDROCK
 386      elif model_config.provider == Provider.VERTEX_AI:
 387          auth_config = model_config.auth_config or {}
 388          provider_config = VertexAIConfig(
 389              vertex_project=auth_config.get("vertex_project"),
 390              vertex_location=auth_config.get("vertex_location"),
 391              vertex_credentials=model_config.secret_value.get("vertex_credentials"),
 392          )
 393      else:
 394          # Use LiteLLM as fallback for unsupported providers
 395          # Store the original provider name for LiteLLM's provider/model format
 396          original_provider = model_config.provider
 397          auth_config = model_config.auth_config or {}
 398          # Merge auth_config with secret_value (secret_value contains api_key and other secrets)
 399          litellm_config = {
 400              "litellm_provider": original_provider,
 401              "litellm_auth_config": auth_config | model_config.secret_value,
 402          }
 403          provider_config = LiteLLMConfig(**litellm_config)
 404          model_config.provider = Provider.LITELLM
 405  
 406      # Build and return EndpointConfig
 407      return EndpointConfig(
 408          name=endpoint_name,
 409          endpoint_type=endpoint_type,
 410          model={
 411              "name": model_config.model_name,
 412              "provider": model_config.provider,
 413              "config": provider_config.model_dump(),
 414          },
 415      )
 416  
 417  
 418  def _create_provider(
 419      endpoint_config: GatewayEndpointConfig,
 420      endpoint_type: EndpointType,
 421      enable_tracing: bool = False,
 422  ) -> BaseProvider:
 423      """
 424      Create a provider instance based on endpoint routing strategy.
 425  
 426      Fallback is independent of routing strategy - if fallback_config is present,
 427      the provider is wrapped with FallbackProvider.
 428  
 429      Args:
 430          endpoint_config: The endpoint configuration with model details and routing config.
 431          endpoint_type: Endpoint type (chat or embeddings).
 432  
 433      Returns:
 434          Provider instance (standard provider, TrafficRouteProvider, or FallbackProvider).
 435  
 436      Raises:
 437          MlflowException: If endpoint configuration is invalid or has no models.
 438      """
 439      # Get PRIMARY models
 440      primary_models = [
 441          model
 442          for model in endpoint_config.models
 443          if model.linkage_type == GatewayModelLinkageType.PRIMARY
 444      ]
 445  
 446      if not primary_models:
 447          raise MlflowException(
 448              f"Endpoint '{endpoint_config.endpoint_name}' has no PRIMARY models configured",
 449              error_code=RESOURCE_DOES_NOT_EXIST,
 450          )
 451  
 452      # Create base provider based on routing strategy
 453      if endpoint_config.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT:
 454          # Traffic split: distribute requests based on weights
 455          configs = []
 456          weights = []
 457          for model_config in primary_models:
 458              gateway_endpoint_config = _build_endpoint_config(
 459                  endpoint_name=endpoint_config.endpoint_name,
 460                  model_config=model_config,
 461                  endpoint_type=endpoint_type,
 462              )
 463              configs.append(gateway_endpoint_config)
 464              weights.append(int(model_config.weight * 100))  # Convert to percentage
 465  
 466          primary_provider = TrafficRouteProvider(
 467              configs=configs,
 468              traffic_splits=weights,
 469              routing_strategy="TRAFFIC_SPLIT",
 470              enable_tracing=enable_tracing,
 471          )
 472      else:
 473          # Default: use the first PRIMARY model
 474          model_config = primary_models[0]
 475          gateway_endpoint_config = _build_endpoint_config(
 476              endpoint_config.endpoint_name, model_config, endpoint_type
 477          )
 478          provider_class = get_provider(model_config.provider)
 479          primary_provider = provider_class(gateway_endpoint_config, enable_tracing=enable_tracing)
 480  
 481      # Wrap with FallbackProvider if fallback configuration exists
 482      if endpoint_config.fallback_config:
 483          fallback_models = [
 484              model
 485              for model in endpoint_config.models
 486              if model.linkage_type == GatewayModelLinkageType.FALLBACK
 487          ]
 488  
 489          if not fallback_models:
 490              _logger.debug(
 491                  f"Endpoint '{endpoint_config.endpoint_name}' has fallback_config "
 492                  "but no FALLBACK models configured"
 493              )
 494              return primary_provider
 495  
 496          # Sort fallback models by fallback_order
 497          fallback_models.sort(
 498              key=lambda m: m.fallback_order if m.fallback_order is not None else float("inf")
 499          )
 500  
 501          fallback_providers = [
 502              get_provider(model_config.provider)(
 503                  _build_endpoint_config(
 504                      endpoint_name=endpoint_config.endpoint_name,
 505                      model_config=model_config,
 506                      endpoint_type=endpoint_type,
 507                  ),
 508                  enable_tracing=enable_tracing,
 509              )
 510              for model_config in fallback_models
 511          ]
 512  
 513          max_attempts = endpoint_config.fallback_config.max_attempts or len(fallback_models)
 514  
 515          # FallbackProvider expects all providers (primary + fallback)
 516          all_providers = [primary_provider] + fallback_providers
 517  
 518          return FallbackProvider(
 519              providers=all_providers,
 520              max_attempts=max_attempts + 1,  # +1 to include primary
 521              strategy=endpoint_config.fallback_config.strategy,
 522              enable_tracing=enable_tracing,
 523          )
 524  
 525      return primary_provider
 526  
 527  
 528  def _create_provider_from_endpoint_name(
 529      store: SqlAlchemyStore,
 530      endpoint_name: str,
 531      endpoint_type: EndpointType,
 532      enable_tracing: bool = True,
 533  ) -> tuple[BaseProvider, GatewayEndpointConfig]:
 534      """
 535      Create a provider from an endpoint name.
 536  
 537      Args:
 538          store: The SQLAlchemy store instance.
 539          endpoint_name: The endpoint name.
 540          endpoint_type: Endpoint type (chat or embeddings).
 541          enable_tracing: If True, enables MLflow tracing for provider calls.
 542  
 543      Returns:
 544          Tuple of (provider instance, endpoint config)
 545      """
 546      endpoint_config = get_endpoint_config(endpoint_name=endpoint_name, store=store)
 547      return _create_provider(
 548          endpoint_config, endpoint_type, enable_tracing=enable_tracing
 549      ), endpoint_config
 550  
 551  
 552  def _validate_store(store: AbstractStore) -> None:
 553      if not isinstance(store, SqlAlchemyStore):
 554          raise HTTPException(
 555              status_code=500,
 556              detail="Gateway endpoints are only available with SqlAlchemyStore, "
 557              f"got {type(store).__name__}.",
 558          )
 559  
 560  
 561  def _extract_endpoint_name_from_model(body: dict[str, Any]) -> str:
 562      """
 563      Extract and validate the endpoint name from the 'model' parameter in the request body.
 564  
 565      Args:
 566          body: The request body dictionary
 567  
 568      Returns:
 569          The endpoint name extracted from the 'model' parameter
 570  
 571      Raises:
 572          HTTPException: If the 'model' parameter is missing
 573      """
 574      endpoint_name = body.get("model")
 575      if not endpoint_name:
 576          raise HTTPException(
 577              status_code=400,
 578              detail="Missing required 'model' parameter in request body",
 579          )
 580      return endpoint_name
 581  
 582  
 583  def _get_guardrails_and_auth(
 584      store, endpoint_config, request: Request
 585  ) -> tuple[list[JudgeGuardrail], dict[str, str]]:
 586      """Load guardrails and extract auth headers, skipping guardrails for internal bypass calls."""
 587      headers = dict(request.headers)
 588      bypass = headers.get(_SANITIZE_BYPASS_HEADER) == "1"
 589      guardrails = [] if bypass else load_guardrails(store, endpoint_config, request)
 590      return guardrails, extract_auth_headers(headers)
 591  
 592  
 593  @gateway_router.post("/{endpoint_name}/mlflow/invocations", response_model=None)
 594  @translate_http_exception
 595  @_record_gateway_invocation(GatewayInvocationType.MLFLOW_INVOCATIONS)
 596  async def invocations(endpoint_name: str, request: Request):
 597      """
 598      Unified invocations endpoint handler that supports both chat and embeddings.
 599  
 600      The handler automatically detects the request type based on the payload structure:
 601      - If payload has "messages" field -> chat endpoint
 602      - If payload has "input" field -> embeddings endpoint
 603      """
 604      body = await _get_request_body(request)
 605      user_metadata = _get_user_metadata(request)
 606      headers = dict(request.headers)
 607  
 608      store = _get_store()
 609      workspace = get_request_workspace()
 610  
 611      _validate_store(store)
 612      endpoint_config = get_endpoint_config(endpoint_name=endpoint_name, store=store)
 613      _set_gateway_telemetry_state(request, endpoint_config)
 614      check_budget_limit(store, endpoint_config, workspace=workspace)
 615      guardrails, auth_headers = _get_guardrails_and_auth(store, endpoint_config, request)
 616  
 617      # Detect request type based on payload structure
 618      if "messages" in body:
 619          # Chat request
 620          endpoint_type = EndpointType.LLM_V1_CHAT
 621          try:
 622              payload = chat.RequestPayload(**body)
 623          except Exception as e:
 624              raise HTTPException(status_code=400, detail=f"Invalid chat payload: {e!s}")
 625  
 626          provider, endpoint_config = _create_provider_from_endpoint_name(
 627              store, endpoint_name, endpoint_type
 628          )
 629  
 630          if payload.stream:
 631              # Post-LLM guardrails are not applied to streaming responses.
 632              # Pre-LLM guardrails run inside the trace as child spans; violations
 633              # are surfaced as SSE error chunks via safe_stream.
 634              async def _guarded_stream(
 635                  payload: chat.RequestPayload,
 636              ):
 637                  request_dict = await run_pre_llm_guardrails(
 638                      guardrails,
 639                      payload.model_dump(),
 640                      auth_headers=auth_headers,
 641                      usage_tracking=endpoint_config.usage_tracking,
 642                  )
 643                  async for chunk in provider.chat_stream(chat.RequestPayload(**request_dict)):
 644                      yield chunk
 645  
 646              stream = maybe_traced_gateway_call(
 647                  _guarded_stream,
 648                  endpoint_config,
 649                  user_metadata,
 650                  output_reducer=aggregate_chat_stream_chunks,
 651                  request_headers=headers,
 652                  request_type=GatewayRequestType.UNIFIED_CHAT,
 653                  on_complete=make_budget_on_complete(store, workspace),
 654              )(payload)
 655              return StreamingResponse(
 656                  safe_stream(to_sse_chunk(chunk.model_dump_json()) async for chunk in stream),
 657                  media_type="text/event-stream",
 658              )
 659          else:
 660  
 661              async def _guarded_chat(
 662                  payload: chat.RequestPayload,
 663              ) -> chat.ResponsePayload:
 664                  request_dict = await run_pre_llm_guardrails(
 665                      guardrails,
 666                      payload.model_dump(),
 667                      auth_headers=auth_headers,
 668                      usage_tracking=endpoint_config.usage_tracking,
 669                  )
 670                  modified_payload = chat.RequestPayload(**request_dict)
 671                  response = await provider.chat(modified_payload)
 672                  return await run_post_llm_guardrails(
 673                      guardrails,
 674                      request_dict,
 675                      response,
 676                      auth_headers=auth_headers,
 677                      usage_tracking=endpoint_config.usage_tracking,
 678                  )
 679  
 680              try:
 681                  return await maybe_traced_gateway_call(
 682                      _guarded_chat,
 683                      endpoint_config,
 684                      user_metadata,
 685                      request_headers=headers,
 686                      request_type=GatewayRequestType.UNIFIED_CHAT,
 687                      on_complete=make_budget_on_complete(store, workspace),
 688                  )(payload)
 689              except GuardrailViolation as e:
 690                  raise HTTPException(status_code=400, detail=str(e))
 691  
 692      elif "input" in body:
 693          # Embeddings request
 694          endpoint_type = EndpointType.LLM_V1_EMBEDDINGS
 695          try:
 696              payload = embeddings.RequestPayload(**body)
 697          except Exception as e:
 698              raise HTTPException(status_code=400, detail=f"Invalid embeddings payload: {e!s}")
 699  
 700          provider, endpoint_config = _create_provider_from_endpoint_name(
 701              store, endpoint_name, endpoint_type
 702          )
 703  
 704          return await maybe_traced_gateway_call(
 705              provider.embeddings,
 706              endpoint_config,
 707              user_metadata,
 708              request_headers=headers,
 709              request_type=GatewayRequestType.UNIFIED_EMBEDDINGS,
 710              on_complete=make_budget_on_complete(store, workspace),
 711          )(payload)
 712  
 713      else:
 714          raise HTTPException(
 715              status_code=400,
 716              detail="Invalid request: payload format must be either chat or embeddings",
 717          )
 718  
 719  
 720  @gateway_router.post("/mlflow/v1/chat/completions", response_model=None)
 721  @translate_http_exception
 722  @_record_gateway_invocation(GatewayInvocationType.MLFLOW_CHAT_COMPLETIONS)
 723  async def chat_completions(request: Request):
 724      """
 725      OpenAI-compatible chat completions endpoint.
 726  
 727      This endpoint follows the OpenAI API format where the endpoint name is specified
 728      via the "model" parameter in the request body, allowing clients to use the
 729      standard OpenAI SDK.
 730  
 731      Example:
 732          POST /gateway/mlflow/v1/chat/completions
 733          {
 734              "model": "my-endpoint-name",
 735              "messages": [{"role": "user", "content": "Hello"}]
 736          }
 737      """
 738      body = await _get_request_body(request)
 739      user_metadata = _get_user_metadata(request)
 740      headers = dict(request.headers)
 741  
 742      # Extract endpoint name from "model" parameter
 743      endpoint_name = _extract_endpoint_name_from_model(body)
 744      body.pop("model")
 745  
 746      store = _get_store()
 747      workspace = get_request_workspace()
 748  
 749      _validate_store(store)
 750      provider, endpoint_config = _create_provider_from_endpoint_name(
 751          store, endpoint_name, EndpointType.LLM_V1_CHAT
 752      )
 753      _set_gateway_telemetry_state(request, endpoint_config)
 754      check_budget_limit(store, endpoint_config, workspace=workspace)
 755      guardrails, auth_headers = _get_guardrails_and_auth(store, endpoint_config, request)
 756  
 757      try:
 758          payload = chat.RequestPayload(**body)
 759      except Exception as e:
 760          raise HTTPException(status_code=400, detail=f"Invalid chat payload: {e!s}")
 761  
 762      if payload.stream:
 763          # Post-LLM guardrails are not applied to streaming responses.
 764          # Pre-LLM guardrails run inside the trace as child spans; violations
 765          # are surfaced as SSE error chunks via safe_stream.
 766          async def _guarded_stream(
 767              payload: chat.RequestPayload,
 768          ):
 769              request_dict = await run_pre_llm_guardrails(
 770                  guardrails,
 771                  payload.model_dump(),
 772                  auth_headers=auth_headers,
 773                  usage_tracking=endpoint_config.usage_tracking,
 774              )
 775              async for chunk in provider.chat_stream(chat.RequestPayload(**request_dict)):
 776                  yield chunk
 777  
 778          stream = maybe_traced_gateway_call(
 779              _guarded_stream,
 780              endpoint_config,
 781              user_metadata,
 782              output_reducer=aggregate_chat_stream_chunks,
 783              request_headers=headers,
 784              request_type=GatewayRequestType.UNIFIED_CHAT,
 785              on_complete=make_budget_on_complete(store, workspace),
 786          )(payload)
 787          return StreamingResponse(
 788              safe_stream(to_sse_chunk(chunk.model_dump_json()) async for chunk in stream),
 789              media_type="text/event-stream",
 790          )
 791      else:
 792  
 793          async def _guarded_chat(
 794              payload: chat.RequestPayload,
 795          ) -> chat.ResponsePayload:
 796              request_dict = await run_pre_llm_guardrails(
 797                  guardrails,
 798                  payload.model_dump(),
 799                  auth_headers=auth_headers,
 800                  usage_tracking=endpoint_config.usage_tracking,
 801              )
 802              modified_payload = chat.RequestPayload(**request_dict)
 803              response = await provider.chat(modified_payload)
 804              return await run_post_llm_guardrails(
 805                  guardrails,
 806                  request_dict,
 807                  response,
 808                  auth_headers=auth_headers,
 809                  usage_tracking=endpoint_config.usage_tracking,
 810              )
 811  
 812          try:
 813              return await maybe_traced_gateway_call(
 814                  _guarded_chat,
 815                  endpoint_config,
 816                  user_metadata,
 817                  request_headers=headers,
 818                  request_type=GatewayRequestType.UNIFIED_CHAT,
 819                  on_complete=make_budget_on_complete(store, workspace),
 820              )(payload)
 821          except GuardrailViolation as e:
 822              raise HTTPException(status_code=400, detail=str(e))
 823  
 824  
 825  @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.OPENAI_CHAT], response_model=None)
 826  @translate_http_exception
 827  @_record_gateway_invocation(GatewayInvocationType.OPENAI_PASSTHROUGH_CHAT)
 828  async def openai_passthrough_chat(request: Request):
 829      """
 830      OpenAI passthrough endpoint for chat completions.
 831  
 832      This endpoint accepts raw OpenAI API format and passes it through to the
 833      OpenAI provider with the configured API key and model. The 'model' parameter
 834      in the request specifies which MLflow endpoint to use.
 835  
 836      Supports streaming responses when the 'stream' parameter is set to true.
 837  
 838      Example:
 839          POST /gateway/openai/v1/chat/completions
 840          {
 841              "model": "my-openai-endpoint",
 842              "messages": [{"role": "user", "content": "Hello"}],
 843              "temperature": 0.7,
 844              "stream": true
 845          }
 846      """
 847      body = await _get_request_body(request)
 848      user_metadata = _get_user_metadata(request)
 849  
 850      endpoint_name = _extract_endpoint_name_from_model(body)
 851      body.pop("model")
 852      store = _get_store()
 853      workspace = get_request_workspace()
 854      _validate_store(store)
 855      headers = dict(request.headers)
 856      provider, endpoint_config = _create_provider_from_endpoint_name(
 857          store, endpoint_name, EndpointType.LLM_V1_CHAT
 858      )
 859      _set_gateway_telemetry_state(request, endpoint_config)
 860      check_budget_limit(store, endpoint_config, workspace=workspace)
 861  
 862      if body.get("stream", False):
 863          stream = await provider.passthrough(
 864              action=PassthroughAction.OPENAI_CHAT, payload=body, headers=headers
 865          )
 866  
 867          # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks
 868          async def yield_stream(body: dict[str, Any]):
 869              async for chunk in stream:
 870                  yield chunk
 871  
 872          traced_stream = maybe_traced_gateway_call(
 873              yield_stream,
 874              endpoint_config,
 875              user_metadata,
 876              request_headers=headers,
 877              request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_CHAT,
 878              on_complete=make_budget_on_complete(store, workspace),
 879          )
 880          return StreamingResponse(
 881              safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream"
 882          )
 883  
 884      traced_passthrough = maybe_traced_gateway_call(
 885          provider.passthrough,
 886          endpoint_config,
 887          user_metadata,
 888          request_headers=headers,
 889          request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_CHAT,
 890          on_complete=make_budget_on_complete(store, workspace),
 891      )
 892      return await traced_passthrough(
 893          action=PassthroughAction.OPENAI_CHAT, payload=body, headers=headers
 894      )
 895  
 896  
 897  @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.OPENAI_EMBEDDINGS], response_model=None)
 898  @translate_http_exception
 899  @_record_gateway_invocation(GatewayInvocationType.OPENAI_PASSTHROUGH_EMBEDDINGS)
 900  async def openai_passthrough_embeddings(request: Request):
 901      """
 902      OpenAI passthrough endpoint for embeddings.
 903  
 904      This endpoint accepts raw OpenAI API format and passes it through to the
 905      OpenAI provider with the configured API key and model. The 'model' parameter
 906      in the request specifies which MLflow endpoint to use.
 907  
 908      Example:
 909          POST /gateway/openai/v1/embeddings
 910          {
 911              "model": "my-openai-endpoint",
 912              "input": "The food was delicious and the waiter..."
 913          }
 914      """
 915      body = await _get_request_body(request)
 916      user_metadata = _get_user_metadata(request)
 917  
 918      endpoint_name = _extract_endpoint_name_from_model(body)
 919      body.pop("model")
 920      store = _get_store()
 921      workspace = get_request_workspace()
 922      _validate_store(store)
 923      headers = dict(request.headers)
 924      provider, endpoint_config = _create_provider_from_endpoint_name(
 925          store, endpoint_name, EndpointType.LLM_V1_EMBEDDINGS
 926      )
 927      _set_gateway_telemetry_state(request, endpoint_config)
 928      check_budget_limit(store, endpoint_config, workspace=workspace)
 929  
 930      traced_passthrough = maybe_traced_gateway_call(
 931          provider.passthrough,
 932          endpoint_config,
 933          user_metadata,
 934          request_headers=headers,
 935          request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_EMBEDDINGS,
 936          on_complete=make_budget_on_complete(store, workspace),
 937      )
 938      return await traced_passthrough(
 939          action=PassthroughAction.OPENAI_EMBEDDINGS, payload=body, headers=headers
 940      )
 941  
 942  
 943  @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.OPENAI_RESPONSES], response_model=None)
 944  @translate_http_exception
 945  @_record_gateway_invocation(GatewayInvocationType.OPENAI_PASSTHROUGH_RESPONSES)
 946  async def openai_passthrough_responses(request: Request):
 947      """
 948      OpenAI passthrough endpoint for the Responses API.
 949  
 950      This endpoint accepts raw OpenAI Responses API format and passes it through to the
 951      OpenAI provider with the configured API key and model. The 'model' parameter
 952      in the request specifies which MLflow endpoint to use.
 953  
 954      Supports streaming responses when the 'stream' parameter is set to true.
 955  
 956      Example:
 957          POST /gateway/openai/v1/responses
 958          {
 959              "model": "my-openai-endpoint",
 960              "input": [{"type": "text", "text": "Hello"}],
 961              "instructions": "You are a helpful assistant",
 962              "stream": true
 963          }
 964      """
 965      body = await _get_request_body(request)
 966      user_metadata = _get_user_metadata(request)
 967  
 968      endpoint_name = _extract_endpoint_name_from_model(body)
 969      body.pop("model")
 970      store = _get_store()
 971      workspace = get_request_workspace()
 972      _validate_store(store)
 973      headers = dict(request.headers)
 974      provider, endpoint_config = _create_provider_from_endpoint_name(
 975          store, endpoint_name, EndpointType.LLM_V1_CHAT
 976      )
 977      _set_gateway_telemetry_state(request, endpoint_config)
 978      check_budget_limit(store, endpoint_config, workspace=workspace)
 979  
 980      if body.get("stream", False):
 981          stream = await provider.passthrough(
 982              action=PassthroughAction.OPENAI_RESPONSES, payload=body, headers=headers
 983          )
 984  
 985          # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks
 986          async def yield_stream(body: dict[str, Any]):
 987              async for chunk in stream:
 988                  yield chunk
 989  
 990          traced_stream = maybe_traced_gateway_call(
 991              yield_stream,
 992              endpoint_config,
 993              user_metadata,
 994              output_reducer=aggregate_openai_responses_stream_chunks,
 995              request_headers=headers,
 996              request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_RESPONSES,
 997              on_complete=make_budget_on_complete(store, workspace),
 998          )
 999          return StreamingResponse(
1000              safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream"
1001          )
1002  
1003      traced_passthrough = maybe_traced_gateway_call(
1004          provider.passthrough,
1005          endpoint_config,
1006          user_metadata,
1007          request_headers=headers,
1008          request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_RESPONSES,
1009          on_complete=make_budget_on_complete(store, workspace),
1010      )
1011      return await traced_passthrough(
1012          action=PassthroughAction.OPENAI_RESPONSES, payload=body, headers=headers
1013      )
1014  
1015  
1016  @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.ANTHROPIC_MESSAGES], response_model=None)
1017  @translate_http_exception
1018  @_record_gateway_invocation(GatewayInvocationType.ANTHROPIC_PASSTHROUGH_MESSAGES)
1019  async def anthropic_passthrough_messages(request: Request):
1020      """
1021      Anthropic passthrough endpoint for the Messages API.
1022  
1023      This endpoint accepts raw Anthropic API format and passes it through to the
1024      Anthropic provider with the configured API key and model. The 'model' parameter
1025      in the request specifies which MLflow endpoint to use.
1026  
1027      Supports streaming responses when the 'stream' parameter is set to true.
1028  
1029      Example:
1030          POST /gateway/anthropic/v1/messages
1031          {
1032              "model": "my-anthropic-endpoint",
1033              "messages": [{"role": "user", "content": "Hello"}],
1034              "max_tokens": 1024,
1035              "stream": true
1036          }
1037      """
1038      body = await _get_request_body(request)
1039      user_metadata = _get_user_metadata(request)
1040  
1041      endpoint_name = _extract_endpoint_name_from_model(body)
1042      body.pop("model")
1043      store = _get_store()
1044      workspace = get_request_workspace()
1045      _validate_store(store)
1046      headers = dict(request.headers)
1047      provider, endpoint_config = _create_provider_from_endpoint_name(
1048          store, endpoint_name, EndpointType.LLM_V1_CHAT
1049      )
1050      _set_gateway_telemetry_state(request, endpoint_config)
1051      check_budget_limit(store, endpoint_config, workspace=workspace)
1052  
1053      if body.get("stream", False):
1054          stream = await provider.passthrough(
1055              action=PassthroughAction.ANTHROPIC_MESSAGES, payload=body, headers=headers
1056          )
1057  
1058          # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks
1059          async def yield_stream(body: dict[str, Any]):
1060              async for chunk in stream:
1061                  yield chunk
1062  
1063          traced_stream = maybe_traced_gateway_call(
1064              yield_stream,
1065              endpoint_config,
1066              user_metadata,
1067              output_reducer=aggregate_anthropic_messages_stream_chunks,
1068              request_headers=headers,
1069              request_type=GatewayRequestType.PASSTHROUGH_MODEL_ANTHROPIC_MESSAGES,
1070              on_complete=make_budget_on_complete(store, workspace),
1071          )
1072          return StreamingResponse(
1073              safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream"
1074          )
1075  
1076      traced_passthrough = maybe_traced_gateway_call(
1077          provider.passthrough,
1078          endpoint_config,
1079          user_metadata,
1080          request_headers=headers,
1081          request_type=GatewayRequestType.PASSTHROUGH_MODEL_ANTHROPIC_MESSAGES,
1082          on_complete=make_budget_on_complete(store, workspace),
1083      )
1084      return await traced_passthrough(
1085          action=PassthroughAction.ANTHROPIC_MESSAGES, payload=body, headers=headers
1086      )
1087  
1088  
1089  @gateway_router.post(
1090      PASSTHROUGH_ROUTES[PassthroughAction.GEMINI_GENERATE_CONTENT], response_model=None
1091  )
1092  @translate_http_exception
1093  @_record_gateway_invocation(GatewayInvocationType.GEMINI_PASSTHROUGH_GENERATE_CONTENT)
1094  async def gemini_passthrough_generate_content(endpoint_name: str, request: Request):
1095      """
1096      Gemini passthrough endpoint for generateContent API (non-streaming).
1097  
1098      This endpoint accepts raw Gemini API format and passes it through to the
1099      Gemini provider with the configured API key. The endpoint_name in the URL path
1100      specifies which MLflow endpoint to use.
1101  
1102      Example:
1103          POST /gateway/gemini/v1beta/models/my-gemini-endpoint:generateContent
1104          {
1105              "contents": [
1106                  {
1107                      "role": "user",
1108                      "parts": [{"text": "Hello"}]
1109                  }
1110              ]
1111          }
1112      """
1113      body = await _get_request_body(request)
1114      user_metadata = _get_user_metadata(request)
1115  
1116      store = _get_store()
1117      workspace = get_request_workspace()
1118      _validate_store(store)
1119      headers = dict(request.headers)
1120      provider, endpoint_config = _create_provider_from_endpoint_name(
1121          store, endpoint_name, EndpointType.LLM_V1_CHAT
1122      )
1123      _set_gateway_telemetry_state(request, endpoint_config)
1124      check_budget_limit(store, endpoint_config, workspace=workspace)
1125  
1126      traced_passthrough = maybe_traced_gateway_call(
1127          provider.passthrough,
1128          endpoint_config,
1129          user_metadata,
1130          request_headers=headers,
1131          request_type=GatewayRequestType.PASSTHROUGH_MODEL_GEMINI_GENERATE_CONTENT,
1132          on_complete=make_budget_on_complete(store, workspace),
1133      )
1134      return await traced_passthrough(
1135          action=PassthroughAction.GEMINI_GENERATE_CONTENT, payload=body, headers=headers
1136      )
1137  
1138  
1139  @gateway_router.post(
1140      PASSTHROUGH_ROUTES[PassthroughAction.GEMINI_STREAM_GENERATE_CONTENT], response_model=None
1141  )
1142  @translate_http_exception
1143  @_record_gateway_invocation(GatewayInvocationType.GEMINI_PASSTHROUGH_STREAM_GENERATE_CONTENT)
1144  async def gemini_passthrough_stream_generate_content(endpoint_name: str, request: Request):
1145      """
1146      Gemini passthrough endpoint for streamGenerateContent API (streaming).
1147  
1148      This endpoint accepts raw Gemini API format and passes it through to the
1149      Gemini provider with the configured API key. The endpoint_name in the URL path
1150      specifies which MLflow endpoint to use.
1151  
1152      Example:
1153          POST /gateway/gemini/v1beta/models/my-gemini-endpoint:streamGenerateContent
1154          {
1155              "contents": [
1156                  {
1157                      "role": "user",
1158                      "parts": [{"text": "Hello"}]
1159                  }
1160              ]
1161          }
1162      """
1163      body = await _get_request_body(request)
1164      user_metadata = _get_user_metadata(request)
1165  
1166      store = _get_store()
1167      workspace = get_request_workspace()
1168      _validate_store(store)
1169      headers = dict(request.headers)
1170      provider, endpoint_config = _create_provider_from_endpoint_name(
1171          store, endpoint_name, EndpointType.LLM_V1_CHAT
1172      )
1173      _set_gateway_telemetry_state(request, endpoint_config)
1174      check_budget_limit(store, endpoint_config, workspace=workspace)
1175  
1176      stream = await provider.passthrough(
1177          action=PassthroughAction.GEMINI_STREAM_GENERATE_CONTENT, payload=body, headers=headers
1178      )
1179  
1180      # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks
1181      async def yield_stream(body: dict[str, Any]):
1182          async for chunk in stream:
1183              yield chunk
1184  
1185      traced_stream = maybe_traced_gateway_call(
1186          yield_stream,
1187          endpoint_config,
1188          user_metadata,
1189          output_reducer=aggregate_gemini_stream_generate_content_chunks,
1190          request_headers=headers,
1191          request_type=GatewayRequestType.PASSTHROUGH_MODEL_GEMINI_GENERATE_CONTENT,
1192          on_complete=make_budget_on_complete(store, workspace),
1193      )
1194      return StreamingResponse(
1195          safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream"
1196      )