gateway_api.py
1 """ 2 Database-backed Gateway API endpoints for MLflow Server. 3 4 This module provides dynamic gateway endpoints that are configured from the database 5 rather than from a static YAML configuration file. It integrates the AI Gateway 6 functionality directly into the MLflow tracking server. 7 """ 8 9 import functools 10 import logging 11 import sys 12 import time 13 from collections.abc import Callable 14 from typing import Any 15 16 from fastapi import APIRouter, HTTPException, Request 17 from fastapi.responses import StreamingResponse 18 19 from mlflow.entities.gateway_endpoint import GatewayModelLinkageType 20 from mlflow.exceptions import MlflowException 21 from mlflow.gateway.budget import check_budget_limit, make_budget_on_complete 22 from mlflow.gateway.config import ( 23 AmazonBedrockConfig, 24 AnthropicConfig, 25 EndpointConfig, 26 EndpointType, 27 GatewayRequestType, 28 GeminiConfig, 29 LiteLLMConfig, 30 MistralConfig, 31 OpenAIAPIType, 32 OpenAIConfig, 33 Provider, 34 VertexAIConfig, 35 _AuthConfigKey, 36 _OpenAICompatibleConfig, 37 ) 38 from mlflow.gateway.constants import MLFLOW_GATEWAY_CALLER_HEADER, GatewayCaller 39 from mlflow.gateway.guardrail_utils import ( 40 extract_auth_headers, 41 load_guardrails, 42 run_post_llm_guardrails, 43 run_pre_llm_guardrails, 44 ) 45 from mlflow.gateway.guardrails import ( 46 _SANITIZE_BYPASS_HEADER, 47 GuardrailViolation, 48 JudgeGuardrail, 49 ) 50 from mlflow.gateway.providers import get_provider 51 from mlflow.gateway.providers.base import ( 52 PASSTHROUGH_ROUTES, 53 BaseProvider, 54 FallbackProvider, 55 PassthroughAction, 56 TrafficRouteProvider, 57 ) 58 from mlflow.gateway.providers.utils import provider_call_duration_ms 59 from mlflow.gateway.schemas import chat, embeddings 60 from mlflow.gateway.tracing_utils import ( 61 aggregate_anthropic_messages_stream_chunks, 62 aggregate_chat_stream_chunks, 63 aggregate_gemini_stream_generate_content_chunks, 64 aggregate_openai_responses_stream_chunks, 65 maybe_traced_gateway_call, 66 ) 67 from mlflow.gateway.utils import safe_stream, to_sse_chunk, translate_http_exception 68 from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST 69 from mlflow.store.tracking.abstract_store import AbstractStore 70 from mlflow.store.tracking.gateway.config_resolver import get_endpoint_config 71 from mlflow.store.tracking.gateway.entities import ( 72 GatewayEndpointConfig, 73 GatewayModelConfig, 74 RoutingStrategy, 75 ) 76 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore 77 from mlflow.telemetry.events import GatewayInvocationEvent, GatewayInvocationType 78 from mlflow.telemetry.track import _record_event 79 from mlflow.tracing.constant import TraceMetadataKey 80 from mlflow.tracking._tracking_service.utils import _get_store 81 from mlflow.utils.provider_filter import is_provider_allowed, normalize_provider_name 82 from mlflow.utils.workspace_context import get_request_workspace 83 84 _logger = logging.getLogger(__name__) 85 86 gateway_router = APIRouter(prefix="/gateway", tags=["gateway"]) 87 88 89 async def _get_request_body(request: Request) -> dict[str, Any]: 90 """ 91 Get request body, using cached version if available. 92 93 The auth middleware may have already parsed the request body for permission 94 validation. Since Starlette request body can only be read once, we cache 95 the parsed body in request.state.cached_body for reuse by route handlers. 96 97 Args: 98 request: The FastAPI Request object. 99 100 Returns: 101 Parsed JSON body as a dictionary. 102 103 Raises: 104 HTTPException: If the request body is not valid JSON. 105 """ 106 # Check if body was already parsed by auth middleware 107 cached_body = getattr(request.state, "cached_body", None) 108 if isinstance(cached_body, dict): 109 return cached_body 110 111 # Otherwise parse it now 112 try: 113 return await request.json() 114 except Exception as e: 115 raise HTTPException(status_code=400, detail=f"Invalid JSON payload: {e!s}") 116 117 118 def _get_user_metadata(request: Request) -> dict[str, Any]: 119 """ 120 Extract user metadata from request state for tracing. 121 122 The auth middleware stores the authenticated user's info in request.state. 123 124 Args: 125 request: The FastAPI Request object. 126 127 Returns: 128 Dictionary with user metadata (username and user_id if available). 129 """ 130 metadata = {} 131 if username := getattr(request.state, "username", None): 132 metadata[TraceMetadataKey.AUTH_USERNAME] = username 133 if user_id := getattr(request.state, "user_id", None): 134 metadata[TraceMetadataKey.AUTH_USER_ID] = str(user_id) 135 return metadata 136 137 138 def _record_gateway_invocation(invocation_type: GatewayInvocationType) -> Callable[..., Any]: 139 """ 140 Decorator for gateway invocation endpoints that records telemetry: 141 success/failure status, duration, streaming mode, and caller. 142 143 As a side effect, relays provider call duration to the gateway timing middleware by 144 writing `request.state.gateway_provider_duration_ms`. This is required because 145 Starlette's call_next() copies the ContextVar context for the handler task, so 146 mutations to provider_call_duration_ms don't propagate back to the middleware. 147 148 Timing headers (X-MLflow-Gateway-Duration-Ms, X-MLflow-Gateway-Overhead-Duration-Ms) 149 are injected by gateway_timing_middleware in fastapi_app.py. 150 151 Args: 152 invocation_type: The type of invocation endpoint. 153 """ 154 155 def decorator(func: Callable[..., Any]) -> Callable[..., Any]: 156 @functools.wraps(func) 157 async def wrapper(*args, **kwargs): 158 start_time = time.perf_counter() 159 success = True 160 result = None 161 162 # Extract caller header from the Request object if present, 163 # only accepting known caller values to avoid logging arbitrary input. 164 caller = None 165 request = next((a for a in (*args, *kwargs.values()) if isinstance(a, Request)), None) 166 if request is not None: 167 raw_caller = request.headers.get(MLFLOW_GATEWAY_CALLER_HEADER) 168 if raw_caller in {e.value for e in GatewayCaller}: 169 caller = raw_caller 170 171 try: 172 result = await func(*args, **kwargs) 173 except Exception: 174 success = False 175 raise 176 finally: 177 duration_ms = int((time.perf_counter() - start_time) * 1000) 178 provider_duration = int(provider_call_duration_ms.get()) 179 is_streaming = isinstance(result, StreamingResponse) 180 params = { 181 "is_streaming": is_streaming, 182 "invocation_type": invocation_type, 183 } 184 # provider_call_duration_ms is only updated by send_request() 185 # (non-streaming); send_stream_request() never sets it, so 186 # timing fields would always be 0 for streaming responses. 187 if not is_streaming: 188 params["provider_duration_ms"] = provider_duration 189 params["gateway_overhead_ms"] = max(0, duration_ms - provider_duration) 190 if caller: 191 params["caller"] = caller 192 if request is not None: 193 params["has_traceparent"] = request.headers.get("traceparent") is not None 194 auth_mod = sys.modules.get("mlflow.server.auth") 195 params["auth_enabled"] = auth_mod.is_auth_enabled() if auth_mod else False 196 if endpoint_id := getattr(request.state, "endpoint_id", None): 197 params["endpoint_id"] = endpoint_id 198 # Prefer the actual provider from the response (set by 199 # BaseProvider after the call) over the endpoint config 200 # estimate, which may not reflect traffic-split/fallback. 201 actual_provider = getattr(result, "provider", None) 202 if provider := (actual_provider or getattr(request.state, "provider", None)): 203 params["provider"] = provider 204 _record_event( 205 GatewayInvocationEvent, 206 params=params, 207 success=success, 208 duration_ms=duration_ms, 209 ) 210 # Relay provider timing to the middleware via request.state. 211 # ContextVar values set in the handler task don't propagate back 212 # to the middleware task (Starlette copies the context for call_next). 213 if request is not None: 214 request.state.gateway_provider_duration_ms = int( 215 provider_call_duration_ms.get() 216 ) 217 218 return result 219 220 return wrapper 221 222 return decorator 223 224 225 def _set_gateway_telemetry_state(request: Request, endpoint_config) -> None: 226 """Set endpoint_id and provider on request.state for telemetry attribution.""" 227 request.state.endpoint_id = endpoint_config.endpoint_id 228 if endpoint_config.models: 229 primary_model = next( 230 ( 231 m 232 for m in endpoint_config.models 233 if m.linkage_type == GatewayModelLinkageType.PRIMARY 234 ), 235 endpoint_config.models[0], 236 ) 237 request.state.provider = str(primary_model.provider) 238 239 240 def _build_openai_compatible_config(model_config: "GatewayModelConfig"): 241 """Build an _OpenAICompatibleConfig for providers that use the OpenAI API format.""" 242 auth_config = model_config.auth_config or {} 243 return _OpenAICompatibleConfig( 244 api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY), 245 api_base=auth_config.get(_AuthConfigKey.API_BASE), 246 ) 247 248 249 def _build_endpoint_config( 250 endpoint_name: str, 251 model_config: GatewayModelConfig, 252 endpoint_type: EndpointType, 253 ) -> EndpointConfig: 254 """ 255 Build an EndpointConfig from model configuration. 256 257 This function combines provider config building and endpoint config building 258 into a single operation. 259 260 Args: 261 endpoint_name: The endpoint name. 262 model_config: The model configuration object with decrypted secrets. 263 endpoint_type: Endpoint type (chat or embeddings). 264 265 Returns: 266 EndpointConfig instance ready for provider instantiation. 267 268 Raises: 269 MlflowException: If provider configuration is invalid. 270 """ 271 provider_name = model_config.provider 272 if not is_provider_allowed(provider_name): 273 _logger.debug( 274 "Provider '%s' blocked by MLFLOW_GATEWAY_ALLOWED_PROVIDERS", 275 provider_name, 276 ) 277 raise MlflowException.invalid_parameter_value( 278 f"Provider '{provider_name}' is not allowed by the current gateway provider policy." 279 ) 280 281 provider_config = None 282 283 if model_config.provider == Provider.OPENAI: 284 auth_config = model_config.auth_config or {} 285 openai_config = { 286 "openai_api_key": model_config.secret_value.get(_AuthConfigKey.API_KEY), 287 } 288 289 # Check if this is Azure OpenAI (requires api_type, deployment_name, api_base, api_version) 290 if "api_type" in auth_config and auth_config["api_type"] in ("azure", "azuread"): 291 openai_config["openai_api_type"] = auth_config["api_type"] 292 openai_config["openai_api_base"] = auth_config.get(_AuthConfigKey.API_BASE) 293 openai_config["openai_deployment_name"] = auth_config.get("deployment_name") 294 openai_config["openai_api_version"] = auth_config.get("api_version") 295 else: 296 # Standard OpenAI 297 if _AuthConfigKey.API_BASE in auth_config: 298 openai_config["openai_api_base"] = auth_config[_AuthConfigKey.API_BASE] 299 if "organization" in auth_config: 300 openai_config["openai_organization"] = auth_config["organization"] 301 302 provider_config = OpenAIConfig(**openai_config) 303 elif model_config.provider == Provider.AZURE: 304 auth_config = model_config.auth_config or {} 305 model_config.provider = Provider.OPENAI 306 provider_config = OpenAIConfig( 307 openai_api_type=OpenAIAPIType.AZURE, 308 openai_api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY), 309 openai_api_base=auth_config.get(_AuthConfigKey.API_BASE), 310 openai_deployment_name=model_config.model_name, 311 openai_api_version=auth_config.get("api_version"), 312 ) 313 elif model_config.provider == Provider.ANTHROPIC: 314 anthropic_config = { 315 "anthropic_api_key": model_config.secret_value.get(_AuthConfigKey.API_KEY), 316 } 317 if model_config.auth_config and "version" in model_config.auth_config: 318 anthropic_config["anthropic_version"] = model_config.auth_config["version"] 319 provider_config = AnthropicConfig(**anthropic_config) 320 elif model_config.provider == Provider.MISTRAL: 321 provider_config = MistralConfig( 322 mistral_api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY), 323 ) 324 elif model_config.provider == Provider.GEMINI: 325 provider_config = GeminiConfig( 326 gemini_api_key=model_config.secret_value.get(_AuthConfigKey.API_KEY), 327 ) 328 elif model_config.provider in { 329 Provider.GROQ, 330 Provider.DEEPSEEK, 331 Provider.XAI, 332 Provider.OPENROUTER, 333 Provider.OLLAMA, 334 Provider.PORTKEY, 335 }: 336 provider_config = _build_openai_compatible_config(model_config) 337 elif normalize_provider_name(model_config.provider) == Provider.DATABRICKS: 338 from mlflow.gateway.providers.databricks import DatabricksConfig 339 340 auth_config = model_config.auth_config or {} 341 auth_mode = auth_config.get(_AuthConfigKey.AUTH_MODE, "pat_token") 342 config_kwargs = {} 343 if api_base := auth_config.get(_AuthConfigKey.API_BASE): 344 config_kwargs["host"] = api_base 345 if auth_mode == "oauth_m2m": 346 config_kwargs["client_id"] = auth_config.get("client_id") 347 config_kwargs["client_secret"] = model_config.secret_value.get("client_secret") 348 else: 349 config_kwargs["token"] = model_config.secret_value.get(_AuthConfigKey.API_KEY) 350 provider_config = DatabricksConfig(**config_kwargs) 351 model_config.provider = Provider.DATABRICKS 352 elif normalize_provider_name(model_config.provider) == Provider.BEDROCK: 353 auth_config = model_config.auth_config or {} 354 auth_mode = auth_config.get(_AuthConfigKey.AUTH_MODE, "api_key") 355 if auth_mode == "api_key": 356 # Bearer token auth — bypasses boto3 SigV4 357 provider_config = AmazonBedrockConfig( 358 aws_config={ 359 "aws_bearer_token": model_config.secret_value.get(_AuthConfigKey.API_KEY), 360 "aws_region": auth_config.get("aws_region_name"), 361 } 362 ) 363 elif auth_mode == "access_keys": 364 provider_config = AmazonBedrockConfig( 365 aws_config={ 366 "aws_access_key_id": model_config.secret_value.get("aws_access_key_id"), 367 "aws_secret_access_key": model_config.secret_value.get("aws_secret_access_key"), 368 "aws_region": auth_config.get("aws_region_name"), 369 } 370 ) 371 elif auth_mode == "iam_role": 372 provider_config = AmazonBedrockConfig( 373 aws_config={ 374 "aws_role_arn": auth_config.get("aws_role_name"), 375 "aws_region": auth_config.get("aws_region_name"), 376 } 377 ) 378 else: 379 # default_chain — boto3 resolves credentials from the 380 # environment (env vars, ~/.aws/credentials, instance profile, etc.) 381 aws_config = {"aws_region": auth_config.get("aws_region_name")} 382 if role_arn := auth_config.get("aws_role_name"): 383 aws_config["aws_role_arn"] = role_arn 384 provider_config = AmazonBedrockConfig(aws_config=aws_config) 385 model_config.provider = Provider.BEDROCK 386 elif model_config.provider == Provider.VERTEX_AI: 387 auth_config = model_config.auth_config or {} 388 provider_config = VertexAIConfig( 389 vertex_project=auth_config.get("vertex_project"), 390 vertex_location=auth_config.get("vertex_location"), 391 vertex_credentials=model_config.secret_value.get("vertex_credentials"), 392 ) 393 else: 394 # Use LiteLLM as fallback for unsupported providers 395 # Store the original provider name for LiteLLM's provider/model format 396 original_provider = model_config.provider 397 auth_config = model_config.auth_config or {} 398 # Merge auth_config with secret_value (secret_value contains api_key and other secrets) 399 litellm_config = { 400 "litellm_provider": original_provider, 401 "litellm_auth_config": auth_config | model_config.secret_value, 402 } 403 provider_config = LiteLLMConfig(**litellm_config) 404 model_config.provider = Provider.LITELLM 405 406 # Build and return EndpointConfig 407 return EndpointConfig( 408 name=endpoint_name, 409 endpoint_type=endpoint_type, 410 model={ 411 "name": model_config.model_name, 412 "provider": model_config.provider, 413 "config": provider_config.model_dump(), 414 }, 415 ) 416 417 418 def _create_provider( 419 endpoint_config: GatewayEndpointConfig, 420 endpoint_type: EndpointType, 421 enable_tracing: bool = False, 422 ) -> BaseProvider: 423 """ 424 Create a provider instance based on endpoint routing strategy. 425 426 Fallback is independent of routing strategy - if fallback_config is present, 427 the provider is wrapped with FallbackProvider. 428 429 Args: 430 endpoint_config: The endpoint configuration with model details and routing config. 431 endpoint_type: Endpoint type (chat or embeddings). 432 433 Returns: 434 Provider instance (standard provider, TrafficRouteProvider, or FallbackProvider). 435 436 Raises: 437 MlflowException: If endpoint configuration is invalid or has no models. 438 """ 439 # Get PRIMARY models 440 primary_models = [ 441 model 442 for model in endpoint_config.models 443 if model.linkage_type == GatewayModelLinkageType.PRIMARY 444 ] 445 446 if not primary_models: 447 raise MlflowException( 448 f"Endpoint '{endpoint_config.endpoint_name}' has no PRIMARY models configured", 449 error_code=RESOURCE_DOES_NOT_EXIST, 450 ) 451 452 # Create base provider based on routing strategy 453 if endpoint_config.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT: 454 # Traffic split: distribute requests based on weights 455 configs = [] 456 weights = [] 457 for model_config in primary_models: 458 gateway_endpoint_config = _build_endpoint_config( 459 endpoint_name=endpoint_config.endpoint_name, 460 model_config=model_config, 461 endpoint_type=endpoint_type, 462 ) 463 configs.append(gateway_endpoint_config) 464 weights.append(int(model_config.weight * 100)) # Convert to percentage 465 466 primary_provider = TrafficRouteProvider( 467 configs=configs, 468 traffic_splits=weights, 469 routing_strategy="TRAFFIC_SPLIT", 470 enable_tracing=enable_tracing, 471 ) 472 else: 473 # Default: use the first PRIMARY model 474 model_config = primary_models[0] 475 gateway_endpoint_config = _build_endpoint_config( 476 endpoint_config.endpoint_name, model_config, endpoint_type 477 ) 478 provider_class = get_provider(model_config.provider) 479 primary_provider = provider_class(gateway_endpoint_config, enable_tracing=enable_tracing) 480 481 # Wrap with FallbackProvider if fallback configuration exists 482 if endpoint_config.fallback_config: 483 fallback_models = [ 484 model 485 for model in endpoint_config.models 486 if model.linkage_type == GatewayModelLinkageType.FALLBACK 487 ] 488 489 if not fallback_models: 490 _logger.debug( 491 f"Endpoint '{endpoint_config.endpoint_name}' has fallback_config " 492 "but no FALLBACK models configured" 493 ) 494 return primary_provider 495 496 # Sort fallback models by fallback_order 497 fallback_models.sort( 498 key=lambda m: m.fallback_order if m.fallback_order is not None else float("inf") 499 ) 500 501 fallback_providers = [ 502 get_provider(model_config.provider)( 503 _build_endpoint_config( 504 endpoint_name=endpoint_config.endpoint_name, 505 model_config=model_config, 506 endpoint_type=endpoint_type, 507 ), 508 enable_tracing=enable_tracing, 509 ) 510 for model_config in fallback_models 511 ] 512 513 max_attempts = endpoint_config.fallback_config.max_attempts or len(fallback_models) 514 515 # FallbackProvider expects all providers (primary + fallback) 516 all_providers = [primary_provider] + fallback_providers 517 518 return FallbackProvider( 519 providers=all_providers, 520 max_attempts=max_attempts + 1, # +1 to include primary 521 strategy=endpoint_config.fallback_config.strategy, 522 enable_tracing=enable_tracing, 523 ) 524 525 return primary_provider 526 527 528 def _create_provider_from_endpoint_name( 529 store: SqlAlchemyStore, 530 endpoint_name: str, 531 endpoint_type: EndpointType, 532 enable_tracing: bool = True, 533 ) -> tuple[BaseProvider, GatewayEndpointConfig]: 534 """ 535 Create a provider from an endpoint name. 536 537 Args: 538 store: The SQLAlchemy store instance. 539 endpoint_name: The endpoint name. 540 endpoint_type: Endpoint type (chat or embeddings). 541 enable_tracing: If True, enables MLflow tracing for provider calls. 542 543 Returns: 544 Tuple of (provider instance, endpoint config) 545 """ 546 endpoint_config = get_endpoint_config(endpoint_name=endpoint_name, store=store) 547 return _create_provider( 548 endpoint_config, endpoint_type, enable_tracing=enable_tracing 549 ), endpoint_config 550 551 552 def _validate_store(store: AbstractStore) -> None: 553 if not isinstance(store, SqlAlchemyStore): 554 raise HTTPException( 555 status_code=500, 556 detail="Gateway endpoints are only available with SqlAlchemyStore, " 557 f"got {type(store).__name__}.", 558 ) 559 560 561 def _extract_endpoint_name_from_model(body: dict[str, Any]) -> str: 562 """ 563 Extract and validate the endpoint name from the 'model' parameter in the request body. 564 565 Args: 566 body: The request body dictionary 567 568 Returns: 569 The endpoint name extracted from the 'model' parameter 570 571 Raises: 572 HTTPException: If the 'model' parameter is missing 573 """ 574 endpoint_name = body.get("model") 575 if not endpoint_name: 576 raise HTTPException( 577 status_code=400, 578 detail="Missing required 'model' parameter in request body", 579 ) 580 return endpoint_name 581 582 583 def _get_guardrails_and_auth( 584 store, endpoint_config, request: Request 585 ) -> tuple[list[JudgeGuardrail], dict[str, str]]: 586 """Load guardrails and extract auth headers, skipping guardrails for internal bypass calls.""" 587 headers = dict(request.headers) 588 bypass = headers.get(_SANITIZE_BYPASS_HEADER) == "1" 589 guardrails = [] if bypass else load_guardrails(store, endpoint_config, request) 590 return guardrails, extract_auth_headers(headers) 591 592 593 @gateway_router.post("/{endpoint_name}/mlflow/invocations", response_model=None) 594 @translate_http_exception 595 @_record_gateway_invocation(GatewayInvocationType.MLFLOW_INVOCATIONS) 596 async def invocations(endpoint_name: str, request: Request): 597 """ 598 Unified invocations endpoint handler that supports both chat and embeddings. 599 600 The handler automatically detects the request type based on the payload structure: 601 - If payload has "messages" field -> chat endpoint 602 - If payload has "input" field -> embeddings endpoint 603 """ 604 body = await _get_request_body(request) 605 user_metadata = _get_user_metadata(request) 606 headers = dict(request.headers) 607 608 store = _get_store() 609 workspace = get_request_workspace() 610 611 _validate_store(store) 612 endpoint_config = get_endpoint_config(endpoint_name=endpoint_name, store=store) 613 _set_gateway_telemetry_state(request, endpoint_config) 614 check_budget_limit(store, endpoint_config, workspace=workspace) 615 guardrails, auth_headers = _get_guardrails_and_auth(store, endpoint_config, request) 616 617 # Detect request type based on payload structure 618 if "messages" in body: 619 # Chat request 620 endpoint_type = EndpointType.LLM_V1_CHAT 621 try: 622 payload = chat.RequestPayload(**body) 623 except Exception as e: 624 raise HTTPException(status_code=400, detail=f"Invalid chat payload: {e!s}") 625 626 provider, endpoint_config = _create_provider_from_endpoint_name( 627 store, endpoint_name, endpoint_type 628 ) 629 630 if payload.stream: 631 # Post-LLM guardrails are not applied to streaming responses. 632 # Pre-LLM guardrails run inside the trace as child spans; violations 633 # are surfaced as SSE error chunks via safe_stream. 634 async def _guarded_stream( 635 payload: chat.RequestPayload, 636 ): 637 request_dict = await run_pre_llm_guardrails( 638 guardrails, 639 payload.model_dump(), 640 auth_headers=auth_headers, 641 usage_tracking=endpoint_config.usage_tracking, 642 ) 643 async for chunk in provider.chat_stream(chat.RequestPayload(**request_dict)): 644 yield chunk 645 646 stream = maybe_traced_gateway_call( 647 _guarded_stream, 648 endpoint_config, 649 user_metadata, 650 output_reducer=aggregate_chat_stream_chunks, 651 request_headers=headers, 652 request_type=GatewayRequestType.UNIFIED_CHAT, 653 on_complete=make_budget_on_complete(store, workspace), 654 )(payload) 655 return StreamingResponse( 656 safe_stream(to_sse_chunk(chunk.model_dump_json()) async for chunk in stream), 657 media_type="text/event-stream", 658 ) 659 else: 660 661 async def _guarded_chat( 662 payload: chat.RequestPayload, 663 ) -> chat.ResponsePayload: 664 request_dict = await run_pre_llm_guardrails( 665 guardrails, 666 payload.model_dump(), 667 auth_headers=auth_headers, 668 usage_tracking=endpoint_config.usage_tracking, 669 ) 670 modified_payload = chat.RequestPayload(**request_dict) 671 response = await provider.chat(modified_payload) 672 return await run_post_llm_guardrails( 673 guardrails, 674 request_dict, 675 response, 676 auth_headers=auth_headers, 677 usage_tracking=endpoint_config.usage_tracking, 678 ) 679 680 try: 681 return await maybe_traced_gateway_call( 682 _guarded_chat, 683 endpoint_config, 684 user_metadata, 685 request_headers=headers, 686 request_type=GatewayRequestType.UNIFIED_CHAT, 687 on_complete=make_budget_on_complete(store, workspace), 688 )(payload) 689 except GuardrailViolation as e: 690 raise HTTPException(status_code=400, detail=str(e)) 691 692 elif "input" in body: 693 # Embeddings request 694 endpoint_type = EndpointType.LLM_V1_EMBEDDINGS 695 try: 696 payload = embeddings.RequestPayload(**body) 697 except Exception as e: 698 raise HTTPException(status_code=400, detail=f"Invalid embeddings payload: {e!s}") 699 700 provider, endpoint_config = _create_provider_from_endpoint_name( 701 store, endpoint_name, endpoint_type 702 ) 703 704 return await maybe_traced_gateway_call( 705 provider.embeddings, 706 endpoint_config, 707 user_metadata, 708 request_headers=headers, 709 request_type=GatewayRequestType.UNIFIED_EMBEDDINGS, 710 on_complete=make_budget_on_complete(store, workspace), 711 )(payload) 712 713 else: 714 raise HTTPException( 715 status_code=400, 716 detail="Invalid request: payload format must be either chat or embeddings", 717 ) 718 719 720 @gateway_router.post("/mlflow/v1/chat/completions", response_model=None) 721 @translate_http_exception 722 @_record_gateway_invocation(GatewayInvocationType.MLFLOW_CHAT_COMPLETIONS) 723 async def chat_completions(request: Request): 724 """ 725 OpenAI-compatible chat completions endpoint. 726 727 This endpoint follows the OpenAI API format where the endpoint name is specified 728 via the "model" parameter in the request body, allowing clients to use the 729 standard OpenAI SDK. 730 731 Example: 732 POST /gateway/mlflow/v1/chat/completions 733 { 734 "model": "my-endpoint-name", 735 "messages": [{"role": "user", "content": "Hello"}] 736 } 737 """ 738 body = await _get_request_body(request) 739 user_metadata = _get_user_metadata(request) 740 headers = dict(request.headers) 741 742 # Extract endpoint name from "model" parameter 743 endpoint_name = _extract_endpoint_name_from_model(body) 744 body.pop("model") 745 746 store = _get_store() 747 workspace = get_request_workspace() 748 749 _validate_store(store) 750 provider, endpoint_config = _create_provider_from_endpoint_name( 751 store, endpoint_name, EndpointType.LLM_V1_CHAT 752 ) 753 _set_gateway_telemetry_state(request, endpoint_config) 754 check_budget_limit(store, endpoint_config, workspace=workspace) 755 guardrails, auth_headers = _get_guardrails_and_auth(store, endpoint_config, request) 756 757 try: 758 payload = chat.RequestPayload(**body) 759 except Exception as e: 760 raise HTTPException(status_code=400, detail=f"Invalid chat payload: {e!s}") 761 762 if payload.stream: 763 # Post-LLM guardrails are not applied to streaming responses. 764 # Pre-LLM guardrails run inside the trace as child spans; violations 765 # are surfaced as SSE error chunks via safe_stream. 766 async def _guarded_stream( 767 payload: chat.RequestPayload, 768 ): 769 request_dict = await run_pre_llm_guardrails( 770 guardrails, 771 payload.model_dump(), 772 auth_headers=auth_headers, 773 usage_tracking=endpoint_config.usage_tracking, 774 ) 775 async for chunk in provider.chat_stream(chat.RequestPayload(**request_dict)): 776 yield chunk 777 778 stream = maybe_traced_gateway_call( 779 _guarded_stream, 780 endpoint_config, 781 user_metadata, 782 output_reducer=aggregate_chat_stream_chunks, 783 request_headers=headers, 784 request_type=GatewayRequestType.UNIFIED_CHAT, 785 on_complete=make_budget_on_complete(store, workspace), 786 )(payload) 787 return StreamingResponse( 788 safe_stream(to_sse_chunk(chunk.model_dump_json()) async for chunk in stream), 789 media_type="text/event-stream", 790 ) 791 else: 792 793 async def _guarded_chat( 794 payload: chat.RequestPayload, 795 ) -> chat.ResponsePayload: 796 request_dict = await run_pre_llm_guardrails( 797 guardrails, 798 payload.model_dump(), 799 auth_headers=auth_headers, 800 usage_tracking=endpoint_config.usage_tracking, 801 ) 802 modified_payload = chat.RequestPayload(**request_dict) 803 response = await provider.chat(modified_payload) 804 return await run_post_llm_guardrails( 805 guardrails, 806 request_dict, 807 response, 808 auth_headers=auth_headers, 809 usage_tracking=endpoint_config.usage_tracking, 810 ) 811 812 try: 813 return await maybe_traced_gateway_call( 814 _guarded_chat, 815 endpoint_config, 816 user_metadata, 817 request_headers=headers, 818 request_type=GatewayRequestType.UNIFIED_CHAT, 819 on_complete=make_budget_on_complete(store, workspace), 820 )(payload) 821 except GuardrailViolation as e: 822 raise HTTPException(status_code=400, detail=str(e)) 823 824 825 @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.OPENAI_CHAT], response_model=None) 826 @translate_http_exception 827 @_record_gateway_invocation(GatewayInvocationType.OPENAI_PASSTHROUGH_CHAT) 828 async def openai_passthrough_chat(request: Request): 829 """ 830 OpenAI passthrough endpoint for chat completions. 831 832 This endpoint accepts raw OpenAI API format and passes it through to the 833 OpenAI provider with the configured API key and model. The 'model' parameter 834 in the request specifies which MLflow endpoint to use. 835 836 Supports streaming responses when the 'stream' parameter is set to true. 837 838 Example: 839 POST /gateway/openai/v1/chat/completions 840 { 841 "model": "my-openai-endpoint", 842 "messages": [{"role": "user", "content": "Hello"}], 843 "temperature": 0.7, 844 "stream": true 845 } 846 """ 847 body = await _get_request_body(request) 848 user_metadata = _get_user_metadata(request) 849 850 endpoint_name = _extract_endpoint_name_from_model(body) 851 body.pop("model") 852 store = _get_store() 853 workspace = get_request_workspace() 854 _validate_store(store) 855 headers = dict(request.headers) 856 provider, endpoint_config = _create_provider_from_endpoint_name( 857 store, endpoint_name, EndpointType.LLM_V1_CHAT 858 ) 859 _set_gateway_telemetry_state(request, endpoint_config) 860 check_budget_limit(store, endpoint_config, workspace=workspace) 861 862 if body.get("stream", False): 863 stream = await provider.passthrough( 864 action=PassthroughAction.OPENAI_CHAT, payload=body, headers=headers 865 ) 866 867 # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks 868 async def yield_stream(body: dict[str, Any]): 869 async for chunk in stream: 870 yield chunk 871 872 traced_stream = maybe_traced_gateway_call( 873 yield_stream, 874 endpoint_config, 875 user_metadata, 876 request_headers=headers, 877 request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_CHAT, 878 on_complete=make_budget_on_complete(store, workspace), 879 ) 880 return StreamingResponse( 881 safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream" 882 ) 883 884 traced_passthrough = maybe_traced_gateway_call( 885 provider.passthrough, 886 endpoint_config, 887 user_metadata, 888 request_headers=headers, 889 request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_CHAT, 890 on_complete=make_budget_on_complete(store, workspace), 891 ) 892 return await traced_passthrough( 893 action=PassthroughAction.OPENAI_CHAT, payload=body, headers=headers 894 ) 895 896 897 @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.OPENAI_EMBEDDINGS], response_model=None) 898 @translate_http_exception 899 @_record_gateway_invocation(GatewayInvocationType.OPENAI_PASSTHROUGH_EMBEDDINGS) 900 async def openai_passthrough_embeddings(request: Request): 901 """ 902 OpenAI passthrough endpoint for embeddings. 903 904 This endpoint accepts raw OpenAI API format and passes it through to the 905 OpenAI provider with the configured API key and model. The 'model' parameter 906 in the request specifies which MLflow endpoint to use. 907 908 Example: 909 POST /gateway/openai/v1/embeddings 910 { 911 "model": "my-openai-endpoint", 912 "input": "The food was delicious and the waiter..." 913 } 914 """ 915 body = await _get_request_body(request) 916 user_metadata = _get_user_metadata(request) 917 918 endpoint_name = _extract_endpoint_name_from_model(body) 919 body.pop("model") 920 store = _get_store() 921 workspace = get_request_workspace() 922 _validate_store(store) 923 headers = dict(request.headers) 924 provider, endpoint_config = _create_provider_from_endpoint_name( 925 store, endpoint_name, EndpointType.LLM_V1_EMBEDDINGS 926 ) 927 _set_gateway_telemetry_state(request, endpoint_config) 928 check_budget_limit(store, endpoint_config, workspace=workspace) 929 930 traced_passthrough = maybe_traced_gateway_call( 931 provider.passthrough, 932 endpoint_config, 933 user_metadata, 934 request_headers=headers, 935 request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_EMBEDDINGS, 936 on_complete=make_budget_on_complete(store, workspace), 937 ) 938 return await traced_passthrough( 939 action=PassthroughAction.OPENAI_EMBEDDINGS, payload=body, headers=headers 940 ) 941 942 943 @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.OPENAI_RESPONSES], response_model=None) 944 @translate_http_exception 945 @_record_gateway_invocation(GatewayInvocationType.OPENAI_PASSTHROUGH_RESPONSES) 946 async def openai_passthrough_responses(request: Request): 947 """ 948 OpenAI passthrough endpoint for the Responses API. 949 950 This endpoint accepts raw OpenAI Responses API format and passes it through to the 951 OpenAI provider with the configured API key and model. The 'model' parameter 952 in the request specifies which MLflow endpoint to use. 953 954 Supports streaming responses when the 'stream' parameter is set to true. 955 956 Example: 957 POST /gateway/openai/v1/responses 958 { 959 "model": "my-openai-endpoint", 960 "input": [{"type": "text", "text": "Hello"}], 961 "instructions": "You are a helpful assistant", 962 "stream": true 963 } 964 """ 965 body = await _get_request_body(request) 966 user_metadata = _get_user_metadata(request) 967 968 endpoint_name = _extract_endpoint_name_from_model(body) 969 body.pop("model") 970 store = _get_store() 971 workspace = get_request_workspace() 972 _validate_store(store) 973 headers = dict(request.headers) 974 provider, endpoint_config = _create_provider_from_endpoint_name( 975 store, endpoint_name, EndpointType.LLM_V1_CHAT 976 ) 977 _set_gateway_telemetry_state(request, endpoint_config) 978 check_budget_limit(store, endpoint_config, workspace=workspace) 979 980 if body.get("stream", False): 981 stream = await provider.passthrough( 982 action=PassthroughAction.OPENAI_RESPONSES, payload=body, headers=headers 983 ) 984 985 # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks 986 async def yield_stream(body: dict[str, Any]): 987 async for chunk in stream: 988 yield chunk 989 990 traced_stream = maybe_traced_gateway_call( 991 yield_stream, 992 endpoint_config, 993 user_metadata, 994 output_reducer=aggregate_openai_responses_stream_chunks, 995 request_headers=headers, 996 request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_RESPONSES, 997 on_complete=make_budget_on_complete(store, workspace), 998 ) 999 return StreamingResponse( 1000 safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream" 1001 ) 1002 1003 traced_passthrough = maybe_traced_gateway_call( 1004 provider.passthrough, 1005 endpoint_config, 1006 user_metadata, 1007 request_headers=headers, 1008 request_type=GatewayRequestType.PASSTHROUGH_MODEL_OPENAI_RESPONSES, 1009 on_complete=make_budget_on_complete(store, workspace), 1010 ) 1011 return await traced_passthrough( 1012 action=PassthroughAction.OPENAI_RESPONSES, payload=body, headers=headers 1013 ) 1014 1015 1016 @gateway_router.post(PASSTHROUGH_ROUTES[PassthroughAction.ANTHROPIC_MESSAGES], response_model=None) 1017 @translate_http_exception 1018 @_record_gateway_invocation(GatewayInvocationType.ANTHROPIC_PASSTHROUGH_MESSAGES) 1019 async def anthropic_passthrough_messages(request: Request): 1020 """ 1021 Anthropic passthrough endpoint for the Messages API. 1022 1023 This endpoint accepts raw Anthropic API format and passes it through to the 1024 Anthropic provider with the configured API key and model. The 'model' parameter 1025 in the request specifies which MLflow endpoint to use. 1026 1027 Supports streaming responses when the 'stream' parameter is set to true. 1028 1029 Example: 1030 POST /gateway/anthropic/v1/messages 1031 { 1032 "model": "my-anthropic-endpoint", 1033 "messages": [{"role": "user", "content": "Hello"}], 1034 "max_tokens": 1024, 1035 "stream": true 1036 } 1037 """ 1038 body = await _get_request_body(request) 1039 user_metadata = _get_user_metadata(request) 1040 1041 endpoint_name = _extract_endpoint_name_from_model(body) 1042 body.pop("model") 1043 store = _get_store() 1044 workspace = get_request_workspace() 1045 _validate_store(store) 1046 headers = dict(request.headers) 1047 provider, endpoint_config = _create_provider_from_endpoint_name( 1048 store, endpoint_name, EndpointType.LLM_V1_CHAT 1049 ) 1050 _set_gateway_telemetry_state(request, endpoint_config) 1051 check_budget_limit(store, endpoint_config, workspace=workspace) 1052 1053 if body.get("stream", False): 1054 stream = await provider.passthrough( 1055 action=PassthroughAction.ANTHROPIC_MESSAGES, payload=body, headers=headers 1056 ) 1057 1058 # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks 1059 async def yield_stream(body: dict[str, Any]): 1060 async for chunk in stream: 1061 yield chunk 1062 1063 traced_stream = maybe_traced_gateway_call( 1064 yield_stream, 1065 endpoint_config, 1066 user_metadata, 1067 output_reducer=aggregate_anthropic_messages_stream_chunks, 1068 request_headers=headers, 1069 request_type=GatewayRequestType.PASSTHROUGH_MODEL_ANTHROPIC_MESSAGES, 1070 on_complete=make_budget_on_complete(store, workspace), 1071 ) 1072 return StreamingResponse( 1073 safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream" 1074 ) 1075 1076 traced_passthrough = maybe_traced_gateway_call( 1077 provider.passthrough, 1078 endpoint_config, 1079 user_metadata, 1080 request_headers=headers, 1081 request_type=GatewayRequestType.PASSTHROUGH_MODEL_ANTHROPIC_MESSAGES, 1082 on_complete=make_budget_on_complete(store, workspace), 1083 ) 1084 return await traced_passthrough( 1085 action=PassthroughAction.ANTHROPIC_MESSAGES, payload=body, headers=headers 1086 ) 1087 1088 1089 @gateway_router.post( 1090 PASSTHROUGH_ROUTES[PassthroughAction.GEMINI_GENERATE_CONTENT], response_model=None 1091 ) 1092 @translate_http_exception 1093 @_record_gateway_invocation(GatewayInvocationType.GEMINI_PASSTHROUGH_GENERATE_CONTENT) 1094 async def gemini_passthrough_generate_content(endpoint_name: str, request: Request): 1095 """ 1096 Gemini passthrough endpoint for generateContent API (non-streaming). 1097 1098 This endpoint accepts raw Gemini API format and passes it through to the 1099 Gemini provider with the configured API key. The endpoint_name in the URL path 1100 specifies which MLflow endpoint to use. 1101 1102 Example: 1103 POST /gateway/gemini/v1beta/models/my-gemini-endpoint:generateContent 1104 { 1105 "contents": [ 1106 { 1107 "role": "user", 1108 "parts": [{"text": "Hello"}] 1109 } 1110 ] 1111 } 1112 """ 1113 body = await _get_request_body(request) 1114 user_metadata = _get_user_metadata(request) 1115 1116 store = _get_store() 1117 workspace = get_request_workspace() 1118 _validate_store(store) 1119 headers = dict(request.headers) 1120 provider, endpoint_config = _create_provider_from_endpoint_name( 1121 store, endpoint_name, EndpointType.LLM_V1_CHAT 1122 ) 1123 _set_gateway_telemetry_state(request, endpoint_config) 1124 check_budget_limit(store, endpoint_config, workspace=workspace) 1125 1126 traced_passthrough = maybe_traced_gateway_call( 1127 provider.passthrough, 1128 endpoint_config, 1129 user_metadata, 1130 request_headers=headers, 1131 request_type=GatewayRequestType.PASSTHROUGH_MODEL_GEMINI_GENERATE_CONTENT, 1132 on_complete=make_budget_on_complete(store, workspace), 1133 ) 1134 return await traced_passthrough( 1135 action=PassthroughAction.GEMINI_GENERATE_CONTENT, payload=body, headers=headers 1136 ) 1137 1138 1139 @gateway_router.post( 1140 PASSTHROUGH_ROUTES[PassthroughAction.GEMINI_STREAM_GENERATE_CONTENT], response_model=None 1141 ) 1142 @translate_http_exception 1143 @_record_gateway_invocation(GatewayInvocationType.GEMINI_PASSTHROUGH_STREAM_GENERATE_CONTENT) 1144 async def gemini_passthrough_stream_generate_content(endpoint_name: str, request: Request): 1145 """ 1146 Gemini passthrough endpoint for streamGenerateContent API (streaming). 1147 1148 This endpoint accepts raw Gemini API format and passes it through to the 1149 Gemini provider with the configured API key. The endpoint_name in the URL path 1150 specifies which MLflow endpoint to use. 1151 1152 Example: 1153 POST /gateway/gemini/v1beta/models/my-gemini-endpoint:streamGenerateContent 1154 { 1155 "contents": [ 1156 { 1157 "role": "user", 1158 "parts": [{"text": "Hello"}] 1159 } 1160 ] 1161 } 1162 """ 1163 body = await _get_request_body(request) 1164 user_metadata = _get_user_metadata(request) 1165 1166 store = _get_store() 1167 workspace = get_request_workspace() 1168 _validate_store(store) 1169 headers = dict(request.headers) 1170 provider, endpoint_config = _create_provider_from_endpoint_name( 1171 store, endpoint_name, EndpointType.LLM_V1_CHAT 1172 ) 1173 _set_gateway_telemetry_state(request, endpoint_config) 1174 check_budget_limit(store, endpoint_config, workspace=workspace) 1175 1176 stream = await provider.passthrough( 1177 action=PassthroughAction.GEMINI_STREAM_GENERATE_CONTENT, payload=body, headers=headers 1178 ) 1179 1180 # Wrap stream iteration in an async generator so @mlflow.trace properly captures chunks 1181 async def yield_stream(body: dict[str, Any]): 1182 async for chunk in stream: 1183 yield chunk 1184 1185 traced_stream = maybe_traced_gateway_call( 1186 yield_stream, 1187 endpoint_config, 1188 user_metadata, 1189 output_reducer=aggregate_gemini_stream_generate_content_chunks, 1190 request_headers=headers, 1191 request_type=GatewayRequestType.PASSTHROUGH_MODEL_GEMINI_GENERATE_CONTENT, 1192 on_complete=make_budget_on_complete(store, workspace), 1193 ) 1194 return StreamingResponse( 1195 safe_stream(traced_stream(body), as_bytes=True), media_type="text/event-stream" 1196 )