handlers.py
1 # Define all the service endpoint handlers here. 2 import io 3 import json 4 import logging 5 import os 6 import pathlib 7 import posixpath 8 import re 9 import tempfile 10 import threading 11 import time 12 import urllib 13 from functools import partial, wraps 14 from typing import Any, Callable 15 16 import requests 17 from cachetools import TTLCache 18 from flask import Request, Response, current_app, jsonify, request, send_file 19 from google.protobuf import descriptor 20 from google.protobuf.json_format import ParseError 21 22 import mlflow 23 from mlflow.entities import ( 24 Assessment, 25 DatasetInput, 26 Expectation, 27 ExperimentTag, 28 FallbackConfig, 29 FallbackStrategy, 30 Feedback, 31 FileInfo, 32 GatewayEndpointModelConfig, 33 GatewayEndpointTag, 34 GatewayResourceType, 35 InputTag, 36 IssueSeverity, 37 IssueStatus, 38 Metric, 39 Param, 40 RunStatus, 41 RunTag, 42 ViewType, 43 Workspace, 44 WorkspaceDeletionMode, 45 ) 46 from mlflow.entities import ( 47 RoutingStrategy as RoutingStrategyEntity, 48 ) 49 from mlflow.entities.gateway_budget_policy import ( 50 BudgetAction, 51 BudgetDuration, 52 BudgetDurationUnit, 53 BudgetTargetScope, 54 BudgetUnit, 55 ) 56 from mlflow.entities.logged_model import LoggedModel 57 from mlflow.entities.logged_model_input import LoggedModelInput 58 from mlflow.entities.logged_model_output import LoggedModelOutput 59 from mlflow.entities.logged_model_parameter import LoggedModelParameter 60 from mlflow.entities.logged_model_status import LoggedModelStatus 61 from mlflow.entities.logged_model_tag import LoggedModelTag 62 from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag 63 from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY 64 from mlflow.entities.multipart_upload import MultipartUploadPart 65 from mlflow.entities.trace_info import TraceInfo 66 from mlflow.entities.trace_info_v2 import TraceInfoV2 67 from mlflow.entities.trace_metrics import MetricAggregation, MetricViewType 68 from mlflow.entities.trace_status import TraceStatus 69 from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent, WebhookStatus 70 from mlflow.environment_variables import ( 71 MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX, 72 MLFLOW_DEPLOYMENTS_TARGET, 73 MLFLOW_ENABLE_WORKSPACES, 74 MLFLOW_PRESIGNED_DOWNLOAD_URL_TTL_SECONDS, 75 ) 76 from mlflow.exceptions import ( 77 MlflowException, 78 MlflowNotImplementedException, 79 MlflowTracingException, 80 _UnsupportedMultipartDownloadException, 81 _UnsupportedMultipartUploadException, 82 _UnsupportedPresignedUploadException, 83 ) 84 from mlflow.gateway.budget import maybe_refresh_budget_policies 85 from mlflow.gateway.budget_tracker import get_budget_tracker 86 from mlflow.gateway.utils import is_valid_endpoint_name 87 from mlflow.genai.scorers.scorer_utils import DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR 88 from mlflow.models import Model 89 from mlflow.prompt.constants import PROMPT_TEXT_TAG_KEY, PROMPT_TYPE_TAG_KEY 90 from mlflow.protos import databricks_pb2 91 from mlflow.protos.databricks_pb2 import ( 92 BAD_REQUEST, 93 FEATURE_DISABLED, 94 INTERNAL_ERROR, 95 INVALID_PARAMETER_VALUE, 96 INVALID_STATE, 97 RESOURCE_DOES_NOT_EXIST, 98 ) 99 from mlflow.protos.issues_pb2 import ( 100 CreateIssue, 101 GetIssue, 102 SearchIssues, 103 UpdateIssue, 104 ) 105 from mlflow.protos.jobs_pb2 import JobStatus 106 from mlflow.protos.mlflow_artifacts_pb2 import ( 107 AbortMultipartUpload, 108 CompleteMultipartUpload, 109 CreateMultipartUpload, 110 DeleteArtifact, 111 DownloadArtifact, 112 GetPresignedDownloadUrl, 113 MlflowArtifactsService, 114 UploadArtifact, 115 ) 116 from mlflow.protos.mlflow_artifacts_pb2 import ( 117 ListArtifacts as ListArtifactsMlflowArtifacts, 118 ) 119 from mlflow.protos.model_registry_pb2 import ( 120 CreateModelVersion, 121 CreateRegisteredModel, 122 DeleteModelVersion, 123 DeleteModelVersionTag, 124 DeleteRegisteredModel, 125 DeleteRegisteredModelAlias, 126 DeleteRegisteredModelTag, 127 GetLatestVersions, 128 GetModelVersion, 129 GetModelVersionByAlias, 130 GetModelVersionDownloadUri, 131 GetRegisteredModel, 132 ModelRegistryService, 133 RenameRegisteredModel, 134 SearchModelVersions, 135 SearchRegisteredModels, 136 SetModelVersionTag, 137 SetRegisteredModelAlias, 138 SetRegisteredModelTag, 139 TransitionModelVersionStage, 140 UpdateModelVersion, 141 UpdateRegisteredModel, 142 ) 143 from mlflow.protos.prompt_optimization_pb2 import ( 144 PromptOptimizationJob as PromptOptimizationJobProto, 145 ) 146 from mlflow.protos.service_pb2 import ( 147 AddDatasetToExperiments, 148 AddGuardrailToEndpoint, 149 AttachModelToGatewayEndpoint, 150 BatchGetTraceInfos, 151 BatchGetTraces, 152 CalculateTraceFilterCorrelation, 153 CancelPromptOptimizationJob, 154 CreateAssessment, 155 CreateDataset, 156 CreateExperiment, 157 CreateGatewayBudgetPolicy, 158 CreateGatewayEndpoint, 159 CreateGatewayEndpointBinding, 160 CreateGatewayGuardrail, 161 CreateGatewayModelDefinition, 162 CreateGatewaySecret, 163 CreateLoggedModel, 164 CreatePresignedUploadUrl, 165 CreatePromptOptimizationJob, 166 CreateRun, 167 CreateWorkspace, 168 DeleteAssessment, 169 DeleteDataset, 170 DeleteDatasetRecords, 171 DeleteDatasetTag, 172 DeleteExperiment, 173 DeleteExperimentTag, 174 DeleteGatewayBudgetPolicy, 175 DeleteGatewayEndpoint, 176 DeleteGatewayEndpointBinding, 177 DeleteGatewayEndpointTag, 178 DeleteGatewayGuardrail, 179 DeleteGatewayModelDefinition, 180 DeleteGatewaySecret, 181 DeleteLoggedModel, 182 DeleteLoggedModelTag, 183 DeletePromptOptimizationJob, 184 DeleteRun, 185 DeleteScorer, 186 DeleteTag, 187 DeleteTraces, 188 DeleteTracesV3, 189 DeleteTraceTag, 190 DeleteTraceTagV3, 191 DeleteWorkspace, 192 DetachModelFromGatewayEndpoint, 193 EndTrace, 194 FinalizeLoggedModel, 195 GetAssessmentRequest, 196 GetDataset, 197 GetDatasetExperimentIds, 198 GetDatasetRecords, 199 GetExperiment, 200 GetExperimentByName, 201 GetGatewayBudgetPolicy, 202 GetGatewayEndpoint, 203 GetGatewayGuardrail, 204 GetGatewayModelDefinition, 205 GetGatewaySecretInfo, 206 GetLoggedModel, 207 GetMetricHistory, 208 GetMetricHistoryBulkInterval, 209 GetPromptOptimizationJob, 210 GetRun, 211 GetScorer, 212 GetTrace, 213 GetTraceInfo, 214 GetTraceInfoV3, 215 GetWorkspace, 216 LinkPromptsToTrace, 217 LinkTracesToRun, 218 ListArtifacts, 219 ListEndpointGuardrailConfigs, 220 ListGatewayBudgetPolicies, 221 ListGatewayBudgetWindows, 222 ListGatewayEndpointBindings, 223 ListGatewayEndpoints, 224 ListGatewayGuardrails, 225 ListGatewayModelDefinitions, 226 ListGatewaySecretInfos, 227 ListLoggedModelArtifacts, 228 ListScorers, 229 ListScorerVersions, 230 ListWorkspaces, 231 LogBatch, 232 LogInputs, 233 LogLoggedModelParamsRequest, 234 LogMetric, 235 LogModel, 236 LogOutputs, 237 LogParam, 238 MlflowService, 239 QueryTraceMetrics, 240 RegisterScorer, 241 RemoveDatasetFromExperiments, 242 RemoveGuardrailFromEndpoint, 243 RestoreExperiment, 244 RestoreRun, 245 SearchDatasets, 246 SearchEvaluationDatasets, 247 SearchExperiments, 248 SearchLoggedModels, 249 SearchPromptOptimizationJobs, 250 SearchRuns, 251 SearchTraces, 252 SearchTracesV3, 253 SetDatasetTags, 254 SetExperimentTag, 255 SetGatewayEndpointTag, 256 SetLoggedModelTags, 257 SetTag, 258 SetTraceTag, 259 SetTraceTagV3, 260 StartTrace, 261 StartTraceV3, 262 UpdateAssessment, 263 UpdateEndpointGuardrailConfig, 264 UpdateExperiment, 265 UpdateGatewayBudgetPolicy, 266 UpdateGatewayEndpoint, 267 UpdateGatewayModelDefinition, 268 UpdateGatewaySecret, 269 UpdateRun, 270 UpdateWorkspace, 271 UpsertDatasetRecords, 272 ) 273 from mlflow.protos.service_pb2 import Trace as ProtoTrace 274 from mlflow.protos.webhooks_pb2 import ( 275 CreateWebhook, 276 DeleteWebhook, 277 GetWebhook, 278 ListWebhooks, 279 TestWebhook, 280 UpdateWebhook, 281 WebhookService, 282 ) 283 from mlflow.server.validation import _validate_content_type 284 from mlflow.server.workspace_helpers import ( 285 _get_workspace_store, 286 ) 287 from mlflow.store.artifact.artifact_repo import ( 288 MultipartDownloadMixin, 289 MultipartUploadMixin, 290 PresignedUploadMixin, 291 ) 292 from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository 293 from mlflow.store.db.db_types import DATABASE_ENGINES 294 from mlflow.store.jobs.abstract_store import AbstractJobStore 295 from mlflow.store.model_registry.abstract_store import AbstractStore as AbstractModelRegistryStore 296 from mlflow.store.model_registry.rest_store import RestStore as ModelRegistryRestStore 297 from mlflow.store.tracking import MAX_RESULTS_QUERY_TRACE_METRICS, SEARCH_MAX_RESULTS_DEFAULT 298 from mlflow.store.tracking.abstract_store import AbstractStore as AbstractTrackingStore 299 from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore 300 from mlflow.store.workspace.abstract_store import WorkspaceNameValidator 301 from mlflow.telemetry import get_telemetry_client 302 from mlflow.telemetry.installation_id import get_or_create_installation_id 303 from mlflow.telemetry.schemas import Record, Status 304 from mlflow.telemetry.utils import ( 305 FALLBACK_UI_CONFIG, 306 fetch_ui_telemetry_config, 307 is_telemetry_disabled, 308 ) 309 from mlflow.tracing.utils.artifact_utils import ( 310 TRACE_DATA_FILE_NAME, 311 get_artifact_uri_for_trace, 312 ) 313 from mlflow.tracking._model_registry import utils as registry_utils 314 from mlflow.tracking._model_registry.registry import ModelRegistryStoreRegistry 315 from mlflow.tracking._tracking_service import utils 316 from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry 317 from mlflow.tracking.context.default_context import _get_user 318 from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException 319 from mlflow.utils import workspace_context 320 from mlflow.utils.crypto import CRYPTO_KEK_PASSPHRASE_ENV_VAR 321 from mlflow.utils.databricks_utils import get_databricks_host_creds 322 from mlflow.utils.file_utils import local_file_uri_to_path 323 from mlflow.utils.mime_type_utils import _guess_mime_type 324 from mlflow.utils.mlflow_tags import ( 325 MLFLOW_ISSUE_DETECTION_JOB_ID, 326 MLFLOW_RUN_TYPE, 327 MLFLOW_RUN_TYPE_ISSUE_DETECTION, 328 ) 329 from mlflow.utils.promptlab_utils import _create_promptlab_run_impl 330 from mlflow.utils.proto_json_utils import message_to_json, parse_dict 331 from mlflow.utils.providers import ( 332 get_all_providers, 333 get_models, 334 get_provider_config_response, 335 ) 336 from mlflow.utils.string_utils import is_string_type 337 from mlflow.utils.time import get_current_time_millis 338 from mlflow.utils.uri import is_local_uri, validate_path_is_safe, validate_query_string 339 from mlflow.utils.validation import ( 340 _validate_batch_log_api_req, 341 _validate_experiment_artifact_location, 342 _validate_experiment_artifact_location_length, 343 invalid_value, 344 missing_value, 345 ) 346 from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME 347 from mlflow.webhooks.delivery import deliver_webhook, test_webhook 348 from mlflow.webhooks.types import ( 349 ModelVersionAliasCreatedPayload, 350 ModelVersionAliasDeletedPayload, 351 ModelVersionCreatedPayload, 352 ModelVersionTagDeletedPayload, 353 ModelVersionTagSetPayload, 354 PromptAliasCreatedPayload, 355 PromptAliasDeletedPayload, 356 PromptCreatedPayload, 357 PromptTagDeletedPayload, 358 PromptTagSetPayload, 359 PromptVersionCreatedPayload, 360 PromptVersionTagDeletedPayload, 361 PromptVersionTagSetPayload, 362 RegisteredModelCreatedPayload, 363 ) 364 365 _logger = logging.getLogger(__name__) 366 _tracking_store = None 367 _model_registry_store = None 368 _job_store = None 369 _artifact_repo = None 370 STATIC_PREFIX_ENV_VAR = "_MLFLOW_STATIC_PREFIX" 371 MAX_RUNS_GET_METRIC_HISTORY_BULK = 100 372 MAX_RESULTS_PER_RUN = 2500 373 # Chunk size for streaming artifact uploads and downloads (1 MB) 374 ARTIFACT_STREAM_CHUNK_SIZE = 1024 * 1024 375 376 377 class TrackingStoreRegistryWrapper(TrackingStoreRegistry): 378 def __init__(self): 379 super().__init__() 380 self.register("", self._get_file_store) 381 self.register("file", self._get_file_store) 382 for scheme in DATABASE_ENGINES: 383 self.register(scheme, self._get_sqlalchemy_store) 384 # Add support for Databricks tracking store 385 self.register("databricks", self._get_databricks_rest_store) 386 self.register_entrypoints() 387 388 @classmethod 389 def _get_file_store(cls, store_uri, artifact_uri): 390 from mlflow.store.tracking.file_store import FileStore 391 392 return FileStore(store_uri, artifact_uri) 393 394 @classmethod 395 def _get_sqlalchemy_store(cls, store_uri, artifact_uri): 396 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore 397 from mlflow.store.tracking.sqlalchemy_workspace_store import ( 398 WorkspaceAwareSqlAlchemyStore, 399 ) 400 401 store_cls = ( 402 WorkspaceAwareSqlAlchemyStore if MLFLOW_ENABLE_WORKSPACES.get() else SqlAlchemyStore 403 ) 404 return store_cls(store_uri, artifact_uri) 405 406 @classmethod 407 def _get_databricks_rest_store(cls, store_uri, artifact_uri): 408 return DatabricksTracingRestStore(partial(get_databricks_host_creds, store_uri)) 409 410 411 class ModelRegistryStoreRegistryWrapper(ModelRegistryStoreRegistry): 412 def __init__(self): 413 super().__init__() 414 self.register("", self._get_file_store) 415 self.register("file", self._get_file_store) 416 for scheme in DATABASE_ENGINES: 417 self.register(scheme, self._get_sqlalchemy_store) 418 # Add support for Databricks registries 419 self.register("databricks", self._get_databricks_rest_store) 420 self.register("databricks-uc", self._get_databricks_uc_rest_store) 421 self.register_entrypoints() 422 423 @classmethod 424 def _get_file_store(cls, store_uri): 425 from mlflow.store.model_registry.file_store import FileStore 426 427 return FileStore(store_uri) 428 429 @classmethod 430 def _get_sqlalchemy_store(cls, store_uri): 431 from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore 432 from mlflow.store.model_registry.sqlalchemy_workspace_store import ( 433 WorkspaceAwareSqlAlchemyStore, 434 ) 435 436 store_cls = ( 437 WorkspaceAwareSqlAlchemyStore if MLFLOW_ENABLE_WORKSPACES.get() else SqlAlchemyStore 438 ) 439 return store_cls(store_uri) 440 441 @classmethod 442 def _get_databricks_rest_store(cls, store_uri): 443 return ModelRegistryRestStore(partial(get_databricks_host_creds, store_uri)) 444 445 @classmethod 446 def _get_databricks_uc_rest_store(cls, store_uri): 447 from mlflow.environment_variables import MLFLOW_TRACKING_URI 448 from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore 449 450 # Get tracking URI from environment or use "databricks-uc" as default 451 tracking_uri = MLFLOW_TRACKING_URI.get() or "databricks-uc" 452 return UcModelRegistryStore(store_uri, tracking_uri) 453 454 455 _tracking_store_registry = TrackingStoreRegistryWrapper() 456 _model_registry_store_registry = ModelRegistryStoreRegistryWrapper() 457 458 459 def _get_artifact_repo_mlflow_artifacts(): 460 """ 461 Get an artifact repository specified by ``--artifacts-destination`` option for ``mlflow server`` 462 command. 463 """ 464 from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR 465 466 global _artifact_repo 467 if _artifact_repo is None: 468 _artifact_repo = get_artifact_repository(os.environ[ARTIFACTS_DESTINATION_ENV_VAR]) 469 return _artifact_repo 470 471 472 def _get_trace_artifact_repo(trace_info: TraceInfo): 473 """ 474 Resolve the artifact repository for fetching data for the given trace. 475 476 Args: 477 trace_info: The trace info object containing metadata about the trace. 478 """ 479 artifact_uri = get_artifact_uri_for_trace(trace_info) 480 481 if _is_servable_proxied_run_artifact_root(artifact_uri): 482 # If the artifact location is a proxied run artifact root (e.g. mlflow-artifacts://...), 483 # we need to resolve it to the actual artifact location. 484 from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR 485 486 path = _get_proxied_run_artifact_destination_path(artifact_uri) 487 if not path: 488 raise MlflowException( 489 f"Failed to resolve the proxied run artifact URI: {artifact_uri}. ", 490 "Trace artifact URI must contain subpath to the trace data directory.", 491 error_code=BAD_REQUEST, 492 ) 493 root = os.environ[ARTIFACTS_DESTINATION_ENV_VAR] 494 artifact_uri = posixpath.join(root, path) 495 496 # We don't set it to global var unlike run artifact, because the artifact repo has 497 # to be created with full trace artifact URI including request_id. 498 # e.g. s3://<experiment_id>/traces/<request_id> 499 artifact_repo = get_artifact_repository(artifact_uri) 500 else: 501 artifact_repo = get_artifact_repository(artifact_uri) 502 return artifact_repo 503 504 505 def _is_serving_proxied_artifacts(): 506 """ 507 Returns: 508 True if the MLflow server is serving proxied artifacts (i.e. acting as a proxy for 509 artifact upload / download / list operations), as would be enabled by specifying the 510 --serve-artifacts configuration option. False otherwise. 511 """ 512 from mlflow.server import SERVE_ARTIFACTS_ENV_VAR 513 514 return os.environ.get(SERVE_ARTIFACTS_ENV_VAR, "false") == "true" 515 516 517 def _is_servable_proxied_run_artifact_root(run_artifact_root): 518 """ 519 Determines whether or not the following are true: 520 521 - The specified Run artifact root is a proxied artifact root (i.e. an artifact root with scheme 522 ``http``, ``https``, or ``mlflow-artifacts``). 523 524 - The MLflow server is capable of resolving and accessing the underlying storage location 525 corresponding to the proxied artifact root, allowing it to fulfill artifact list and 526 download requests by using this storage location directly. 527 528 Args: 529 run_artifact_root: The Run artifact root location (URI). 530 531 Returns: 532 True if the specified Run artifact root refers to proxied artifacts that can be 533 served by this MLflow server (i.e. the server has access to the destination and 534 can respond to list and download requests for the artifact). False otherwise. 535 """ 536 parsed_run_artifact_root = urllib.parse.urlparse(run_artifact_root) 537 # NB: If the run artifact root is a proxied artifact root (has scheme `http`, `https`, or 538 # `mlflow-artifacts`) *and* the MLflow server is configured to serve artifacts, the MLflow 539 # server always assumes that it has access to the underlying storage location for the proxied 540 # artifacts. This may not always be accurate. For example: 541 # 542 # An organization may initially use the MLflow server to serve Tracking API requests and proxy 543 # access to artifacts stored in Location A (via `mlflow server --serve-artifacts`). Then, for 544 # scalability and / or security purposes, the organization may decide to store artifacts in a 545 # new location B and set up a separate server (e.g. `mlflow server --artifacts-only`) to proxy 546 # access to artifacts stored in Location B. 547 # 548 # In this scenario, requests for artifacts stored in Location B that are sent to the original 549 # MLflow server will fail if the original MLflow server does not have access to Location B 550 # because it will assume that it can serve all proxied artifacts regardless of the underlying 551 # location. Such failures can be remediated by granting the original MLflow server access to 552 # Location B. 553 return ( 554 parsed_run_artifact_root.scheme in ["http", "https", "mlflow-artifacts"] 555 and _is_serving_proxied_artifacts() 556 ) 557 558 559 def _get_proxied_run_artifact_destination_path(proxied_artifact_root, relative_path=None): 560 """ 561 Resolves the specified proxied artifact location within a Run to a concrete storage location. 562 563 Args: 564 proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``, 565 ``https``, or `mlflow-artifacts` that can be resolved by the MLflow server to a 566 concrete storage location. 567 relative_path: The relative path of the destination within the specified 568 ``proxied_artifact_root``. If ``None``, the destination is assumed to be 569 the resolved ``proxied_artifact_root``. 570 571 Returns: 572 The storage location of the specified artifact. 573 """ 574 parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root) 575 assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"] 576 577 if parsed_proxied_artifact_root.scheme == "mlflow-artifacts": 578 # If the proxied artifact root is an `mlflow-artifacts` URI, the run artifact root path is 579 # simply the path component of the URI, since the fully-qualified format of an 580 # `mlflow-artifacts` URI is `mlflow-artifacts://<netloc>/path/to/artifact` 581 proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.lstrip("/") 582 else: 583 # In this case, the proxied artifact root is an HTTP(S) URL referring to an mlflow-artifacts 584 # API route that can be used to download the artifact. These routes are always anchored at 585 # `/api/2.0/mlflow-artifacts/artifacts`. Accordingly, we split the path on this route anchor 586 # and interpret the rest of the path (everything after the route anchor) as the run artifact 587 # root path 588 mlflow_artifacts_http_route_anchor = "/api/2.0/mlflow-artifacts/artifacts/" 589 assert mlflow_artifacts_http_route_anchor in parsed_proxied_artifact_root.path 590 591 proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.split( 592 mlflow_artifacts_http_route_anchor 593 )[1].lstrip("/") 594 595 return ( 596 posixpath.join(proxied_run_artifact_root_path, relative_path) 597 if relative_path is not None 598 else proxied_run_artifact_root_path 599 ) 600 601 602 def _get_tracking_store( 603 backend_store_uri: str | None = None, 604 default_artifact_root: str | None = None, 605 ) -> AbstractTrackingStore: 606 from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR 607 608 global _tracking_store 609 if _tracking_store is None: 610 store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None) 611 artifact_root = default_artifact_root or os.environ.get(ARTIFACT_ROOT_ENV_VAR, None) 612 _tracking_store = _tracking_store_registry.get_store(store_uri, artifact_root) 613 utils.set_tracking_uri(store_uri) 614 return _tracking_store 615 616 617 def _get_model_registry_store(registry_store_uri: str | None = None) -> AbstractModelRegistryStore: 618 from mlflow.server import BACKEND_STORE_URI_ENV_VAR, REGISTRY_STORE_URI_ENV_VAR 619 620 global _model_registry_store 621 if _model_registry_store is None: 622 store_uri = ( 623 registry_store_uri 624 or os.environ.get(REGISTRY_STORE_URI_ENV_VAR, None) 625 or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None) 626 ) 627 _model_registry_store = _model_registry_store_registry.get_store(store_uri) 628 registry_utils.set_registry_uri(store_uri) 629 return _model_registry_store 630 631 632 def _get_job_store(backend_store_uri: str | None = None) -> AbstractJobStore: 633 """ 634 Get a job store instance based on the backend store URI. 635 636 Args: 637 backend_store_uri: Optional backend store URI. If not provided, 638 uses environment variable. 639 640 Returns: 641 An instance of AbstractJobStore 642 """ 643 from mlflow.server import BACKEND_STORE_URI_ENV_VAR 644 from mlflow.store.jobs.sqlalchemy_store import SqlAlchemyJobStore 645 from mlflow.store.jobs.sqlalchemy_workspace_store import WorkspaceAwareSqlAlchemyJobStore 646 from mlflow.utils.uri import extract_db_type_from_uri 647 648 global _job_store 649 if _job_store is None: 650 store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None) 651 if not store_uri: 652 raise MlflowException.invalid_parameter_value("Job store requires a backend store URI") 653 try: 654 extract_db_type_from_uri(store_uri) 655 except (MlflowException, ValueError): 656 # Require a database backend URI for the job store 657 # Raise MlflowException so the CLI/REST layer returns a structured 400 658 # instead of surfacing a generic 500 from ValueError 659 raise MlflowException.invalid_parameter_value("Job store requires a backend store URI") 660 661 store_cls = ( 662 WorkspaceAwareSqlAlchemyJobStore 663 if MLFLOW_ENABLE_WORKSPACES.get() 664 else SqlAlchemyJobStore 665 ) 666 _job_store = store_cls(store_uri) 667 668 if MLFLOW_ENABLE_WORKSPACES.get(): 669 _verify_job_store_workspace_support(_job_store) 670 671 return _job_store 672 673 674 def initialize_backend_stores( 675 backend_store_uri: str | None = None, 676 registry_store_uri: str | None = None, 677 default_artifact_root: str | None = None, 678 workspace_store_uri: str | None = None, 679 ) -> None: 680 tracking_store = _get_tracking_store(backend_store_uri, default_artifact_root) 681 registry_store = None 682 try: 683 registry_store = _get_model_registry_store(registry_store_uri) 684 except UnsupportedModelRegistryStoreURIException: 685 pass 686 687 if MLFLOW_ENABLE_WORKSPACES.get(): 688 # Initialize the workspace store to verify it's correctly configured 689 _get_workspace_store( 690 workspace_uri=workspace_store_uri, 691 tracking_uri=backend_store_uri, 692 ) 693 _verify_tracking_store_workspace_support(tracking_store) 694 _verify_model_registry_store_workspace_support(registry_store) 695 696 697 def _store_supports_workspaces( 698 store: AbstractTrackingStore | AbstractModelRegistryStore | AbstractJobStore, 699 ) -> bool: 700 """Return whether the provided store reports workspace support.""" 701 return bool(getattr(store, "supports_workspaces", False)) 702 703 704 def _verify_tracking_store_workspace_support(tracking_store: AbstractTrackingStore) -> None: 705 if not _store_supports_workspaces(tracking_store): 706 raise MlflowException( 707 "The configured tracking store does not support workspace-aware operations. " 708 "Remove the --enable-workspaces flag or configure a workspace-capable backend store.", 709 error_code=INVALID_STATE, 710 ) 711 712 713 def _verify_model_registry_store_workspace_support( 714 registry_store: AbstractModelRegistryStore, 715 ) -> None: 716 if registry_store is None: 717 return 718 719 if not _store_supports_workspaces(registry_store): 720 raise MlflowException( 721 "The configured model registry store does not support workspace-aware operations. " 722 "Remove the --enable-workspaces flag or configure a workspace-capable backend store.", 723 error_code=INVALID_STATE, 724 ) 725 726 727 def _verify_job_store_workspace_support(job_store: AbstractJobStore) -> None: 728 if not _store_supports_workspaces(job_store): 729 raise MlflowException( 730 "The configured job store does not support workspace-aware operations. " 731 "Remove the --enable-workspaces flag or configure a workspace-capable backend store.", 732 error_code=INVALID_STATE, 733 ) 734 735 736 def _assert_string(x): 737 assert isinstance(x, str) 738 739 740 def _assert_intlike(x): 741 try: 742 x = int(x) 743 except ValueError: 744 pass 745 746 assert isinstance(x, int) 747 748 749 def _assert_bool(x): 750 assert isinstance(x, bool) 751 752 753 def _assert_floatlike(x): 754 try: 755 x = float(x) 756 except ValueError: 757 pass 758 759 assert isinstance(x, float) 760 761 762 def _assert_array(x): 763 assert isinstance(x, list) 764 765 766 def _assert_map_key_present(x): 767 _assert_array(x) 768 for entry in x: 769 _assert_required(entry.get("key")) 770 771 772 def _assert_required(x, path=None): 773 if path is None: 774 assert x is not None 775 # When parsing JSON payloads via proto, absent string fields 776 # are expressed as empty strings 777 assert x != "" 778 else: 779 assert x is not None, missing_value(path) 780 assert x != "", missing_value(path) 781 782 783 def _assert_less_than_or_equal(x, max_value, message=None): 784 if x > max_value: 785 raise AssertionError(message) if message else AssertionError() 786 787 788 def _assert_intlike_within_range(x, min_value, max_value, message=None): 789 if not min_value <= x <= max_value: 790 raise AssertionError(message) if message else AssertionError() 791 792 793 def _assert_item_type_string(x): 794 assert all(isinstance(item, str) for item in x) 795 796 797 def _assert_secret_value(x): 798 """Validate secret_value is present. Does not print values in errors.""" 799 if x is None: 800 raise MlflowException( 801 message="Missing value for required parameter 'secret_value'.", 802 error_code=INVALID_PARAMETER_VALUE, 803 ) 804 805 806 _TYPE_VALIDATORS = { 807 _assert_intlike, 808 _assert_string, 809 _assert_bool, 810 _assert_floatlike, 811 _assert_array, 812 _assert_item_type_string, 813 } 814 815 816 def _validate_param_against_schema(schema, param, value, proto_parsing_succeeded=False): 817 """ 818 Attempts to validate a single parameter against a specified schema. Examples of the elements of 819 the schema are type assertions and checks for required parameters. Returns None on validation 820 success. Otherwise, raises an MLFlowException if an assertion fails. This method is intended 821 to be called for side effects. 822 823 Args: 824 schema: A list of functions to validate the parameter against. 825 param: The string name of the parameter being validated. 826 value: The corresponding value of the `param` being validated. 827 proto_parsing_succeeded: A boolean value indicating whether proto parsing succeeded. 828 If the proto was successfully parsed, we assume all of the types of the parameters in 829 the request body were correctly specified, and thus we skip validating types. If proto 830 parsing failed, then we validate types in addition to the rest of the schema. For 831 details, see https://github.com/mlflow/mlflow/pull/5458#issuecomment-1080880870. 832 """ 833 834 for f in schema: 835 if f in _TYPE_VALIDATORS and proto_parsing_succeeded: 836 continue 837 838 try: 839 f(value) 840 except AssertionError as e: 841 if e.args: 842 message = e.args[0] 843 elif f == _assert_required: 844 message = f"Missing value for required parameter '{param}'." 845 else: 846 message = invalid_value( 847 param, value, f" Hint: Value was of type '{type(value).__name__}'." 848 ) 849 raise MlflowException( 850 message=( 851 message + " See the API docs for more information about request parameters." 852 ), 853 error_code=INVALID_PARAMETER_VALUE, 854 ) 855 856 return None 857 858 859 def _get_request_json(flask_request=request): 860 _validate_content_type(flask_request, ["application/json"]) 861 return flask_request.get_json(force=True, silent=True) 862 863 864 def _get_normalized_request_json(flask_request: Request = request) -> dict[str, Any]: 865 """ 866 Get request JSON with normalization for legacy clients. 867 868 Handles double-encoded JSON strings from older clients and empty request bodies. 869 870 Args: 871 flask_request: The Flask request object. 872 873 Returns: 874 The request data as a dictionary (empty dict if no body). 875 """ 876 request_json = _get_request_json(flask_request) 877 878 # Older clients may post their JSON double-encoded as strings, so the get_json 879 # above actually converts it to a string. Therefore, we check this condition 880 # (which we can tell for sure because any proper request should be a dictionary), 881 # and decode it a second time. 882 if is_string_type(request_json): 883 request_json = json.loads(request_json) 884 885 # If request doesn't have json body then assume it's empty. 886 if request_json is None: 887 request_json = {} 888 889 return request_json 890 891 892 def _validate_request_json_with_schema( 893 request_json: dict[str, Any], 894 schema: dict[str, list[Callable[..., Any]]] | None, 895 proto_parsing_succeeded: bool | None, 896 ) -> None: 897 """ 898 Validate request JSON against a schema without requiring protobuf messages. 899 900 Args: 901 request_json: The request data as a dictionary. 902 schema: Dictionary mapping parameter names to lists of validation functions. 903 proto_parsing_succeeded: Whether protobuf parsing succeeded. None indicates the 904 request was not parsed from protobuf. 905 """ 906 schema = schema or {} 907 for schema_key, schema_validation_fns in schema.items(): 908 if schema_key in request_json or _assert_required in schema_validation_fns: 909 value = request_json.get(schema_key) 910 if schema_key == "run_id" and value is None and "run_uuid" in request_json: 911 value = request_json.get("run_uuid") 912 _validate_param_against_schema( 913 schema=schema_validation_fns, 914 param=schema_key, 915 value=value, 916 proto_parsing_succeeded=proto_parsing_succeeded, 917 ) 918 919 920 def _get_request_message(request_message, flask_request=request, schema=None): 921 if flask_request.method == "GET" and flask_request.args: 922 # Convert atomic values of repeated fields to lists before calling protobuf deserialization. 923 # Context: We parse the parameter string into a dictionary outside of protobuf since 924 # protobuf does not know how to read the query parameters directly. The query parser above 925 # has no type information and hence any parameter that occurs exactly once is parsed as an 926 # atomic value. Since protobuf requires that the values of repeated fields are lists, 927 # deserialization will fail unless we do the fix below. 928 request_json = {} 929 for field in request_message.DESCRIPTOR.fields: 930 if field.name not in flask_request.args: 931 continue 932 933 # Use is_repeated property (preferred) with fallback to deprecated label 934 try: 935 is_repeated = field.is_repeated 936 except AttributeError: 937 is_repeated = field.label == descriptor.FieldDescriptor.LABEL_REPEATED 938 939 if is_repeated: 940 request_json[field.name] = flask_request.args.getlist(field.name) 941 else: 942 value = flask_request.args.get(field.name) 943 if field.type == descriptor.FieldDescriptor.TYPE_BOOL and isinstance(value, str): 944 if value.lower() not in ["true", "false"]: 945 raise MlflowException.invalid_parameter_value( 946 f"Invalid boolean value: {value}, must be 'true' or 'false'.", 947 ) 948 value = value.lower() == "true" 949 request_json[field.name] = value 950 else: 951 request_json = _get_normalized_request_json(flask_request) 952 953 proto_parsing_succeeded = True 954 try: 955 parse_dict(request_json, request_message) 956 except ParseError: 957 proto_parsing_succeeded = False 958 959 _validate_request_json_with_schema(request_json, schema, proto_parsing_succeeded) 960 961 return request_message 962 963 964 def _get_validated_flask_request_json( 965 flask_request: Request = request, 966 schema: dict[str, list[Callable[..., Any]]] | None = None, 967 ) -> dict[str, Any]: 968 """ 969 Get and validate request data without protobuf parsing. 970 971 This is an alternative to _get_request_message for endpoints that don't 972 use protobuf message definitions. Supports both GET and POST/PUT requests. 973 974 Args: 975 flask_request: The Flask request object. 976 schema: Dictionary mapping parameter names to lists of validation functions. 977 978 Returns: 979 The validated request data as a dictionary. 980 981 Raises: 982 MlflowException: If validation fails. 983 """ 984 if flask_request.method == "GET" and flask_request.args: 985 # Extract query parameters for GET requests 986 request_json = {} 987 schema = schema or {} 988 for key in flask_request.args: 989 # Get all values for this key (supports repeated parameters) 990 values = flask_request.args.getlist(key) 991 # Check if this field is a list type by looking for _assert_array validator 992 is_list_type = _assert_array in schema.get(key, []) 993 # If list type, always keep as list; otherwise use scalar if only one value 994 request_json[key] = ( 995 values if is_list_type else (values[0] if len(values) == 1 else values) 996 ) 997 else: 998 request_json = _get_normalized_request_json(flask_request) 999 1000 _validate_request_json_with_schema(request_json, schema, proto_parsing_succeeded=None) 1001 1002 return request_json 1003 1004 1005 def _response_with_file_attachment_headers(file_path, response): 1006 mime_type = _guess_mime_type(file_path) 1007 filename = pathlib.Path(file_path).name 1008 response.mimetype = mime_type 1009 content_disposition_header_name = "Content-Disposition" 1010 if content_disposition_header_name not in response.headers: 1011 response.headers[content_disposition_header_name] = f"attachment; filename={filename}" 1012 response.headers["X-Content-Type-Options"] = "nosniff" 1013 response.headers["Content-Type"] = mime_type 1014 return response 1015 1016 1017 def _send_artifact(artifact_repository, path): 1018 file_path = os.path.abspath(artifact_repository.download_artifacts(path)) 1019 # Always send artifacts as attachments to prevent the browser from displaying them on our web 1020 # server's domain, which might enable XSS. 1021 mime_type = _guess_mime_type(file_path) 1022 file_sender_response = send_file(file_path, mimetype=mime_type, as_attachment=True) 1023 return _response_with_file_attachment_headers(file_path, file_sender_response) 1024 1025 1026 def catch_mlflow_exception(func): 1027 @wraps(func) 1028 def wrapper(*args, **kwargs): 1029 try: 1030 return func(*args, **kwargs) 1031 except MlflowException as e: 1032 response = Response(mimetype="application/json") 1033 response.set_data(e.serialize_as_json()) 1034 response.status_code = e.get_http_status_code() 1035 if response.status_code >= 500: 1036 is_debug = _logger.isEnabledFor(logging.DEBUG) 1037 msg = f"Error in {func.__name__}: {e}" 1038 if not is_debug: 1039 msg += ". Set MLFLOW_LOGGING_LEVEL=DEBUG for traceback." 1040 _logger.error(msg, exc_info=is_debug) 1041 return response 1042 1043 return wrapper 1044 1045 1046 def _disable_unless_serve_artifacts(func): 1047 @wraps(func) 1048 def wrapper(*args, **kwargs): 1049 if not _is_serving_proxied_artifacts(): 1050 return Response( 1051 ( 1052 f"Endpoint: {request.url_rule} disabled due to the mlflow server running " 1053 "with `--no-serve-artifacts`. To enable artifacts server functionality, " 1054 "run `mlflow server` with `--serve-artifacts`" 1055 ), 1056 503, 1057 ) 1058 return func(*args, **kwargs) 1059 1060 return wrapper 1061 1062 1063 def _disable_if_artifacts_only(func): 1064 @wraps(func) 1065 def wrapper(*args, **kwargs): 1066 from mlflow.server import ARTIFACTS_ONLY_ENV_VAR 1067 1068 if os.environ.get(ARTIFACTS_ONLY_ENV_VAR): 1069 return Response( 1070 ( 1071 f"Endpoint: {request.url_rule} disabled due to the mlflow server running " 1072 "in `--artifacts-only` mode. To enable tracking server functionality, run " 1073 "`mlflow server` without `--artifacts-only`" 1074 ), 1075 503, 1076 ) 1077 return func(*args, **kwargs) 1078 1079 return wrapper 1080 1081 1082 def _disable_if_workspaces_disabled(func): 1083 @wraps(func) 1084 def wrapper(*args, **kwargs): 1085 if not MLFLOW_ENABLE_WORKSPACES.get(): 1086 return Response( 1087 ( 1088 f"Endpoint: {request.url_rule} disabled because the server is running " 1089 "without workspaces support. To enable workspace, run " 1090 "`mlflow server` with `--enable-workspaces`" 1091 ), 1092 503, 1093 ) 1094 return func(*args, **kwargs) 1095 1096 return wrapper 1097 1098 1099 def _workspace_not_supported(message: str) -> MlflowException: 1100 return MlflowException(message, FEATURE_DISABLED) 1101 1102 1103 def _validate_artifact_root_uri(value: str, field_name: str) -> str: 1104 parsed = urllib.parse.urlparse(value) 1105 if parsed.fragment or parsed.params: 1106 raise MlflowException.invalid_parameter_value( 1107 f"'{field_name}' URL can't include fragments or params." 1108 ) 1109 1110 validate_query_string(parsed.query) 1111 _validate_experiment_artifact_location(value) 1112 _validate_experiment_artifact_location_length(value) 1113 return value 1114 1115 1116 def _validate_workspace_default_artifact_root(value: str | None) -> str | None: 1117 if value is None: 1118 return None 1119 1120 trimmed = value.strip() 1121 if not trimmed: 1122 return "" 1123 1124 return _validate_artifact_root_uri(trimmed, "default_artifact_root") 1125 1126 1127 def _ensure_artifact_root_available(workspace_artifact_root: str | None) -> None: 1128 """Ensure an artifact root is available either at workspace or server level. 1129 1130 Args: 1131 workspace_artifact_root: The workspace's default_artifact_root value. 1132 - None means "not specified" (fallback to server default) 1133 - "" means "clear/unset" (fallback to server default) 1134 - non-empty string means "use this workspace-specific root" 1135 1136 Raises: 1137 MlflowException: If neither workspace nor server has an artifact root configured. 1138 """ 1139 # If workspace has a non-empty artifact root, it's valid 1140 if workspace_artifact_root: 1141 return 1142 1143 # Otherwise, check if server has a default artifact root 1144 server_artifact_root = _get_tracking_store().artifact_root_uri 1145 if not server_artifact_root: 1146 raise MlflowException.invalid_parameter_value( 1147 "Cannot create or update workspace without an artifact root. Either specify " 1148 "'default_artifact_root' for this workspace or start the server with " 1149 "'--default-artifact-root'." 1150 ) 1151 1152 1153 @catch_mlflow_exception 1154 @_disable_if_workspaces_disabled 1155 def _list_workspaces_handler(): 1156 _get_request_message(ListWorkspaces()) 1157 workspaces = _get_workspace_store().list_workspaces() 1158 response_message = ListWorkspaces.Response() 1159 response_message.workspaces.extend([ws.to_proto() for ws in workspaces]) 1160 return _wrap_response(response_message) 1161 1162 1163 @catch_mlflow_exception 1164 @_disable_if_workspaces_disabled 1165 def _create_workspace_handler(): 1166 request_message = _get_request_message( 1167 CreateWorkspace(), 1168 schema={ 1169 "name": [_assert_required, _assert_string], 1170 "description": [_assert_string], 1171 "default_artifact_root": [_assert_string], 1172 }, 1173 ) 1174 1175 if request_message.name == DEFAULT_WORKSPACE_NAME: 1176 raise MlflowException.invalid_parameter_value( 1177 f"The '{DEFAULT_WORKSPACE_NAME}' workspace is reserved and cannot be created" 1178 ) 1179 WorkspaceNameValidator.validate(request_message.name) 1180 description = request_message.description if request_message.HasField("description") else None 1181 default_artifact_root = ( 1182 request_message.default_artifact_root 1183 if request_message.HasField("default_artifact_root") 1184 else None 1185 ) 1186 default_artifact_root = _validate_workspace_default_artifact_root(default_artifact_root) 1187 _ensure_artifact_root_available(default_artifact_root) 1188 store = _get_workspace_store() 1189 try: 1190 workspace = store.create_workspace( 1191 Workspace( 1192 name=request_message.name, 1193 description=description, 1194 default_artifact_root=default_artifact_root, 1195 ) 1196 ) 1197 except NotImplementedError: 1198 raise _workspace_not_supported("Workspace creation is not supported by this provider") 1199 1200 response_message = CreateWorkspace.Response() 1201 response_message.workspace.MergeFrom(workspace.to_proto()) 1202 response = _wrap_response(response_message) 1203 response.status_code = 201 1204 return response 1205 1206 1207 @catch_mlflow_exception 1208 @_disable_if_workspaces_disabled 1209 def _get_workspace_handler(workspace_name: str): 1210 if workspace_name != DEFAULT_WORKSPACE_NAME: 1211 WorkspaceNameValidator.validate(workspace_name) 1212 workspace = _get_workspace_store().get_workspace(workspace_name) 1213 response_message = GetWorkspace.Response() 1214 response_message.workspace.MergeFrom(workspace.to_proto()) 1215 return _wrap_response(response_message) 1216 1217 1218 @catch_mlflow_exception 1219 @_disable_if_workspaces_disabled 1220 def _update_workspace_handler(workspace_name: str): 1221 if workspace_name != DEFAULT_WORKSPACE_NAME: 1222 WorkspaceNameValidator.validate(workspace_name) 1223 request_message = _get_request_message( 1224 UpdateWorkspace(), 1225 schema={ 1226 "description": [_assert_string], 1227 "default_artifact_root": [_assert_string], 1228 }, 1229 ) 1230 1231 has_description = request_message.HasField("description") 1232 has_artifact_root = request_message.HasField("default_artifact_root") 1233 1234 if not has_description and not has_artifact_root: 1235 raise MlflowException.invalid_parameter_value("Workspace update must have at least one key") 1236 1237 description = request_message.description if has_description else None 1238 default_artifact_root = request_message.default_artifact_root if has_artifact_root else None 1239 default_artifact_root = _validate_workspace_default_artifact_root(default_artifact_root) 1240 1241 # If the user is clearing the workspace artifact root (empty string), ensure the server 1242 # has a default artifact root configured 1243 if default_artifact_root == "": 1244 _ensure_artifact_root_available(default_artifact_root) 1245 1246 store = _get_workspace_store() 1247 try: 1248 workspace = store.update_workspace( 1249 Workspace( 1250 name=workspace_name, 1251 description=description, 1252 default_artifact_root=default_artifact_root, 1253 ) 1254 ) 1255 except NotImplementedError: 1256 raise _workspace_not_supported("Workspace updates are not supported by this provider") 1257 1258 response_message = UpdateWorkspace.Response() 1259 response_message.workspace.MergeFrom(workspace.to_proto()) 1260 return _wrap_response(response_message) 1261 1262 1263 @catch_mlflow_exception 1264 @_disable_if_workspaces_disabled 1265 def _delete_workspace_handler(workspace_name: str): 1266 if workspace_name == DEFAULT_WORKSPACE_NAME: 1267 raise MlflowException.invalid_parameter_value( 1268 f"The '{DEFAULT_WORKSPACE_NAME}' workspace is reserved and cannot be deleted" 1269 ) 1270 WorkspaceNameValidator.validate(workspace_name) 1271 mode_str = request.args.get("mode", WorkspaceDeletionMode.RESTRICT.value) 1272 try: 1273 mode = WorkspaceDeletionMode(mode_str) 1274 except ValueError: 1275 raise MlflowException.invalid_parameter_value( 1276 f"Invalid deletion mode '{mode_str}'. " 1277 f"Must be one of: {', '.join(m.value for m in WorkspaceDeletionMode)}" 1278 ) 1279 store = _get_workspace_store() 1280 try: 1281 store.delete_workspace(workspace_name, mode=mode) 1282 except NotImplementedError: 1283 raise _workspace_not_supported("Workspace deletion is not supported by this provider") 1284 return Response(status=204) 1285 1286 1287 @catch_mlflow_exception 1288 def get_artifact_handler(): 1289 run_id = request.args.get("run_id") or request.args.get("run_uuid") 1290 path = request.args["path"] 1291 path = validate_path_is_safe(path) 1292 run = _get_tracking_store().get_run(run_id) 1293 1294 if _is_servable_proxied_run_artifact_root(run.info.artifact_uri): 1295 artifact_repo = _get_artifact_repo_mlflow_artifacts() 1296 artifact_path = _get_proxied_run_artifact_destination_path( 1297 proxied_artifact_root=run.info.artifact_uri, 1298 relative_path=path, 1299 ) 1300 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 1301 else: 1302 artifact_repo = _get_artifact_repo(run) 1303 artifact_path = path 1304 1305 return _send_artifact(artifact_repo, artifact_path) 1306 1307 1308 def _not_implemented(): 1309 response = Response() 1310 response.status_code = 404 1311 return response 1312 1313 1314 # Tracking Server APIs 1315 1316 1317 @catch_mlflow_exception 1318 @_disable_if_artifacts_only 1319 def _create_experiment(): 1320 request_message = _get_request_message( 1321 CreateExperiment(), 1322 schema={ 1323 "name": [_assert_required, _assert_string], 1324 "artifact_location": [_assert_string], 1325 "tags": [_assert_array], 1326 }, 1327 ) 1328 1329 tags = [ExperimentTag(tag.key, tag.value) for tag in request_message.tags] 1330 1331 if request_message.artifact_location: 1332 _validate_artifact_root_uri(request_message.artifact_location, "artifact_location") 1333 experiment_id = _get_tracking_store().create_experiment( 1334 request_message.name, request_message.artifact_location, tags 1335 ) 1336 response_message = CreateExperiment.Response() 1337 response_message.experiment_id = experiment_id 1338 response = Response(mimetype="application/json") 1339 response.set_data(message_to_json(response_message)) 1340 return response 1341 1342 1343 @catch_mlflow_exception 1344 @_disable_if_artifacts_only 1345 def _get_experiment(): 1346 request_message = _get_request_message( 1347 GetExperiment(), schema={"experiment_id": [_assert_required, _assert_string]} 1348 ) 1349 response_message = get_experiment_impl(request_message) 1350 response = Response(mimetype="application/json") 1351 response.set_data(message_to_json(response_message)) 1352 return response 1353 1354 1355 def get_experiment_impl(request_message): 1356 response_message = GetExperiment.Response() 1357 experiment = _get_tracking_store().get_experiment(request_message.experiment_id).to_proto() 1358 response_message.experiment.MergeFrom(experiment) 1359 return response_message 1360 1361 1362 @catch_mlflow_exception 1363 @_disable_if_artifacts_only 1364 def _get_experiment_by_name(): 1365 request_message = _get_request_message( 1366 GetExperimentByName(), 1367 schema={"experiment_name": [_assert_required, _assert_string]}, 1368 ) 1369 response_message = GetExperimentByName.Response() 1370 store_exp = _get_tracking_store().get_experiment_by_name(request_message.experiment_name) 1371 if store_exp is None: 1372 raise MlflowException( 1373 f"Could not find experiment with name '{request_message.experiment_name}'", 1374 error_code=RESOURCE_DOES_NOT_EXIST, 1375 ) 1376 experiment = store_exp.to_proto() 1377 response_message.experiment.MergeFrom(experiment) 1378 response = Response(mimetype="application/json") 1379 response.set_data(message_to_json(response_message)) 1380 return response 1381 1382 1383 @catch_mlflow_exception 1384 @_disable_if_artifacts_only 1385 def _delete_experiment(): 1386 request_message = _get_request_message( 1387 DeleteExperiment(), schema={"experiment_id": [_assert_required, _assert_string]} 1388 ) 1389 _get_tracking_store().delete_experiment(request_message.experiment_id) 1390 response_message = DeleteExperiment.Response() 1391 response = Response(mimetype="application/json") 1392 response.set_data(message_to_json(response_message)) 1393 return response 1394 1395 1396 @catch_mlflow_exception 1397 @_disable_if_artifacts_only 1398 def _restore_experiment(): 1399 request_message = _get_request_message( 1400 RestoreExperiment(), 1401 schema={"experiment_id": [_assert_required, _assert_string]}, 1402 ) 1403 _get_tracking_store().restore_experiment(request_message.experiment_id) 1404 response_message = RestoreExperiment.Response() 1405 response = Response(mimetype="application/json") 1406 response.set_data(message_to_json(response_message)) 1407 return response 1408 1409 1410 @catch_mlflow_exception 1411 @_disable_if_artifacts_only 1412 def _update_experiment(): 1413 request_message = _get_request_message( 1414 UpdateExperiment(), 1415 schema={ 1416 "experiment_id": [_assert_required, _assert_string], 1417 "new_name": [_assert_string, _assert_required], 1418 }, 1419 ) 1420 if request_message.new_name: 1421 _get_tracking_store().rename_experiment( 1422 request_message.experiment_id, request_message.new_name 1423 ) 1424 response_message = UpdateExperiment.Response() 1425 response = Response(mimetype="application/json") 1426 response.set_data(message_to_json(response_message)) 1427 return response 1428 1429 1430 @catch_mlflow_exception 1431 @_disable_if_artifacts_only 1432 def _create_run(): 1433 request_message = _get_request_message( 1434 CreateRun(), 1435 schema={ 1436 "experiment_id": [_assert_string], 1437 "start_time": [_assert_intlike], 1438 "run_name": [_assert_string], 1439 }, 1440 ) 1441 1442 tags = [RunTag(tag.key, tag.value) for tag in request_message.tags] 1443 run = _get_tracking_store().create_run( 1444 experiment_id=request_message.experiment_id, 1445 user_id=request_message.user_id, 1446 start_time=request_message.start_time, 1447 tags=tags, 1448 run_name=request_message.run_name, 1449 ) 1450 1451 response_message = CreateRun.Response() 1452 response_message.run.MergeFrom(run.to_proto()) 1453 response = Response(mimetype="application/json") 1454 response.set_data(message_to_json(response_message)) 1455 return response 1456 1457 1458 @catch_mlflow_exception 1459 @_disable_if_artifacts_only 1460 def _update_run(): 1461 request_message = _get_request_message( 1462 UpdateRun(), 1463 schema={ 1464 "run_id": [_assert_required, _assert_string], 1465 "end_time": [_assert_intlike], 1466 "status": [_assert_string], 1467 "run_name": [_assert_string], 1468 }, 1469 ) 1470 run_id = request_message.run_id or request_message.run_uuid 1471 run_name = request_message.run_name if request_message.HasField("run_name") else None 1472 end_time = request_message.end_time if request_message.HasField("end_time") else None 1473 status = request_message.status if request_message.HasField("status") else None 1474 updated_info = _get_tracking_store().update_run_info(run_id, status, end_time, run_name) 1475 response_message = UpdateRun.Response(run_info=updated_info.to_proto()) 1476 response = Response(mimetype="application/json") 1477 response.set_data(message_to_json(response_message)) 1478 return response 1479 1480 1481 @catch_mlflow_exception 1482 @_disable_if_artifacts_only 1483 def _delete_run(): 1484 request_message = _get_request_message( 1485 DeleteRun(), schema={"run_id": [_assert_required, _assert_string]} 1486 ) 1487 _get_tracking_store().delete_run(request_message.run_id) 1488 response_message = DeleteRun.Response() 1489 response = Response(mimetype="application/json") 1490 response.set_data(message_to_json(response_message)) 1491 return response 1492 1493 1494 @catch_mlflow_exception 1495 @_disable_if_artifacts_only 1496 def _restore_run(): 1497 request_message = _get_request_message( 1498 RestoreRun(), schema={"run_id": [_assert_required, _assert_string]} 1499 ) 1500 _get_tracking_store().restore_run(request_message.run_id) 1501 response_message = RestoreRun.Response() 1502 response = Response(mimetype="application/json") 1503 response.set_data(message_to_json(response_message)) 1504 return response 1505 1506 1507 @catch_mlflow_exception 1508 @_disable_if_artifacts_only 1509 def _log_metric(): 1510 request_message = _get_request_message( 1511 LogMetric(), 1512 schema={ 1513 "run_id": [_assert_required, _assert_string], 1514 "key": [_assert_required, _assert_string], 1515 "value": [_assert_required, _assert_floatlike], 1516 "timestamp": [_assert_intlike, _assert_required], 1517 "step": [_assert_intlike], 1518 "model_id": [_assert_string], 1519 "dataset_name": [_assert_string], 1520 "dataset_digest": [_assert_string], 1521 }, 1522 ) 1523 metric = Metric( 1524 request_message.key, 1525 request_message.value, 1526 request_message.timestamp, 1527 request_message.step, 1528 request_message.model_id or None, 1529 request_message.dataset_name or None, 1530 request_message.dataset_digest or None, 1531 request_message.run_id or None, 1532 ) 1533 run_id = request_message.run_id or request_message.run_uuid 1534 _get_tracking_store().log_metric(run_id, metric) 1535 response_message = LogMetric.Response() 1536 response = Response(mimetype="application/json") 1537 response.set_data(message_to_json(response_message)) 1538 return response 1539 1540 1541 @catch_mlflow_exception 1542 @_disable_if_artifacts_only 1543 def _log_param(): 1544 request_message = _get_request_message( 1545 LogParam(), 1546 schema={ 1547 "run_id": [_assert_required, _assert_string], 1548 "key": [_assert_required, _assert_string], 1549 "value": [_assert_string], 1550 }, 1551 ) 1552 param = Param(request_message.key, request_message.value) 1553 run_id = request_message.run_id or request_message.run_uuid 1554 _get_tracking_store().log_param(run_id, param) 1555 response_message = LogParam.Response() 1556 response = Response(mimetype="application/json") 1557 response.set_data(message_to_json(response_message)) 1558 return response 1559 1560 1561 @catch_mlflow_exception 1562 @_disable_if_artifacts_only 1563 def _log_inputs(): 1564 request_message = _get_request_message( 1565 LogInputs(), 1566 schema={ 1567 "run_id": [_assert_required, _assert_string], 1568 "datasets": [_assert_array], 1569 "models": [_assert_array], 1570 }, 1571 ) 1572 run_id = request_message.run_id 1573 datasets = [ 1574 DatasetInput.from_proto(proto_dataset_input) 1575 for proto_dataset_input in request_message.datasets 1576 ] 1577 models = ( 1578 [ 1579 LoggedModelInput.from_proto(proto_logged_model_input) 1580 for proto_logged_model_input in request_message.models 1581 ] 1582 if request_message.models 1583 else None 1584 ) 1585 1586 _get_tracking_store().log_inputs(run_id, datasets=datasets, models=models) 1587 response_message = LogInputs.Response() 1588 response = Response(mimetype="application/json") 1589 response.set_data(message_to_json(response_message)) 1590 return response 1591 1592 1593 @catch_mlflow_exception 1594 @_disable_if_artifacts_only 1595 def _log_outputs(): 1596 request_message = _get_request_message( 1597 LogOutputs(), 1598 schema={ 1599 "run_id": [_assert_required, _assert_string], 1600 "models": [_assert_required, _assert_array], 1601 }, 1602 ) 1603 models = [LoggedModelOutput.from_proto(p) for p in request_message.models] 1604 _get_tracking_store().log_outputs(run_id=request_message.run_id, models=models) 1605 response_message = LogOutputs.Response() 1606 return _wrap_response(response_message) 1607 1608 1609 @catch_mlflow_exception 1610 @_disable_if_artifacts_only 1611 def _set_experiment_tag(): 1612 request_message = _get_request_message( 1613 SetExperimentTag(), 1614 schema={ 1615 "experiment_id": [_assert_required, _assert_string], 1616 "key": [_assert_required, _assert_string], 1617 "value": [_assert_string], 1618 }, 1619 ) 1620 tag = ExperimentTag(request_message.key, request_message.value) 1621 _get_tracking_store().set_experiment_tag(request_message.experiment_id, tag) 1622 response_message = SetExperimentTag.Response() 1623 response = Response(mimetype="application/json") 1624 response.set_data(message_to_json(response_message)) 1625 return response 1626 1627 1628 @catch_mlflow_exception 1629 @_disable_if_artifacts_only 1630 def _delete_experiment_tag(): 1631 request_message = _get_request_message( 1632 DeleteExperimentTag(), 1633 schema={ 1634 "experiment_id": [_assert_required, _assert_string], 1635 "key": [_assert_required, _assert_string], 1636 }, 1637 ) 1638 _get_tracking_store().delete_experiment_tag(request_message.experiment_id, request_message.key) 1639 response_message = DeleteExperimentTag.Response() 1640 response = Response(mimetype="application/json") 1641 response.set_data(message_to_json(response_message)) 1642 return response 1643 1644 1645 @catch_mlflow_exception 1646 @_disable_if_artifacts_only 1647 def _set_tag(): 1648 request_message = _get_request_message( 1649 SetTag(), 1650 schema={ 1651 "run_id": [_assert_required, _assert_string], 1652 "key": [_assert_required, _assert_string], 1653 "value": [_assert_string], 1654 }, 1655 ) 1656 tag = RunTag(request_message.key, request_message.value) 1657 run_id = request_message.run_id or request_message.run_uuid 1658 _get_tracking_store().set_tag(run_id, tag) 1659 response_message = SetTag.Response() 1660 response = Response(mimetype="application/json") 1661 response.set_data(message_to_json(response_message)) 1662 return response 1663 1664 1665 @catch_mlflow_exception 1666 @_disable_if_artifacts_only 1667 def _delete_tag(): 1668 request_message = _get_request_message( 1669 DeleteTag(), 1670 schema={ 1671 "run_id": [_assert_required, _assert_string], 1672 "key": [_assert_required, _assert_string], 1673 }, 1674 ) 1675 _get_tracking_store().delete_tag(request_message.run_id, request_message.key) 1676 response_message = DeleteTag.Response() 1677 response = Response(mimetype="application/json") 1678 response.set_data(message_to_json(response_message)) 1679 return response 1680 1681 1682 @catch_mlflow_exception 1683 @_disable_if_artifacts_only 1684 def _get_run(): 1685 request_message = _get_request_message( 1686 GetRun(), schema={"run_id": [_assert_required, _assert_string]} 1687 ) 1688 response_message = get_run_impl(request_message) 1689 response = Response(mimetype="application/json") 1690 response.set_data(message_to_json(response_message)) 1691 return response 1692 1693 1694 def get_run_impl(request_message): 1695 response_message = GetRun.Response() 1696 run_id = request_message.run_id or request_message.run_uuid 1697 response_message.run.MergeFrom(_get_tracking_store().get_run(run_id).to_proto()) 1698 return response_message 1699 1700 1701 @catch_mlflow_exception 1702 @_disable_if_artifacts_only 1703 def _search_runs(): 1704 request_message = _get_request_message( 1705 SearchRuns(), 1706 schema={ 1707 "experiment_ids": [_assert_array], 1708 "filter": [_assert_string], 1709 "max_results": [ 1710 _assert_intlike, 1711 lambda x: _assert_less_than_or_equal(int(x), 50000), 1712 ], 1713 "order_by": [_assert_array, _assert_item_type_string], 1714 }, 1715 ) 1716 response_message = search_runs_impl(request_message) 1717 response = Response(mimetype="application/json") 1718 response.set_data(message_to_json(response_message)) 1719 return response 1720 1721 1722 def search_runs_impl(request_message): 1723 response_message = SearchRuns.Response() 1724 run_view_type = ViewType.ACTIVE_ONLY 1725 if request_message.HasField("run_view_type"): 1726 run_view_type = ViewType.from_proto(request_message.run_view_type) 1727 filter_string = request_message.filter 1728 max_results = request_message.max_results 1729 experiment_ids = list(request_message.experiment_ids) 1730 1731 # NB: Local import to avoid circular dependency (auth imports from handlers) 1732 try: 1733 from mlflow.server import auth 1734 1735 if auth.auth_config: 1736 experiment_ids = auth.filter_experiment_ids(experiment_ids) 1737 except ImportError: 1738 # Auth module not available (Flask-WTF not installed), skip filtering 1739 pass 1740 1741 order_by = request_message.order_by 1742 run_entities = _get_tracking_store().search_runs( 1743 experiment_ids=experiment_ids, 1744 filter_string=filter_string, 1745 run_view_type=run_view_type, 1746 max_results=max_results, 1747 order_by=order_by, 1748 page_token=request_message.page_token or None, 1749 ) 1750 response_message.runs.extend([r.to_proto() for r in run_entities]) 1751 if run_entities.token: 1752 response_message.next_page_token = run_entities.token 1753 return response_message 1754 1755 1756 @catch_mlflow_exception 1757 @_disable_if_artifacts_only 1758 def _list_artifacts(): 1759 request_message = _get_request_message( 1760 ListArtifacts(), 1761 schema={ 1762 "run_id": [_assert_string, _assert_required], 1763 "path": [_assert_string], 1764 "page_token": [_assert_string], 1765 }, 1766 ) 1767 response_message = list_artifacts_impl(request_message) 1768 response = Response(mimetype="application/json") 1769 response.set_data(message_to_json(response_message)) 1770 return response 1771 1772 1773 def list_artifacts_impl(request_message): 1774 response_message = ListArtifacts.Response() 1775 if request_message.HasField("path"): 1776 path = request_message.path 1777 path = validate_path_is_safe(path) 1778 else: 1779 path = None 1780 run_id = request_message.run_id or request_message.run_uuid 1781 run = _get_tracking_store().get_run(run_id) 1782 1783 if _is_servable_proxied_run_artifact_root(run.info.artifact_uri): 1784 artifact_entities = _list_artifacts_for_proxied_run_artifact_root( 1785 proxied_artifact_root=run.info.artifact_uri, 1786 relative_path=path, 1787 ) 1788 else: 1789 artifact_entities = _get_artifact_repo(run).list_artifacts(path) 1790 1791 response_message.files.extend([a.to_proto() for a in artifact_entities]) 1792 response_message.root_uri = run.info.artifact_uri 1793 return response_message 1794 1795 1796 def _list_artifacts_for_proxied_run_artifact_root(proxied_artifact_root, relative_path=None): 1797 """ 1798 Lists artifacts from the specified ``relative_path`` within the specified proxied Run artifact 1799 root (i.e. a Run artifact root with scheme ``http``, ``https``, or ``mlflow-artifacts``). 1800 1801 Args: 1802 proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``, 1803 ``https``, or ``mlflow-artifacts`` that can be resolved by the 1804 MLflow server to a concrete storage location. 1805 relative_path: The relative path within the specified ``proxied_artifact_root`` under 1806 which to list artifact contents. If ``None``, artifacts are listed from 1807 the ``proxied_artifact_root`` directory. 1808 """ 1809 parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root) 1810 assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"] 1811 1812 artifact_destination_repo = _get_artifact_repo_mlflow_artifacts() 1813 artifact_destination_path = _get_proxied_run_artifact_destination_path( 1814 proxied_artifact_root=proxied_artifact_root, 1815 relative_path=relative_path, 1816 ) 1817 artifact_destination_path = _get_workspace_scoped_repo_path_if_enabled( 1818 artifact_destination_path 1819 ) 1820 1821 artifact_entities = [] 1822 for file_info in artifact_destination_repo.list_artifacts(artifact_destination_path): 1823 basename = posixpath.basename(file_info.path) 1824 run_relative_artifact_path = ( 1825 posixpath.join(relative_path, basename) if relative_path else basename 1826 ) 1827 artifact_entities.append( 1828 FileInfo(run_relative_artifact_path, file_info.is_dir, file_info.file_size) 1829 ) 1830 1831 return artifact_entities 1832 1833 1834 @catch_mlflow_exception 1835 @_disable_if_artifacts_only 1836 def _get_metric_history(): 1837 request_message = _get_request_message( 1838 GetMetricHistory(), 1839 schema={ 1840 "run_id": [_assert_string, _assert_required], 1841 "metric_key": [_assert_string, _assert_required], 1842 "page_token": [_assert_string], 1843 }, 1844 ) 1845 response_message = GetMetricHistory.Response() 1846 run_id = request_message.run_id or request_message.run_uuid 1847 1848 max_results = request_message.max_results if request_message.max_results is not None else None 1849 1850 metric_entities = _get_tracking_store().get_metric_history( 1851 run_id, 1852 request_message.metric_key, 1853 max_results=max_results, 1854 page_token=request_message.page_token or None, 1855 ) 1856 response_message.metrics.extend([m.to_proto() for m in metric_entities]) 1857 1858 # Set next_page_token if available 1859 if next_page_token := metric_entities.token: 1860 response_message.next_page_token = next_page_token 1861 1862 response = Response(mimetype="application/json") 1863 response.set_data(message_to_json(response_message)) 1864 return response 1865 1866 1867 @catch_mlflow_exception 1868 @_disable_if_artifacts_only 1869 def get_metric_history_bulk_handler(): 1870 MAX_HISTORY_RESULTS = 25000 1871 MAX_RUN_IDS_PER_REQUEST = 100 1872 run_ids = request.args.to_dict(flat=False).get("run_id", []) 1873 if not run_ids: 1874 raise MlflowException( 1875 message="GetMetricHistoryBulk request must specify at least one run_id.", 1876 error_code=INVALID_PARAMETER_VALUE, 1877 ) 1878 if len(run_ids) > MAX_RUN_IDS_PER_REQUEST: 1879 raise MlflowException( 1880 message=( 1881 f"GetMetricHistoryBulk request cannot specify more than {MAX_RUN_IDS_PER_REQUEST}" 1882 f" run_ids. Received {len(run_ids)} run_ids." 1883 ), 1884 error_code=INVALID_PARAMETER_VALUE, 1885 ) 1886 1887 metric_key = request.args.get("metric_key") 1888 if metric_key is None: 1889 raise MlflowException( 1890 message="GetMetricHistoryBulk request must specify a metric_key.", 1891 error_code=INVALID_PARAMETER_VALUE, 1892 ) 1893 1894 max_results = int(request.args.get("max_results", MAX_HISTORY_RESULTS)) 1895 max_results = min(max_results, MAX_HISTORY_RESULTS) 1896 1897 store = _get_tracking_store() 1898 1899 def _default_history_bulk_impl(): 1900 metrics_with_run_ids = [] 1901 for run_id in sorted(run_ids): 1902 metrics_for_run = sorted( 1903 store.get_metric_history( 1904 run_id=run_id, 1905 metric_key=metric_key, 1906 max_results=max_results, 1907 ), 1908 key=lambda metric: (metric.timestamp, metric.step, metric.value), 1909 ) 1910 metrics_with_run_ids.extend([ 1911 { 1912 "key": metric.key, 1913 "value": metric.value, 1914 "timestamp": metric.timestamp, 1915 "step": metric.step, 1916 "run_id": run_id, 1917 } 1918 for metric in metrics_for_run 1919 ]) 1920 return metrics_with_run_ids 1921 1922 if hasattr(store, "get_metric_history_bulk"): 1923 metrics_with_run_ids = [ 1924 metric.to_dict() 1925 for metric in store.get_metric_history_bulk( 1926 run_ids=run_ids, 1927 metric_key=metric_key, 1928 max_results=max_results, 1929 ) 1930 ] 1931 else: 1932 metrics_with_run_ids = _default_history_bulk_impl() 1933 1934 return { 1935 "metrics": metrics_with_run_ids[:max_results], 1936 } 1937 1938 1939 @catch_mlflow_exception 1940 @_disable_if_artifacts_only 1941 def get_metric_history_bulk_interval_handler(): 1942 request_message = _get_request_message( 1943 GetMetricHistoryBulkInterval(), 1944 schema={ 1945 "run_ids": [ 1946 _assert_required, 1947 _assert_array, 1948 _assert_item_type_string, 1949 lambda x: _assert_less_than_or_equal( 1950 len(x), 1951 MAX_RUNS_GET_METRIC_HISTORY_BULK, 1952 message=f"GetMetricHistoryBulkInterval request must specify at most " 1953 f"{MAX_RUNS_GET_METRIC_HISTORY_BULK} run_ids. Received {len(x)} run_ids.", 1954 ), 1955 ], 1956 "metric_key": [_assert_required, _assert_string], 1957 "start_step": [_assert_intlike], 1958 "end_step": [_assert_intlike], 1959 "max_results": [ 1960 _assert_intlike, 1961 lambda x: _assert_intlike_within_range( 1962 int(x), 1963 1, 1964 MAX_RESULTS_PER_RUN, 1965 message=f"max_results must be between 1 and {MAX_RESULTS_PER_RUN}.", 1966 ), 1967 ], 1968 }, 1969 ) 1970 response_message = get_metric_history_bulk_interval_impl(request_message) 1971 response = Response(mimetype="application/json") 1972 response.set_data(message_to_json(response_message)) 1973 return response 1974 1975 1976 def get_metric_history_bulk_interval_impl(request_message): 1977 args = request.args 1978 run_ids = request_message.run_ids 1979 metric_key = request_message.metric_key 1980 max_results = int(args.get("max_results", MAX_RESULTS_PER_RUN)) 1981 start_step = args.get("start_step") 1982 end_step = args.get("end_step") 1983 if start_step is not None and end_step is not None: 1984 start_step = int(start_step) 1985 end_step = int(end_step) 1986 if start_step > end_step: 1987 raise MlflowException.invalid_parameter_value( 1988 "end_step must be greater than start_step. " 1989 f"Found start_step={start_step} and end_step={end_step}." 1990 ) 1991 elif start_step is not None or end_step is not None: 1992 raise MlflowException.invalid_parameter_value( 1993 "If either start step or end step are specified, both must be specified." 1994 ) 1995 1996 store = _get_tracking_store() 1997 metrics_with_run_ids = store.get_metric_history_bulk_interval( 1998 run_ids=run_ids, 1999 metric_key=metric_key, 2000 max_results=max_results, 2001 start_step=start_step, 2002 end_step=end_step, 2003 ) 2004 2005 response_message = GetMetricHistoryBulkInterval.Response() 2006 response_message.metrics.extend([m.to_proto() for m in metrics_with_run_ids]) 2007 return response_message 2008 2009 2010 @catch_mlflow_exception 2011 @_disable_if_artifacts_only 2012 def _search_datasets_handler(): 2013 request_message = _get_request_message( 2014 SearchDatasets(), 2015 ) 2016 response_message = search_datasets_impl(request_message) 2017 response = Response(mimetype="application/json") 2018 response.set_data(message_to_json(response_message)) 2019 return response 2020 2021 2022 def search_datasets_impl(request_message): 2023 MAX_EXPERIMENT_IDS_PER_REQUEST = 20 2024 _validate_content_type(request, ["application/json"]) 2025 experiment_ids = request_message.experiment_ids or [] 2026 if not experiment_ids: 2027 raise MlflowException( 2028 message="SearchDatasets request must specify at least one experiment_id.", 2029 error_code=INVALID_PARAMETER_VALUE, 2030 ) 2031 if len(experiment_ids) > MAX_EXPERIMENT_IDS_PER_REQUEST: 2032 raise MlflowException( 2033 message=( 2034 f"SearchDatasets request cannot specify more than {MAX_EXPERIMENT_IDS_PER_REQUEST}" 2035 f" experiment_ids. Received {len(experiment_ids)} experiment_ids." 2036 ), 2037 error_code=INVALID_PARAMETER_VALUE, 2038 ) 2039 2040 store = _get_tracking_store() 2041 2042 if hasattr(store, "_search_datasets"): 2043 response_message = SearchDatasets.Response() 2044 response_message.dataset_summaries.extend([ 2045 summary.to_proto() for summary in store._search_datasets(experiment_ids) 2046 ]) 2047 return response_message 2048 else: 2049 return _not_implemented() 2050 2051 2052 def _validate_gateway_path(method: str, gateway_path: str) -> None: 2053 if not gateway_path: 2054 raise MlflowException( 2055 message="Deployments proxy request must specify a gateway_path.", 2056 error_code=INVALID_PARAMETER_VALUE, 2057 ) 2058 elif method == "GET": 2059 if gateway_path.strip("/") != "api/2.0/endpoints": 2060 raise MlflowException( 2061 message=f"Invalid gateway_path: {gateway_path} for method: {method}", 2062 error_code=INVALID_PARAMETER_VALUE, 2063 ) 2064 elif method == "POST": 2065 # For POST, gateway_path must be in the form of "gateway/{name}/invocations" 2066 if not re.fullmatch(r"gateway/[^/]+/invocations", gateway_path.strip("/")): 2067 raise MlflowException( 2068 message=f"Invalid gateway_path: {gateway_path} for method: {method}", 2069 error_code=INVALID_PARAMETER_VALUE, 2070 ) 2071 2072 2073 @catch_mlflow_exception 2074 def gateway_proxy_handler(): 2075 target_uri = MLFLOW_DEPLOYMENTS_TARGET.get() 2076 if not target_uri: 2077 # Pretend an empty gateway service is running 2078 return {"endpoints": []} 2079 2080 args = request.args if request.method == "GET" else request.json 2081 gateway_path = args.get("gateway_path") 2082 _validate_gateway_path(request.method, gateway_path) 2083 json_data = args.get("json_data", None) 2084 response = requests.request(request.method, f"{target_uri}/{gateway_path}", json=json_data) 2085 if response.status_code == 200: 2086 return response.json() 2087 else: 2088 raise MlflowException( 2089 message=f"Deployments proxy request failed with error code {response.status_code}. " 2090 f"Error message: {response.text}", 2091 error_code=response.status_code, 2092 ) 2093 2094 2095 @catch_mlflow_exception 2096 @_disable_if_artifacts_only 2097 def create_promptlab_run_handler(): 2098 def assert_arg_exists(arg_name, arg): 2099 if not arg: 2100 raise MlflowException( 2101 message=f"CreatePromptlabRun request must specify {arg_name}.", 2102 error_code=INVALID_PARAMETER_VALUE, 2103 ) 2104 2105 _validate_content_type(request, ["application/json"]) 2106 2107 args = request.json 2108 experiment_id = args.get("experiment_id") 2109 assert_arg_exists("experiment_id", experiment_id) 2110 run_name = args.get("run_name", None) 2111 tags = args.get("tags", []) 2112 prompt_template = args.get("prompt_template") 2113 assert_arg_exists("prompt_template", prompt_template) 2114 raw_prompt_parameters = args.get("prompt_parameters") 2115 assert_arg_exists("prompt_parameters", raw_prompt_parameters) 2116 prompt_parameters = [ 2117 Param(param.get("key"), param.get("value")) for param in args.get("prompt_parameters") 2118 ] 2119 model_route = args.get("model_route") 2120 assert_arg_exists("model_route", model_route) 2121 raw_model_parameters = args.get("model_parameters", []) 2122 model_parameters = [ 2123 Param(param.get("key"), param.get("value")) for param in raw_model_parameters 2124 ] 2125 model_input = args.get("model_input") 2126 assert_arg_exists("model_input", model_input) 2127 model_output = args.get("model_output", None) 2128 raw_model_output_parameters = args.get("model_output_parameters", []) 2129 model_output_parameters = [ 2130 Param(param.get("key"), param.get("value")) for param in raw_model_output_parameters 2131 ] 2132 mlflow_version = args.get("mlflow_version") 2133 assert_arg_exists("mlflow_version", mlflow_version) 2134 user_id = args.get("user_id", "unknown") 2135 2136 # use current time if not provided 2137 start_time = args.get("start_time", int(time.time() * 1000)) 2138 2139 store = _get_tracking_store() 2140 2141 run = _create_promptlab_run_impl( 2142 store, 2143 experiment_id=experiment_id, 2144 run_name=run_name, 2145 tags=tags, 2146 prompt_template=prompt_template, 2147 prompt_parameters=prompt_parameters, 2148 model_route=model_route, 2149 model_parameters=model_parameters, 2150 model_input=model_input, 2151 model_output=model_output, 2152 model_output_parameters=model_output_parameters, 2153 mlflow_version=mlflow_version, 2154 user_id=user_id, 2155 start_time=start_time, 2156 ) 2157 response_message = CreateRun.Response() 2158 response_message.run.MergeFrom(run.to_proto()) 2159 response = Response(mimetype="application/json") 2160 response.set_data(message_to_json(response_message)) 2161 return response 2162 2163 2164 @catch_mlflow_exception 2165 def upload_artifact_handler(): 2166 args = request.args 2167 run_uuid = args.get("run_uuid") 2168 if not run_uuid: 2169 raise MlflowException( 2170 message="Request must specify run_uuid.", 2171 error_code=INVALID_PARAMETER_VALUE, 2172 ) 2173 path = args.get("path") 2174 if not path: 2175 raise MlflowException( 2176 message="Request must specify path.", 2177 error_code=INVALID_PARAMETER_VALUE, 2178 ) 2179 path = validate_path_is_safe(path) 2180 2181 if request.content_length and request.content_length > 10 * 1024 * 1024: 2182 raise MlflowException( 2183 message="Artifact size is too large. Max size is 10MB.", 2184 error_code=INVALID_PARAMETER_VALUE, 2185 ) 2186 2187 data = request.data 2188 if not data: 2189 raise MlflowException( 2190 message="Request must specify data.", 2191 error_code=INVALID_PARAMETER_VALUE, 2192 ) 2193 2194 run = _get_tracking_store().get_run(run_uuid) 2195 artifact_dir = run.info.artifact_uri 2196 2197 basename = posixpath.basename(path) 2198 dirname = posixpath.dirname(path) 2199 2200 def _log_artifact_to_repo(file, run, dirname, artifact_dir): 2201 if _is_servable_proxied_run_artifact_root(run.info.artifact_uri): 2202 artifact_repo = _get_artifact_repo_mlflow_artifacts() 2203 # Use posixpath.join since these are logical artifact paths (not local filesystem paths) 2204 # that should always use forward slashes regardless of the platform. 2205 path_to_log = ( 2206 posixpath.join(run.info.experiment_id, run.info.run_id, "artifacts", dirname) 2207 if dirname 2208 else posixpath.join(run.info.experiment_id, run.info.run_id, "artifacts") 2209 ) 2210 path_to_log = _get_workspace_scoped_repo_path_if_enabled(path_to_log) 2211 else: 2212 artifact_repo = get_artifact_repository(artifact_dir) 2213 path_to_log = dirname 2214 2215 artifact_repo.log_artifact(file, path_to_log) 2216 2217 with tempfile.TemporaryDirectory() as tmpdir: 2218 dir_path = os.path.join(tmpdir, dirname) if dirname else tmpdir 2219 file_path = os.path.join(dir_path, basename) 2220 2221 os.makedirs(dir_path, exist_ok=True) 2222 2223 with open(file_path, "wb") as f: 2224 f.write(data) 2225 2226 _log_artifact_to_repo(file_path, run, dirname, artifact_dir) 2227 2228 return Response(mimetype="application/json") 2229 2230 2231 @catch_mlflow_exception 2232 @_disable_if_artifacts_only 2233 def _search_experiments(): 2234 request_message = _get_request_message( 2235 SearchExperiments(), 2236 schema={ 2237 "view_type": [_assert_intlike], 2238 "max_results": [_assert_intlike], 2239 "order_by": [_assert_array], 2240 "filter": [_assert_string], 2241 "page_token": [_assert_string], 2242 }, 2243 ) 2244 2245 experiment_entities = _get_tracking_store().search_experiments( 2246 view_type=request_message.view_type, 2247 max_results=request_message.max_results, 2248 order_by=request_message.order_by, 2249 filter_string=request_message.filter, 2250 page_token=request_message.page_token or None, 2251 ) 2252 response_message = SearchExperiments.Response() 2253 response_message.experiments.extend([e.to_proto() for e in experiment_entities]) 2254 if experiment_entities.token: 2255 response_message.next_page_token = experiment_entities.token 2256 response = Response(mimetype="application/json") 2257 response.set_data(message_to_json(response_message)) 2258 return response 2259 2260 2261 @catch_mlflow_exception 2262 def _get_artifact_repo(run): 2263 return get_artifact_repository(run.info.artifact_uri) 2264 2265 2266 @catch_mlflow_exception 2267 @_disable_if_artifacts_only 2268 def _log_batch(): 2269 def _assert_metrics_fields_present(metrics): 2270 for idx, m in enumerate(metrics): 2271 _assert_required(m.get("key"), path=f"metrics[{idx}].key") 2272 _assert_required(m.get("value"), path=f"metrics[{idx}].value") 2273 _assert_required(m.get("timestamp"), path=f"metrics[{idx}].timestamp") 2274 2275 def _assert_params_fields_present(params): 2276 for idx, param in enumerate(params): 2277 _assert_required(param.get("key"), path=f"params[{idx}].key") 2278 2279 def _assert_tags_fields_present(tags): 2280 for idx, tag in enumerate(tags): 2281 _assert_required(tag.get("key"), path=f"tags[{idx}].key") 2282 2283 _validate_batch_log_api_req(_get_request_json()) 2284 request_message = _get_request_message( 2285 LogBatch(), 2286 schema={ 2287 "run_id": [_assert_string, _assert_required], 2288 "metrics": [_assert_array, _assert_metrics_fields_present], 2289 "params": [_assert_array, _assert_params_fields_present], 2290 "tags": [_assert_array, _assert_tags_fields_present], 2291 }, 2292 ) 2293 metrics = [Metric.from_proto(proto_metric) for proto_metric in request_message.metrics] 2294 params = [Param.from_proto(proto_param) for proto_param in request_message.params] 2295 tags = [RunTag.from_proto(proto_tag) for proto_tag in request_message.tags] 2296 _get_tracking_store().log_batch( 2297 run_id=request_message.run_id, metrics=metrics, params=params, tags=tags 2298 ) 2299 response_message = LogBatch.Response() 2300 response = Response(mimetype="application/json") 2301 response.set_data(message_to_json(response_message)) 2302 return response 2303 2304 2305 @catch_mlflow_exception 2306 @_disable_if_artifacts_only 2307 def _log_model(): 2308 request_message = _get_request_message( 2309 LogModel(), 2310 schema={ 2311 "run_id": [_assert_string, _assert_required], 2312 "model_json": [_assert_string, _assert_required], 2313 }, 2314 ) 2315 try: 2316 model = json.loads(request_message.model_json) 2317 except Exception: 2318 raise MlflowException( 2319 f"Malformed model info. \n {request_message.model_json} \n is not a valid JSON.", 2320 error_code=INVALID_PARAMETER_VALUE, 2321 ) 2322 2323 missing_fields = {"artifact_path", "flavors", "utc_time_created", "run_id"} - set(model.keys()) 2324 2325 if missing_fields: 2326 raise MlflowException( 2327 f"Model json is missing mandatory fields: {missing_fields}", 2328 error_code=INVALID_PARAMETER_VALUE, 2329 ) 2330 _get_tracking_store().record_logged_model( 2331 run_id=request_message.run_id, mlflow_model=Model.from_dict(model) 2332 ) 2333 response_message = LogModel.Response() 2334 response = Response(mimetype="application/json") 2335 response.set_data(message_to_json(response_message)) 2336 return response 2337 2338 2339 def _wrap_response(response_message): 2340 response = Response(mimetype="application/json") 2341 response.set_data(message_to_json(response_message)) 2342 return response 2343 2344 2345 # Model Registry APIs 2346 2347 2348 @catch_mlflow_exception 2349 @_disable_if_artifacts_only 2350 def _create_registered_model(): 2351 request_message = _get_request_message( 2352 CreateRegisteredModel(), 2353 schema={ 2354 "name": [_assert_string, _assert_required], 2355 "tags": [_assert_array], 2356 "description": [_assert_string], 2357 }, 2358 ) 2359 store = _get_model_registry_store() 2360 registered_model = store.create_registered_model( 2361 name=request_message.name, 2362 tags=request_message.tags, 2363 description=request_message.description, 2364 ) 2365 response_message = CreateRegisteredModel.Response(registered_model=registered_model.to_proto()) 2366 2367 # Determine if this is a prompt based on the tags 2368 if _is_prompt_request(request_message): 2369 # Send prompt creation webhook 2370 deliver_webhook( 2371 event=WebhookEvent(WebhookEntity.PROMPT, WebhookAction.CREATED), 2372 payload=PromptCreatedPayload( 2373 name=request_message.name, 2374 tags={ 2375 t.key: t.value 2376 for t in request_message.tags 2377 if t.key not in {IS_PROMPT_TAG_KEY, PROMPT_TYPE_TAG_KEY} 2378 }, 2379 description=request_message.description, 2380 ), 2381 store=store, 2382 ) 2383 else: 2384 # Send regular model creation webhook 2385 deliver_webhook( 2386 event=WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED), 2387 payload=RegisteredModelCreatedPayload( 2388 name=request_message.name, 2389 tags={t.key: t.value for t in request_message.tags}, 2390 description=request_message.description, 2391 ), 2392 store=store, 2393 ) 2394 2395 return _wrap_response(response_message) 2396 2397 2398 @catch_mlflow_exception 2399 @_disable_if_artifacts_only 2400 def _get_registered_model(): 2401 request_message = _get_request_message( 2402 GetRegisteredModel(), schema={"name": [_assert_string, _assert_required]} 2403 ) 2404 registered_model = _get_model_registry_store().get_registered_model(name=request_message.name) 2405 response_message = GetRegisteredModel.Response(registered_model=registered_model.to_proto()) 2406 return _wrap_response(response_message) 2407 2408 2409 @catch_mlflow_exception 2410 @_disable_if_artifacts_only 2411 def _update_registered_model(): 2412 request_message = _get_request_message( 2413 UpdateRegisteredModel(), 2414 schema={ 2415 "name": [_assert_string, _assert_required], 2416 "description": [_assert_string], 2417 }, 2418 ) 2419 name = request_message.name 2420 new_description = request_message.description 2421 registered_model = _get_model_registry_store().update_registered_model( 2422 name=name, description=new_description 2423 ) 2424 response_message = UpdateRegisteredModel.Response(registered_model=registered_model.to_proto()) 2425 return _wrap_response(response_message) 2426 2427 2428 @catch_mlflow_exception 2429 @_disable_if_artifacts_only 2430 def _rename_registered_model(): 2431 request_message = _get_request_message( 2432 RenameRegisteredModel(), 2433 schema={ 2434 "name": [_assert_string, _assert_required], 2435 "new_name": [_assert_string, _assert_required], 2436 }, 2437 ) 2438 name = request_message.name 2439 new_name = request_message.new_name 2440 registered_model = _get_model_registry_store().rename_registered_model( 2441 name=name, new_name=new_name 2442 ) 2443 response_message = RenameRegisteredModel.Response(registered_model=registered_model.to_proto()) 2444 return _wrap_response(response_message) 2445 2446 2447 @catch_mlflow_exception 2448 @_disable_if_artifacts_only 2449 def _delete_registered_model(): 2450 request_message = _get_request_message( 2451 DeleteRegisteredModel(), schema={"name": [_assert_string, _assert_required]} 2452 ) 2453 _get_model_registry_store().delete_registered_model(name=request_message.name) 2454 return _wrap_response(DeleteRegisteredModel.Response()) 2455 2456 2457 @catch_mlflow_exception 2458 @_disable_if_artifacts_only 2459 def _search_registered_models(): 2460 request_message = _get_request_message( 2461 SearchRegisteredModels(), 2462 schema={ 2463 "filter": [_assert_string], 2464 "max_results": [ 2465 _assert_intlike, 2466 lambda x: _assert_less_than_or_equal(int(x), 1000), 2467 ], 2468 "order_by": [_assert_array, _assert_item_type_string], 2469 "page_token": [_assert_string], 2470 }, 2471 ) 2472 store = _get_model_registry_store() 2473 registered_models = store.search_registered_models( 2474 filter_string=request_message.filter, 2475 max_results=request_message.max_results, 2476 order_by=request_message.order_by, 2477 page_token=request_message.page_token or None, 2478 ) 2479 response_message = SearchRegisteredModels.Response() 2480 response_message.registered_models.extend([e.to_proto() for e in registered_models]) 2481 if registered_models.token: 2482 response_message.next_page_token = registered_models.token 2483 return _wrap_response(response_message) 2484 2485 2486 @catch_mlflow_exception 2487 @_disable_if_artifacts_only 2488 def _get_latest_versions(): 2489 request_message = _get_request_message( 2490 GetLatestVersions(), 2491 schema={ 2492 "name": [_assert_string, _assert_required], 2493 "stages": [_assert_array, _assert_item_type_string], 2494 }, 2495 ) 2496 latest_versions = _get_model_registry_store().get_latest_versions( 2497 name=request_message.name, stages=request_message.stages 2498 ) 2499 response_message = GetLatestVersions.Response() 2500 response_message.model_versions.extend([e.to_proto() for e in latest_versions]) 2501 return _wrap_response(response_message) 2502 2503 2504 @catch_mlflow_exception 2505 @_disable_if_artifacts_only 2506 def _set_registered_model_tag(): 2507 request_message = _get_request_message( 2508 SetRegisteredModelTag(), 2509 schema={ 2510 "name": [_assert_string, _assert_required], 2511 "key": [_assert_string, _assert_required], 2512 "value": [_assert_string], 2513 }, 2514 ) 2515 tag = RegisteredModelTag(key=request_message.key, value=request_message.value) 2516 store = _get_model_registry_store() 2517 store.set_registered_model_tag(name=request_message.name, tag=tag) 2518 2519 if _is_prompt(request_message.name): 2520 # Send prompt tag set webhook 2521 deliver_webhook( 2522 event=WebhookEvent(WebhookEntity.PROMPT_TAG, WebhookAction.SET), 2523 payload=PromptTagSetPayload( 2524 name=request_message.name, 2525 key=request_message.key, 2526 value=request_message.value, 2527 ), 2528 store=store, 2529 ) 2530 2531 return _wrap_response(SetRegisteredModelTag.Response()) 2532 2533 2534 @catch_mlflow_exception 2535 @_disable_if_artifacts_only 2536 def _delete_registered_model_tag(): 2537 request_message = _get_request_message( 2538 DeleteRegisteredModelTag(), 2539 schema={ 2540 "name": [_assert_string, _assert_required], 2541 "key": [_assert_string, _assert_required], 2542 }, 2543 ) 2544 store = _get_model_registry_store() 2545 store.delete_registered_model_tag(name=request_message.name, key=request_message.key) 2546 2547 if _is_prompt(request_message.name): 2548 # Send prompt tag deleted webhook 2549 deliver_webhook( 2550 event=WebhookEvent(WebhookEntity.PROMPT_TAG, WebhookAction.DELETED), 2551 payload=PromptTagDeletedPayload( 2552 name=request_message.name, 2553 key=request_message.key, 2554 ), 2555 store=store, 2556 ) 2557 2558 return _wrap_response(DeleteRegisteredModelTag.Response()) 2559 2560 2561 def _validate_non_local_source_contains_relative_paths(source: str): 2562 """ 2563 Validation check to ensure that sources that are provided that conform to the schemes: 2564 http, https, or mlflow-artifacts do not contain relative path designations that are intended 2565 to access local file system paths on the tracking server. 2566 2567 Example paths that this validation function is intended to find and raise an Exception if 2568 passed: 2569 "mlflow-artifacts://host:port/../../../../" 2570 "http://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../" 2571 "https://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../" 2572 "/models/artifacts/../../../" 2573 "s3:/my_bucket/models/path/../../other/path" 2574 "file://path/to/../../../../some/where/you/should/not/be" 2575 "mlflow-artifacts://host:port/..%2f..%2f..%2f..%2f" 2576 "http://host:port/api/2.0/mlflow-artifacts/artifacts%00" 2577 """ 2578 invalid_source_error_message = ( 2579 f"Invalid model version source: '{source}'. If supplying a source as an http, https, " 2580 "local file path, ftp, objectstore, or mlflow-artifacts uri, an absolute path must be " 2581 "provided without relative path references present. " 2582 "Please provide an absolute path." 2583 ) 2584 2585 while (unquoted := urllib.parse.unquote_plus(source)) != source: 2586 source = unquoted 2587 source_path = re.sub(r"/+", "/", urllib.parse.urlparse(source).path.rstrip("/")) 2588 if "\x00" in source_path or any(p == ".." for p in source.split("/")): 2589 raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE) 2590 resolved_source = pathlib.Path(source_path).resolve().as_posix() 2591 # NB: drive split is specifically for Windows since WindowsPath.resolve() will append the 2592 # drive path of the pwd to a given path. We don't care about the drive here, though. 2593 _, resolved_path = os.path.splitdrive(resolved_source) 2594 2595 if resolved_path != source_path: 2596 raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE) 2597 2598 2599 def _validate_source_run(source: str, run_id: str) -> None: 2600 if is_local_uri(source): 2601 if run_id: 2602 store = _get_tracking_store() 2603 run = store.get_run(run_id) 2604 source = pathlib.Path(local_file_uri_to_path(source)).resolve() 2605 if is_local_uri(run.info.artifact_uri): 2606 run_artifact_dir = pathlib.Path( 2607 local_file_uri_to_path(run.info.artifact_uri) 2608 ).resolve() 2609 if run_artifact_dir in [source, *source.parents]: 2610 return 2611 2612 raise MlflowException( 2613 f"Invalid model version source: '{source}'. To use a local path as a model version " 2614 "source, the run_id request parameter has to be specified and the local path has to be " 2615 "contained within the artifact directory of the run specified by the run_id.", 2616 INVALID_PARAMETER_VALUE, 2617 ) 2618 2619 # Checks if relative paths are present in the source (a security threat). If any are present, 2620 # raises an Exception. 2621 _validate_non_local_source_contains_relative_paths(source) 2622 2623 2624 def _validate_source_model(source: str, model_id: str) -> None: 2625 if is_local_uri(source): 2626 if model_id: 2627 store = _get_tracking_store() 2628 model = store.get_logged_model(model_id) 2629 source = pathlib.Path(local_file_uri_to_path(source)).resolve() 2630 if is_local_uri(model.artifact_location): 2631 run_artifact_dir = pathlib.Path( 2632 local_file_uri_to_path(model.artifact_location) 2633 ).resolve() 2634 if run_artifact_dir in [source, *source.parents]: 2635 return 2636 2637 raise MlflowException( 2638 f"Invalid model version source: '{source}'. To use a local path as a model version " 2639 "source, the model_id request parameter has to be specified and the local path has to " 2640 "be contained within the artifact directory of the run specified by the model_id.", 2641 INVALID_PARAMETER_VALUE, 2642 ) 2643 2644 # Checks if relative paths are present in the source (a security threat). If any are present, 2645 # raises an Exception. 2646 _validate_non_local_source_contains_relative_paths(source) 2647 2648 2649 @catch_mlflow_exception 2650 @_disable_if_artifacts_only 2651 def _create_model_version(): 2652 request_message = _get_request_message( 2653 CreateModelVersion(), 2654 schema={ 2655 "name": [_assert_string, _assert_required], 2656 "source": [_assert_string, _assert_required], 2657 "run_id": [_assert_string], 2658 "tags": [_assert_array], 2659 "run_link": [_assert_string], 2660 "description": [_assert_string], 2661 "model_id": [_assert_string], 2662 }, 2663 ) 2664 2665 if request_message.source and ( 2666 regex := MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX.get() 2667 ): 2668 if not re.search(regex, request_message.source): 2669 raise MlflowException( 2670 f"Invalid model version source: '{request_message.source}'.", 2671 error_code=INVALID_PARAMETER_VALUE, 2672 ) 2673 2674 is_prompt = _is_prompt_request(request_message) 2675 if is_prompt: 2676 # Prompt sources must not point to local filesystem paths. 2677 # Block file:// URIs and absolute paths (e.g. /etc/passwd) but allow 2678 # the legitimate schemeless placeholder sources used internally 2679 # (e.g. "prompt-template", "dummy-source"). 2680 source = request_message.source 2681 parsed = urllib.parse.urlparse(source) 2682 if parsed.scheme == "file" or (parsed.scheme == "" and source.startswith("/")): 2683 raise MlflowException( 2684 f"Invalid prompt source: '{source}'. " 2685 "Local source paths are not allowed for prompts.", 2686 INVALID_PARAMETER_VALUE, 2687 ) 2688 # Only validate traversal for sources with a URL scheme (http, https, etc.) 2689 if parsed.scheme: 2690 _validate_non_local_source_contains_relative_paths(source) 2691 else: 2692 if request_message.model_id: 2693 _validate_source_model(request_message.source, request_message.model_id) 2694 else: 2695 _validate_source_run(request_message.source, request_message.run_id) 2696 2697 store = _get_model_registry_store() 2698 model_version = store.create_model_version( 2699 name=request_message.name, 2700 source=request_message.source, 2701 run_id=request_message.run_id, 2702 run_link=request_message.run_link, 2703 tags=request_message.tags, 2704 description=request_message.description, 2705 model_id=request_message.model_id, 2706 ) 2707 if not is_prompt and request_message.model_id: 2708 tracking_store = _get_tracking_store() 2709 tracking_store.set_model_versions_tags( 2710 name=request_message.name, 2711 version=model_version.version, 2712 model_id=request_message.model_id, 2713 ) 2714 response_message = CreateModelVersion.Response(model_version=model_version.to_proto()) 2715 2716 if is_prompt: 2717 # Convert tags to dict and extract template text efficiently 2718 tags_dict = {t.key: t.value for t in request_message.tags} 2719 template_text = tags_dict.pop(PROMPT_TEXT_TAG_KEY, None) 2720 # Remove internal prompt identification and type tags 2721 tags_dict.pop(IS_PROMPT_TAG_KEY, None) 2722 tags_dict.pop(PROMPT_TYPE_TAG_KEY, None) 2723 2724 # Send prompt version creation webhook 2725 deliver_webhook( 2726 event=WebhookEvent(WebhookEntity.PROMPT_VERSION, WebhookAction.CREATED), 2727 payload=PromptVersionCreatedPayload( 2728 name=request_message.name, 2729 version=str(model_version.version), 2730 template=template_text, 2731 tags=tags_dict, 2732 description=request_message.description or None, 2733 ), 2734 store=store, 2735 ) 2736 else: 2737 # Send regular model version creation webhook 2738 deliver_webhook( 2739 event=WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED), 2740 payload=ModelVersionCreatedPayload( 2741 name=request_message.name, 2742 version=str(model_version.version), 2743 source=request_message.source, 2744 run_id=request_message.run_id or None, 2745 tags={t.key: t.value for t in request_message.tags}, 2746 description=request_message.description or None, 2747 ), 2748 store=store, 2749 ) 2750 2751 return _wrap_response(response_message) 2752 2753 2754 def _is_prompt_request(request_message): 2755 return any(tag.key == IS_PROMPT_TAG_KEY for tag in request_message.tags) 2756 2757 2758 def _is_prompt(name: str) -> bool: 2759 rm = _get_model_registry_store().get_registered_model(name=name) 2760 return rm._is_prompt() 2761 2762 2763 @catch_mlflow_exception 2764 @_disable_if_artifacts_only 2765 def get_model_version_artifact_handler(): 2766 name = request.args.get("name") 2767 version = request.args.get("version") 2768 path = request.args["path"] 2769 path = validate_path_is_safe(path) 2770 artifact_uri = _get_model_registry_store().get_model_version_download_uri(name, version) 2771 if _is_servable_proxied_run_artifact_root(artifact_uri): 2772 artifact_repo = _get_artifact_repo_mlflow_artifacts() 2773 artifact_path = _get_proxied_run_artifact_destination_path( 2774 proxied_artifact_root=artifact_uri, 2775 relative_path=path, 2776 ) 2777 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 2778 else: 2779 artifact_repo = get_artifact_repository(artifact_uri) 2780 artifact_path = path 2781 2782 return _send_artifact(artifact_repo, artifact_path) 2783 2784 2785 @catch_mlflow_exception 2786 @_disable_if_artifacts_only 2787 def _get_model_version(): 2788 request_message = _get_request_message( 2789 GetModelVersion(), 2790 schema={ 2791 "name": [_assert_string, _assert_required], 2792 "version": [_assert_string, _assert_required], 2793 }, 2794 ) 2795 model_version = _get_model_registry_store().get_model_version( 2796 name=request_message.name, version=request_message.version 2797 ) 2798 response_proto = model_version.to_proto() 2799 response_message = GetModelVersion.Response(model_version=response_proto) 2800 return _wrap_response(response_message) 2801 2802 2803 @catch_mlflow_exception 2804 @_disable_if_artifacts_only 2805 def _update_model_version(): 2806 request_message = _get_request_message( 2807 UpdateModelVersion(), 2808 schema={ 2809 "name": [_assert_string, _assert_required], 2810 "version": [_assert_string, _assert_required], 2811 "description": [_assert_string], 2812 }, 2813 ) 2814 new_description = None 2815 if request_message.HasField("description"): 2816 new_description = request_message.description 2817 model_version = _get_model_registry_store().update_model_version( 2818 name=request_message.name, 2819 version=request_message.version, 2820 description=new_description, 2821 ) 2822 return _wrap_response(UpdateModelVersion.Response(model_version=model_version.to_proto())) 2823 2824 2825 @catch_mlflow_exception 2826 @_disable_if_artifacts_only 2827 def _transition_stage(): 2828 request_message = _get_request_message( 2829 TransitionModelVersionStage(), 2830 schema={ 2831 "name": [_assert_string, _assert_required], 2832 "version": [_assert_string, _assert_required], 2833 "stage": [_assert_string, _assert_required], 2834 "archive_existing_versions": [_assert_bool], 2835 }, 2836 ) 2837 model_version = _get_model_registry_store().transition_model_version_stage( 2838 name=request_message.name, 2839 version=request_message.version, 2840 stage=request_message.stage, 2841 archive_existing_versions=request_message.archive_existing_versions, 2842 ) 2843 return _wrap_response( 2844 TransitionModelVersionStage.Response(model_version=model_version.to_proto()) 2845 ) 2846 2847 2848 @catch_mlflow_exception 2849 @_disable_if_artifacts_only 2850 def _delete_model_version(): 2851 request_message = _get_request_message( 2852 DeleteModelVersion(), 2853 schema={ 2854 "name": [_assert_string, _assert_required], 2855 "version": [_assert_string, _assert_required], 2856 }, 2857 ) 2858 _get_model_registry_store().delete_model_version( 2859 name=request_message.name, version=request_message.version 2860 ) 2861 return _wrap_response(DeleteModelVersion.Response()) 2862 2863 2864 @catch_mlflow_exception 2865 @_disable_if_artifacts_only 2866 def _get_model_version_download_uri(): 2867 request_message = _get_request_message(GetModelVersionDownloadUri()) 2868 download_uri = _get_model_registry_store().get_model_version_download_uri( 2869 name=request_message.name, version=request_message.version 2870 ) 2871 response_message = GetModelVersionDownloadUri.Response(artifact_uri=download_uri) 2872 return _wrap_response(response_message) 2873 2874 2875 @catch_mlflow_exception 2876 @_disable_if_artifacts_only 2877 def _search_model_versions(): 2878 request_message = _get_request_message( 2879 SearchModelVersions(), 2880 schema={ 2881 "filter": [_assert_string], 2882 "max_results": [ 2883 _assert_intlike, 2884 lambda x: _assert_less_than_or_equal(int(x), 200_000), 2885 ], 2886 "order_by": [_assert_array, _assert_item_type_string], 2887 "page_token": [_assert_string], 2888 }, 2889 ) 2890 response_message = search_model_versions_impl(request_message) 2891 return _wrap_response(response_message) 2892 2893 2894 def search_model_versions_impl(request_message): 2895 store = _get_model_registry_store() 2896 model_versions = store.search_model_versions( 2897 filter_string=request_message.filter, 2898 max_results=request_message.max_results, 2899 order_by=request_message.order_by, 2900 page_token=request_message.page_token or None, 2901 ) 2902 response_message = SearchModelVersions.Response() 2903 response_message.model_versions.extend([e.to_proto() for e in model_versions]) 2904 if model_versions.token: 2905 response_message.next_page_token = model_versions.token 2906 return response_message 2907 2908 2909 @catch_mlflow_exception 2910 @_disable_if_artifacts_only 2911 def _set_model_version_tag(): 2912 request_message = _get_request_message( 2913 SetModelVersionTag(), 2914 schema={ 2915 "name": [_assert_string, _assert_required], 2916 "version": [_assert_string, _assert_required], 2917 "key": [_assert_string, _assert_required], 2918 "value": [_assert_string], 2919 }, 2920 ) 2921 tag = ModelVersionTag(key=request_message.key, value=request_message.value) 2922 store = _get_model_registry_store() 2923 store.set_model_version_tag(name=request_message.name, version=request_message.version, tag=tag) 2924 2925 if _is_prompt(request_message.name): 2926 # Send prompt version tag set webhook 2927 deliver_webhook( 2928 event=WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.SET), 2929 payload=PromptVersionTagSetPayload( 2930 name=request_message.name, 2931 version=request_message.version, 2932 key=request_message.key, 2933 value=request_message.value, 2934 ), 2935 store=store, 2936 ) 2937 else: 2938 # Send regular model version tag set webhook 2939 deliver_webhook( 2940 event=WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.SET), 2941 payload=ModelVersionTagSetPayload( 2942 name=request_message.name, 2943 version=request_message.version, 2944 key=request_message.key, 2945 value=request_message.value, 2946 ), 2947 store=store, 2948 ) 2949 2950 return _wrap_response(SetModelVersionTag.Response()) 2951 2952 2953 @catch_mlflow_exception 2954 @_disable_if_artifacts_only 2955 def _delete_model_version_tag(): 2956 request_message = _get_request_message( 2957 DeleteModelVersionTag(), 2958 schema={ 2959 "name": [_assert_string, _assert_required], 2960 "version": [_assert_string, _assert_required], 2961 "key": [_assert_string, _assert_required], 2962 }, 2963 ) 2964 store = _get_model_registry_store() 2965 store.delete_model_version_tag( 2966 name=request_message.name, 2967 version=request_message.version, 2968 key=request_message.key, 2969 ) 2970 2971 if _is_prompt(request_message.name): 2972 # Send prompt version tag deleted webhook 2973 deliver_webhook( 2974 event=WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.DELETED), 2975 payload=PromptVersionTagDeletedPayload( 2976 name=request_message.name, 2977 version=request_message.version, 2978 key=request_message.key, 2979 ), 2980 store=store, 2981 ) 2982 else: 2983 # Send regular model version tag deleted webhook 2984 deliver_webhook( 2985 event=WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.DELETED), 2986 payload=ModelVersionTagDeletedPayload( 2987 name=request_message.name, 2988 version=request_message.version, 2989 key=request_message.key, 2990 ), 2991 store=store, 2992 ) 2993 2994 return _wrap_response(DeleteModelVersionTag.Response()) 2995 2996 2997 @catch_mlflow_exception 2998 @_disable_if_artifacts_only 2999 def _set_registered_model_alias(): 3000 request_message = _get_request_message( 3001 SetRegisteredModelAlias(), 3002 schema={ 3003 "name": [_assert_string, _assert_required], 3004 "alias": [_assert_string, _assert_required], 3005 "version": [_assert_string, _assert_required], 3006 }, 3007 ) 3008 store = _get_model_registry_store() 3009 store.set_registered_model_alias( 3010 name=request_message.name, 3011 alias=request_message.alias, 3012 version=request_message.version, 3013 ) 3014 3015 if _is_prompt(request_message.name): 3016 # Send prompt alias created webhook 3017 deliver_webhook( 3018 event=WebhookEvent(WebhookEntity.PROMPT_ALIAS, WebhookAction.CREATED), 3019 payload=PromptAliasCreatedPayload( 3020 name=request_message.name, 3021 alias=request_message.alias, 3022 version=request_message.version, 3023 ), 3024 store=store, 3025 ) 3026 else: 3027 # Send regular model version alias created webhook 3028 deliver_webhook( 3029 event=WebhookEvent(WebhookEntity.MODEL_VERSION_ALIAS, WebhookAction.CREATED), 3030 payload=ModelVersionAliasCreatedPayload( 3031 name=request_message.name, 3032 alias=request_message.alias, 3033 version=request_message.version, 3034 ), 3035 store=store, 3036 ) 3037 3038 return _wrap_response(SetRegisteredModelAlias.Response()) 3039 3040 3041 @catch_mlflow_exception 3042 @_disable_if_artifacts_only 3043 def _delete_registered_model_alias(): 3044 request_message = _get_request_message( 3045 DeleteRegisteredModelAlias(), 3046 schema={ 3047 "name": [_assert_string, _assert_required], 3048 "alias": [_assert_string, _assert_required], 3049 }, 3050 ) 3051 store = _get_model_registry_store() 3052 store.delete_registered_model_alias(name=request_message.name, alias=request_message.alias) 3053 3054 if _is_prompt(request_message.name): 3055 # Send prompt alias deleted webhook 3056 deliver_webhook( 3057 event=WebhookEvent(WebhookEntity.PROMPT_ALIAS, WebhookAction.DELETED), 3058 payload=PromptAliasDeletedPayload( 3059 name=request_message.name, 3060 alias=request_message.alias, 3061 ), 3062 store=store, 3063 ) 3064 else: 3065 # Send regular model version alias deleted webhook 3066 deliver_webhook( 3067 event=WebhookEvent(WebhookEntity.MODEL_VERSION_ALIAS, WebhookAction.DELETED), 3068 payload=ModelVersionAliasDeletedPayload( 3069 name=request_message.name, 3070 alias=request_message.alias, 3071 ), 3072 store=store, 3073 ) 3074 3075 return _wrap_response(DeleteRegisteredModelAlias.Response()) 3076 3077 3078 @catch_mlflow_exception 3079 @_disable_if_artifacts_only 3080 def _get_model_version_by_alias(): 3081 request_message = _get_request_message( 3082 GetModelVersionByAlias(), 3083 schema={ 3084 "name": [_assert_string, _assert_required], 3085 "alias": [_assert_string, _assert_required], 3086 }, 3087 ) 3088 model_version = _get_model_registry_store().get_model_version_by_alias( 3089 name=request_message.name, alias=request_message.alias 3090 ) 3091 response_proto = model_version.to_proto() 3092 response_message = GetModelVersionByAlias.Response(model_version=response_proto) 3093 return _wrap_response(response_message) 3094 3095 3096 # Webhook APIs 3097 @catch_mlflow_exception 3098 @_disable_if_artifacts_only 3099 def _create_webhook(): 3100 request_message = _get_request_message( 3101 CreateWebhook(), 3102 schema={ 3103 "name": [_assert_string, _assert_required], 3104 "url": [_assert_string, _assert_required], 3105 "events": [_assert_array, _assert_required], 3106 "description": [_assert_string], 3107 "secret": [_assert_string], 3108 "status": [_assert_string], 3109 }, 3110 ) 3111 3112 webhook = _get_model_registry_store().create_webhook( 3113 name=request_message.name, 3114 url=request_message.url, 3115 events=[WebhookEvent.from_proto(e) for e in request_message.events], 3116 description=request_message.description or None, 3117 secret=request_message.secret or None, 3118 status=WebhookStatus.from_proto(request_message.status) if request_message.status else None, 3119 ) 3120 response_message = CreateWebhook.Response(webhook=webhook.to_proto()) 3121 return _wrap_response(response_message) 3122 3123 3124 @catch_mlflow_exception 3125 @_disable_if_artifacts_only 3126 def _list_webhooks(): 3127 request_message = _get_request_message( 3128 ListWebhooks(), 3129 schema={ 3130 "max_results": [_assert_intlike], 3131 "page_token": [_assert_string], 3132 }, 3133 ) 3134 webhooks_page = _get_model_registry_store().list_webhooks( 3135 max_results=request_message.max_results, 3136 page_token=request_message.page_token or None, 3137 ) 3138 response_message = ListWebhooks.Response( 3139 webhooks=[w.to_proto() for w in webhooks_page], 3140 next_page_token=webhooks_page.token, 3141 ) 3142 return _wrap_response(response_message) 3143 3144 3145 @catch_mlflow_exception 3146 @_disable_if_artifacts_only 3147 def _get_webhook(webhook_id: str): 3148 webhook = _get_model_registry_store().get_webhook(webhook_id=webhook_id) 3149 response_message = GetWebhook.Response(webhook=webhook.to_proto()) 3150 return _wrap_response(response_message) 3151 3152 3153 @catch_mlflow_exception 3154 @_disable_if_artifacts_only 3155 def _update_webhook(webhook_id: str): 3156 request_message = _get_request_message( 3157 UpdateWebhook(), 3158 schema={ 3159 "name": [_assert_string], 3160 "description": [_assert_string], 3161 "url": [_assert_string], 3162 "events": [_assert_array], 3163 "secret": [_assert_string], 3164 "status": [_assert_string], 3165 }, 3166 ) 3167 webhook = _get_model_registry_store().update_webhook( 3168 webhook_id=webhook_id, 3169 name=request_message.name or None, 3170 description=request_message.description or None, 3171 url=request_message.url or None, 3172 events=( 3173 [WebhookEvent.from_proto(e) for e in request_message.events] 3174 if request_message.events 3175 else None 3176 ), 3177 secret=request_message.secret or None, 3178 status=WebhookStatus.from_proto(request_message.status) if request_message.status else None, 3179 ) 3180 response_message = UpdateWebhook.Response(webhook=webhook.to_proto()) 3181 return _wrap_response(response_message) 3182 3183 3184 @catch_mlflow_exception 3185 @_disable_if_artifacts_only 3186 def _delete_webhook(webhook_id: str): 3187 _get_model_registry_store().delete_webhook(webhook_id=webhook_id) 3188 response_message = DeleteWebhook.Response() 3189 return _wrap_response(response_message) 3190 3191 3192 @catch_mlflow_exception 3193 @_disable_if_artifacts_only 3194 def _test_webhook(webhook_id: str): 3195 request_message = _get_request_message(TestWebhook()) 3196 event = ( 3197 WebhookEvent.from_proto(request_message.event) 3198 if request_message.HasField("event") 3199 else None 3200 ) 3201 store = _get_model_registry_store() 3202 webhook = store.get_webhook(webhook_id=webhook_id) 3203 test_result = test_webhook(webhook=webhook, event=event) 3204 response_message = TestWebhook.Response(result=test_result.to_proto()) 3205 return _wrap_response(response_message) 3206 3207 3208 # MLflow Artifacts APIs 3209 3210 3211 def _get_workspace_scoped_repo_path_if_enabled(artifact_path: str | None) -> str | None: 3212 """ 3213 Normalize artifact paths for proxied (served) artifacts so they remain workspace-isolated. 3214 3215 When ``mlflow-artifacts`` proxying is enabled and workspaces are on, every path under the HTTP 3216 artifact endpoint must be rooted at ``workspaces/<workspace>/...``. Direct artifact repositories 3217 (e.g., S3, GCS, local URIs) already encode their own isolation, so they bypass this logic by 3218 calling the underlying store directly. Only the proxied repos need to be rewritten/validated 3219 here. 3220 3221 Returns: 3222 The workspace-scoped path. May return the original path in the following cases: 3223 - Workspaces are disabled (returns ``artifact_path`` unchanged). 3224 - Default workspace with no path (returns ``artifact_path`` unchanged to preserve legacy 3225 root behavior, where artifacts live at the root rather than under ``workspaces/default``). 3226 For non-default workspaces, always returns a string (``workspaces/<workspace>/...``). 3227 """ 3228 if not MLFLOW_ENABLE_WORKSPACES.get(): 3229 return artifact_path 3230 3231 workspace = workspace_context.get_request_workspace() 3232 if not workspace: 3233 raise MlflowException.invalid_parameter_value( 3234 "Active workspace is required for artifact operations. " 3235 "Ensure X-MLFLOW-WORKSPACE is set or call mlflow.set_workspace()." 3236 ) 3237 3238 normalized = artifact_path.lstrip("/") if artifact_path else "" 3239 base = posixpath.join("workspaces", workspace) 3240 3241 if not normalized: 3242 # For the default workspace, preserve the legacy root behavior (no prefix), 3243 # so root operations continue to see the existing layout. 3244 return base if workspace != DEFAULT_WORKSPACE_NAME else artifact_path 3245 3246 if workspace == DEFAULT_WORKSPACE_NAME and not normalized.startswith("workspaces/"): 3247 # Legacy default-workspace artifacts never had the workspace prefix; allow them to be served 3248 # without rewriting as long as the path isn't trying to opt into the reserved namespace. 3249 return artifact_path 3250 3251 leading_segments = normalized.split("/", 2) 3252 if leading_segments and leading_segments[0] == "workspaces": 3253 if len(leading_segments) == 1 or not leading_segments[1]: 3254 raise MlflowException.invalid_parameter_value( 3255 "Artifact paths prefixed with 'workspaces/' must include a workspace name." 3256 ) 3257 requested_workspace = leading_segments[1] 3258 if requested_workspace != workspace: 3259 raise MlflowException.invalid_parameter_value( 3260 f"Artifact path targets workspace '{requested_workspace}' " 3261 f"but the workspace specified in the request is '{workspace}'." 3262 ) 3263 return normalized 3264 3265 return posixpath.join(base, normalized) 3266 3267 3268 @catch_mlflow_exception 3269 @_disable_unless_serve_artifacts 3270 def _download_artifact(artifact_path): 3271 """ 3272 A request handler for `GET /mlflow-artifacts/artifacts/<artifact_path>` to download an artifact 3273 from `artifact_path` (a relative path from the root artifact directory). 3274 """ 3275 artifact_path = validate_path_is_safe(artifact_path) 3276 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 3277 tmp_dir = tempfile.TemporaryDirectory() 3278 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3279 dst = artifact_repo.download_artifacts(artifact_path, tmp_dir.name) 3280 3281 # Ref: https://stackoverflow.com/a/24613980/6943581 3282 file_handle = open(dst, "rb") # noqa: SIM115 3283 3284 def stream_and_remove_file(): 3285 while chunk := file_handle.read(ARTIFACT_STREAM_CHUNK_SIZE): 3286 yield chunk 3287 file_handle.close() 3288 tmp_dir.cleanup() 3289 3290 file_sender_response = current_app.response_class(stream_and_remove_file()) 3291 3292 return _response_with_file_attachment_headers(artifact_path, file_sender_response) 3293 3294 3295 @catch_mlflow_exception 3296 @_disable_unless_serve_artifacts 3297 def _upload_artifact(artifact_path): 3298 """ 3299 A request handler for `PUT /mlflow-artifacts/artifacts/<artifact_path>` to upload an artifact 3300 to `artifact_path` (a relative path from the root artifact directory). 3301 """ 3302 artifact_path = validate_path_is_safe(artifact_path) 3303 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 3304 head, tail = posixpath.split(artifact_path) 3305 with tempfile.TemporaryDirectory() as tmp_dir: 3306 tmp_path = os.path.join(tmp_dir, tail) 3307 with open(tmp_path, "wb") as f: 3308 while chunk := request.stream.read(ARTIFACT_STREAM_CHUNK_SIZE): 3309 f.write(chunk) 3310 3311 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3312 artifact_repo.log_artifact(tmp_path, artifact_path=head or None) 3313 3314 return _wrap_response(UploadArtifact.Response()) 3315 3316 3317 @catch_mlflow_exception 3318 @_disable_unless_serve_artifacts 3319 def _list_artifacts_mlflow_artifacts(): 3320 """ 3321 A request handler for `GET /mlflow-artifacts/artifacts?path=<value>` to list artifacts in `path` 3322 (a relative path from the root artifact directory). 3323 """ 3324 request_message = _get_request_message(ListArtifactsMlflowArtifacts()) 3325 path = validate_path_is_safe(request_message.path) if request_message.HasField("path") else None 3326 path = _get_workspace_scoped_repo_path_if_enabled(path) 3327 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3328 files = [] 3329 for file_info in artifact_repo.list_artifacts(path): 3330 basename = posixpath.basename(file_info.path) 3331 new_file_info = FileInfo(basename, file_info.is_dir, file_info.file_size) 3332 files.append(new_file_info.to_proto()) 3333 response_message = ListArtifacts.Response() 3334 response_message.files.extend(files) 3335 response = Response(mimetype="application/json") 3336 response.set_data(message_to_json(response_message)) 3337 return response 3338 3339 3340 @catch_mlflow_exception 3341 @_disable_unless_serve_artifacts 3342 def _delete_artifact_mlflow_artifacts(artifact_path): 3343 """ 3344 A request handler for `DELETE /mlflow-artifacts/artifacts?path=<value>` to delete artifacts in 3345 `path` (a relative path from the root artifact directory). 3346 """ 3347 artifact_path = validate_path_is_safe(artifact_path) 3348 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 3349 _get_request_message(DeleteArtifact()) 3350 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3351 artifact_repo.delete_artifacts(artifact_path) 3352 response_message = DeleteArtifact.Response() 3353 response = Response(mimetype="application/json") 3354 response.set_data(message_to_json(response_message)) 3355 return response 3356 3357 3358 def _get_graphql_auth_middleware(): 3359 """ 3360 Get GraphQL authorization middleware if basic-auth is enabled. 3361 3362 Returns: 3363 A list of middleware instances if auth is enabled, empty list otherwise. 3364 """ 3365 try: 3366 from mlflow.server.auth import get_graphql_authorization_middleware 3367 3368 return get_graphql_authorization_middleware() 3369 except Exception: 3370 # Auth not configured or other error 3371 return [] 3372 3373 3374 @catch_mlflow_exception 3375 def _graphql(): 3376 from graphql import parse 3377 3378 from mlflow.server.graphql.graphql_no_batching import check_query_safety 3379 from mlflow.server.graphql.graphql_schema_extensions import schema 3380 3381 # Extracting the query, variables, and operationName from the request 3382 request_json = _get_request_json() 3383 query = request_json.get("query") 3384 variables = request_json.get("variables") 3385 operation_name = request_json.get("operationName") 3386 3387 node = parse(query) 3388 if check_result := check_query_safety(node): 3389 result = check_result 3390 else: 3391 # Get auth middleware if basic-auth is enabled 3392 middleware = _get_graphql_auth_middleware() 3393 3394 # Executing the GraphQL query using the Graphene schema 3395 result = schema.execute( 3396 query, 3397 variables=variables, 3398 operation_name=operation_name, 3399 middleware=middleware, 3400 ) 3401 3402 # Convert execution result into json. 3403 result_data = { 3404 "data": result.data, 3405 "errors": [error.message for error in result.errors] if result.errors else None, 3406 } 3407 3408 # Return the response 3409 return jsonify(result_data) 3410 3411 3412 def _validate_support_multipart_upload(artifact_repo): 3413 if not isinstance(artifact_repo, MultipartUploadMixin): 3414 raise _UnsupportedMultipartUploadException() 3415 3416 3417 def _validate_support_multipart_download(artifact_repo): 3418 if not isinstance(artifact_repo, MultipartDownloadMixin): 3419 raise _UnsupportedMultipartDownloadException() 3420 3421 3422 def _validate_support_presigned_upload(artifact_repo): 3423 if not isinstance(artifact_repo, PresignedUploadMixin): 3424 raise _UnsupportedPresignedUploadException() 3425 3426 3427 @catch_mlflow_exception 3428 @_disable_if_artifacts_only 3429 def _create_presigned_upload_url(): 3430 """ 3431 Handler for POST /api/2.0/mlflow/artifacts/presigned-upload-url. 3432 Generates a presigned URL for uploading an artifact directly to cloud storage. 3433 3434 Client reference: https://github.com/aws/sagemaker-mlflow 3435 """ 3436 request_message = _get_request_message( 3437 CreatePresignedUploadUrl(), 3438 schema={ 3439 "run_id": [_assert_required, _assert_string], 3440 "path": [_assert_required, _assert_string], 3441 "expiration": [_assert_intlike], 3442 }, 3443 ) 3444 run_id = request_message.run_id 3445 path = validate_path_is_safe(request_message.path) 3446 expiration = request_message.expiration if request_message.HasField("expiration") else 900 3447 3448 run = _get_tracking_store().get_run(run_id) 3449 artifact_uri = run.info.artifact_uri 3450 artifact_uri_scheme = urllib.parse.urlparse(artifact_uri).scheme 3451 if artifact_uri_scheme in ("http", "https", "mlflow-artifacts"): 3452 raise MlflowException( 3453 "Presigned upload is not supported for runs with proxied artifact storage " 3454 f"(artifact URI scheme: {artifact_uri_scheme}). " 3455 "This endpoint requires a run with a direct cloud storage artifact URI.", 3456 error_code=INVALID_PARAMETER_VALUE, 3457 ) 3458 artifact_repo = _get_artifact_repo(run) 3459 _validate_support_presigned_upload(artifact_repo) 3460 3461 response = artifact_repo.create_presigned_upload_url(path, expiration=expiration) 3462 response_message = response.to_proto() 3463 resp = Response(mimetype="application/json") 3464 resp.set_data(message_to_json(response_message)) 3465 return resp 3466 3467 3468 @catch_mlflow_exception 3469 @_disable_unless_serve_artifacts 3470 def _create_multipart_upload_artifact(artifact_path): 3471 """ 3472 A request handler for `POST /mlflow-artifacts/mpu/create` to create a multipart upload 3473 to `artifact_path` (a relative path from the root artifact directory). 3474 """ 3475 artifact_path = validate_path_is_safe(artifact_path) 3476 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 3477 3478 request_message = _get_request_message( 3479 CreateMultipartUpload(), 3480 schema={ 3481 "path": [_assert_required, _assert_string], 3482 "num_parts": [_assert_intlike], 3483 }, 3484 ) 3485 path = request_message.path 3486 num_parts = request_message.num_parts 3487 3488 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3489 _validate_support_multipart_upload(artifact_repo) 3490 3491 create_response = artifact_repo.create_multipart_upload( 3492 path, 3493 num_parts, 3494 artifact_path, 3495 ) 3496 response_message = create_response.to_proto() 3497 response = Response(mimetype="application/json") 3498 response.set_data(message_to_json(response_message)) 3499 return response 3500 3501 3502 @catch_mlflow_exception 3503 @_disable_unless_serve_artifacts 3504 def _complete_multipart_upload_artifact(artifact_path): 3505 """ 3506 A request handler for `POST /mlflow-artifacts/mpu/complete` to complete a multipart upload 3507 to `artifact_path` (a relative path from the root artifact directory). 3508 """ 3509 artifact_path = validate_path_is_safe(artifact_path) 3510 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 3511 3512 request_message = _get_request_message( 3513 CompleteMultipartUpload(), 3514 schema={ 3515 "path": [_assert_required, _assert_string], 3516 "upload_id": [_assert_string], 3517 "parts": [_assert_required], 3518 }, 3519 ) 3520 path = request_message.path 3521 upload_id = request_message.upload_id 3522 parts = [MultipartUploadPart.from_proto(part) for part in request_message.parts] 3523 3524 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3525 _validate_support_multipart_upload(artifact_repo) 3526 3527 artifact_repo.complete_multipart_upload( 3528 path, 3529 upload_id, 3530 parts, 3531 artifact_path, 3532 ) 3533 return _wrap_response(CompleteMultipartUpload.Response()) 3534 3535 3536 @catch_mlflow_exception 3537 @_disable_unless_serve_artifacts 3538 def _abort_multipart_upload_artifact(artifact_path): 3539 """ 3540 A request handler for `POST /mlflow-artifacts/mpu/abort` to abort a multipart upload 3541 to `artifact_path` (a relative path from the root artifact directory). 3542 """ 3543 artifact_path = validate_path_is_safe(artifact_path) 3544 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 3545 3546 request_message = _get_request_message( 3547 AbortMultipartUpload(), 3548 schema={ 3549 "path": [_assert_required, _assert_string], 3550 "upload_id": [_assert_string], 3551 }, 3552 ) 3553 path = request_message.path 3554 upload_id = request_message.upload_id 3555 3556 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3557 _validate_support_multipart_upload(artifact_repo) 3558 3559 artifact_repo.abort_multipart_upload( 3560 path, 3561 upload_id, 3562 artifact_path, 3563 ) 3564 return _wrap_response(AbortMultipartUpload.Response()) 3565 3566 3567 @catch_mlflow_exception 3568 @_disable_unless_serve_artifacts 3569 def _get_presigned_download_url(artifact_path): 3570 """ 3571 A request handler for `GET /mlflow-artifacts/presigned/<artifact_path>` to get 3572 a presigned URL for downloading an artifact directly from cloud storage. 3573 """ 3574 artifact_path = validate_path_is_safe(artifact_path) 3575 3576 artifact_repo = _get_artifact_repo_mlflow_artifacts() 3577 _validate_support_multipart_download(artifact_repo) 3578 3579 expiration = MLFLOW_PRESIGNED_DOWNLOAD_URL_TTL_SECONDS.get() 3580 presigned_response = artifact_repo.get_download_presigned_url( 3581 artifact_path, expiration=expiration 3582 ) 3583 response = Response(mimetype="application/json") 3584 response.set_data(json.dumps(presigned_response.to_dict())) 3585 return response 3586 3587 3588 # MLflow Tracing APIs 3589 3590 3591 @catch_mlflow_exception 3592 @_disable_if_artifacts_only 3593 def _start_trace_v3(): 3594 """ 3595 A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store. 3596 """ 3597 request_message = _get_request_message( 3598 StartTraceV3(), 3599 schema={"trace": [_assert_required]}, 3600 ) 3601 trace_info = TraceInfo.from_proto(request_message.trace.trace_info) 3602 trace_info = _get_tracking_store().start_trace(trace_info) 3603 response_message = StartTraceV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto())) 3604 return _wrap_response(response_message) 3605 3606 3607 @catch_mlflow_exception 3608 @_disable_if_artifacts_only 3609 def _get_trace_info_v3(trace_id): 3610 """ 3611 A request handler for `GET /mlflow/traces/{trace_id}` to retrieve 3612 an existing TraceInfo record from tracking store. 3613 """ 3614 trace_info = _get_tracking_store().get_trace_info(trace_id) 3615 response_message = GetTraceInfoV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto())) 3616 return _wrap_response(response_message) 3617 3618 3619 @catch_mlflow_exception 3620 @_disable_if_artifacts_only 3621 def _batch_get_traces() -> Response: 3622 """ 3623 A request handler for `GET /mlflow/traces/batchGet` to retrieve 3624 a batch of complete traces with spans for given trace ids. 3625 """ 3626 request_message = _get_request_message( 3627 BatchGetTraces(), 3628 schema={"trace_ids": [_assert_array, _assert_required, _assert_item_type_string]}, 3629 ) 3630 traces = _get_tracking_store().batch_get_traces(request_message.trace_ids, None) 3631 response_message = BatchGetTraces.Response() 3632 response_message.traces.extend([t.to_proto() for t in traces]) 3633 return _wrap_response(response_message) 3634 3635 3636 @catch_mlflow_exception 3637 @_disable_if_artifacts_only 3638 def _batch_get_trace_infos() -> Response: 3639 request_message = _get_request_message( 3640 BatchGetTraceInfos(), 3641 schema={"trace_ids": [_assert_array, _assert_required, _assert_item_type_string]}, 3642 ) 3643 trace_infos = _get_tracking_store().batch_get_trace_infos(request_message.trace_ids) 3644 response_message = BatchGetTraceInfos.Response() 3645 response_message.trace_infos.extend([ti.to_proto() for ti in trace_infos]) 3646 return _wrap_response(response_message) 3647 3648 3649 @catch_mlflow_exception 3650 @_disable_if_artifacts_only 3651 def _get_trace() -> Response: 3652 """ 3653 A request handler for `GET /mlflow/traces/get` to get a trace with spans for given trace id. 3654 """ 3655 request_message = _get_request_message( 3656 GetTrace(), 3657 schema={ 3658 "trace_id": [_assert_string, _assert_required], 3659 "allow_partial": [_assert_bool], 3660 }, 3661 ) 3662 trace_id = request_message.trace_id 3663 allow_partial = request_message.allow_partial 3664 trace = _get_tracking_store().get_trace(trace_id, allow_partial=allow_partial) 3665 response_message = GetTrace.Response(trace=trace.to_proto()) 3666 return _wrap_response(response_message) 3667 3668 3669 @catch_mlflow_exception 3670 @_disable_if_artifacts_only 3671 def _search_traces_v3(): 3672 """ 3673 A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store. 3674 """ 3675 request_message = _get_request_message( 3676 SearchTracesV3(), 3677 schema={ 3678 "locations": [_assert_array, _assert_required], 3679 "filter": [_assert_string], 3680 "max_results": [ 3681 _assert_intlike, 3682 lambda x: _assert_less_than_or_equal(int(x), 500), 3683 ], 3684 "order_by": [_assert_array, _assert_item_type_string], 3685 "page_token": [_assert_string], 3686 }, 3687 ) 3688 experiment_ids = [ 3689 location.mlflow_experiment.experiment_id 3690 for location in request_message.locations 3691 if location.HasField("mlflow_experiment") 3692 ] 3693 3694 traces, token = _get_tracking_store().search_traces( 3695 locations=experiment_ids, 3696 filter_string=request_message.filter, 3697 max_results=request_message.max_results, 3698 order_by=request_message.order_by, 3699 page_token=request_message.page_token or None, 3700 ) 3701 response_message = SearchTracesV3.Response() 3702 response_message.traces.extend([e.to_proto() for e in traces]) 3703 if token: 3704 response_message.next_page_token = token 3705 return _wrap_response(response_message) 3706 3707 3708 @catch_mlflow_exception 3709 @_disable_if_artifacts_only 3710 def _delete_traces(): 3711 """ 3712 A request handler for `POST /mlflow/traces/delete-traces` to delete TraceInfo records 3713 from tracking store. 3714 """ 3715 request_message = _get_request_message( 3716 DeleteTraces(), 3717 schema={ 3718 "experiment_id": [_assert_string, _assert_required], 3719 "max_timestamp_millis": [_assert_intlike], 3720 "max_traces": [_assert_intlike], 3721 "request_ids": [_assert_array, _assert_item_type_string], 3722 }, 3723 ) 3724 3725 # NB: Interestingly, the field accessor for the message object returns the default 3726 # value for optional field if it's not set. For example, `request_message.max_traces` 3727 # returns 0 if max_traces is not specified in the request. This is not desirable, 3728 # because null and 0 means completely opposite i.e. the former is 'delete nothing' 3729 # while the latter is 'delete all'. To handle this, we need to explicitly check 3730 # if the field is set or not using `HasField` method and return None if not. 3731 def _get_nullable_field(field): 3732 if request_message.HasField(field): 3733 return getattr(request_message, field) 3734 return None 3735 3736 traces_deleted = _get_tracking_store().delete_traces( 3737 experiment_id=request_message.experiment_id, 3738 max_timestamp_millis=_get_nullable_field("max_timestamp_millis"), 3739 max_traces=_get_nullable_field("max_traces"), 3740 trace_ids=request_message.request_ids, 3741 ) 3742 return _wrap_response(DeleteTraces.Response(traces_deleted=traces_deleted)) 3743 3744 3745 @catch_mlflow_exception 3746 @_disable_if_artifacts_only 3747 def _calculate_trace_filter_correlation(): 3748 """ 3749 A request handler for `POST /mlflow/traces/calculate-filter-correlation` to calculate 3750 NPMI correlation between two trace filter conditions. 3751 """ 3752 request_message = _get_request_message( 3753 CalculateTraceFilterCorrelation(), 3754 schema={ 3755 "experiment_ids": [_assert_array, _assert_required, _assert_item_type_string], 3756 "filter_string1": [_assert_string, _assert_required], 3757 "filter_string2": [_assert_string, _assert_required], 3758 "base_filter": [_assert_string], 3759 }, 3760 ) 3761 3762 result = _get_tracking_store().calculate_trace_filter_correlation( 3763 experiment_ids=request_message.experiment_ids, 3764 filter_string1=request_message.filter_string1, 3765 filter_string2=request_message.filter_string2, 3766 base_filter=request_message.base_filter 3767 if request_message.HasField("base_filter") 3768 else None, 3769 ) 3770 3771 return _wrap_response(result.to_proto()) 3772 3773 3774 @catch_mlflow_exception 3775 @_disable_if_artifacts_only 3776 def _set_trace_tag(request_id): 3777 """ 3778 A request handler for `PATCH /mlflow/traces/{request_id}/tags` to set tags on a TraceInfo record 3779 """ 3780 request_message = _get_request_message( 3781 SetTraceTag(), 3782 schema={ 3783 "key": [_assert_string, _assert_required], 3784 "value": [_assert_string], 3785 }, 3786 ) 3787 _get_tracking_store().set_trace_tag(request_id, request_message.key, request_message.value) 3788 return _wrap_response(SetTraceTag.Response()) 3789 3790 3791 @catch_mlflow_exception 3792 @_disable_if_artifacts_only 3793 def _set_trace_tag_v3(trace_id): 3794 """ 3795 A request handler for `PATCH /mlflow/traces/{trace_id}/tags` to set tags on a TraceInfo record. 3796 Identical to `_set_trace_tag`, but with request_id renamed to with trace_id. 3797 """ 3798 request_message = _get_request_message( 3799 SetTraceTagV3(), 3800 schema={ 3801 "key": [_assert_string, _assert_required], 3802 "value": [_assert_string], 3803 }, 3804 ) 3805 _get_tracking_store().set_trace_tag(trace_id, request_message.key, request_message.value) 3806 return _wrap_response(SetTraceTagV3.Response()) 3807 3808 3809 @catch_mlflow_exception 3810 @_disable_if_artifacts_only 3811 def _delete_trace_tag(request_id): 3812 """ 3813 A request handler for `DELETE /mlflow/traces/{request_id}/tags` to delete tags from a TraceInfo 3814 record. 3815 """ 3816 request_message = _get_request_message( 3817 DeleteTraceTag(), 3818 schema={ 3819 "key": [_assert_string, _assert_required], 3820 }, 3821 ) 3822 _get_tracking_store().delete_trace_tag(request_id, request_message.key) 3823 return _wrap_response(DeleteTraceTag.Response()) 3824 3825 3826 @catch_mlflow_exception 3827 @_disable_if_artifacts_only 3828 def _delete_trace_tag_v3(trace_id): 3829 """ 3830 A request handler for `DELETE /mlflow/traces/{trace_id}/tags` to delete tags 3831 from a TraceInfo record. 3832 Identical to `_delete_trace_tag`, but with request_id renamed to with trace_id. 3833 """ 3834 request_message = _get_request_message( 3835 DeleteTraceTagV3(), 3836 schema={ 3837 "key": [_assert_string, _assert_required], 3838 }, 3839 ) 3840 _get_tracking_store().delete_trace_tag(trace_id, request_message.key) 3841 return _wrap_response(DeleteTraceTagV3.Response()) 3842 3843 3844 @catch_mlflow_exception 3845 @_disable_if_artifacts_only 3846 def _link_traces_to_run(): 3847 """ 3848 A request handler for `POST /mlflow/traces/link-to-run` to link traces to a run. 3849 """ 3850 request_message = _get_request_message( 3851 LinkTracesToRun(), 3852 schema={ 3853 "trace_ids": [_assert_array, _assert_required, _assert_item_type_string], 3854 "run_id": [_assert_string, _assert_required], 3855 }, 3856 ) 3857 _get_tracking_store().link_traces_to_run( 3858 trace_ids=request_message.trace_ids, 3859 run_id=request_message.run_id, 3860 ) 3861 return _wrap_response(LinkTracesToRun.Response()) 3862 3863 3864 @catch_mlflow_exception 3865 @_disable_if_artifacts_only 3866 def _link_prompts_to_trace(): 3867 """ 3868 A request handler for `POST /mlflow/traces/link-prompts` to link prompt versions to a trace. 3869 """ 3870 from mlflow.entities.model_registry import PromptVersion 3871 3872 request_message = _get_request_message( 3873 LinkPromptsToTrace(), 3874 schema={ 3875 "trace_id": [_assert_string, _assert_required], 3876 "prompt_versions": [_assert_array, _assert_required], 3877 }, 3878 ) 3879 3880 # Convert PromptVersionRef proto messages to PromptVersion objects 3881 # It doesn't load prompt versions since name and version are sufficient for linking 3882 prompt_versions = [ 3883 PromptVersion(name=pv.name, version=int(pv.version), template="") 3884 for pv in request_message.prompt_versions 3885 ] 3886 3887 _get_tracking_store().link_prompts_to_trace( 3888 trace_id=request_message.trace_id, 3889 prompt_versions=prompt_versions, 3890 ) 3891 return _wrap_response(LinkPromptsToTrace.Response()) 3892 3893 3894 def _fetch_trace_data_from_store( 3895 store: AbstractTrackingStore, request_id: str 3896 ) -> dict[str, Any] | None: 3897 try: 3898 # allow partial so the frontend can render in-progress traces 3899 trace = store.get_trace(request_id, allow_partial=True) 3900 return trace.data.to_dict() 3901 except MlflowTracingException: 3902 return None 3903 except MlflowNotImplementedException: 3904 # fallback to batch_get_traces if get_trace is not implemented 3905 pass 3906 3907 try: 3908 traces = store.batch_get_traces([request_id], None) 3909 match traces: 3910 case [trace]: 3911 return trace.data.to_dict() 3912 case _: 3913 raise MlflowException( 3914 f"Trace with id={request_id} not found.", 3915 error_code=RESOURCE_DOES_NOT_EXIST, 3916 ) 3917 # For stores that don't support batch get traces, or if trace data is not in the store, 3918 # return None to signal fallback to artifact repository 3919 except (MlflowTracingException, MlflowNotImplementedException): 3920 return None 3921 3922 3923 @catch_mlflow_exception 3924 @_disable_if_artifacts_only 3925 def get_trace_artifact_handler() -> Response: 3926 request_id = request.args.get("request_id") 3927 path = request.args.get("path") 3928 3929 if not request_id: 3930 raise MlflowException( 3931 'Request must include the "request_id" query parameter.', 3932 error_code=BAD_REQUEST, 3933 ) 3934 3935 store = _get_tracking_store() 3936 3937 if path: 3938 path = validate_path_is_safe(path) 3939 trace_info = store.get_trace_info(request_id) 3940 if trace_info is None: 3941 raise MlflowException( 3942 f"Trace with ID '{request_id}' not found.", 3943 error_code=RESOURCE_DOES_NOT_EXIST, 3944 ) 3945 repo = _get_trace_artifact_repo(trace_info) 3946 try: 3947 content_bytes = repo.download_trace_attachment(path) 3948 except MlflowException: 3949 raise 3950 except Exception: 3951 _logger.warning( 3952 "Failed to download attachment '%s' for trace '%s'", 3953 path, 3954 request_id, 3955 exc_info=True, 3956 ) 3957 raise MlflowException( 3958 f"Failed to download attachment '{path}' for trace '{request_id}'.", 3959 error_code=INTERNAL_ERROR, 3960 ) 3961 buf = io.BytesIO(content_bytes) 3962 file_sender_response = send_file( 3963 buf, 3964 mimetype="application/octet-stream", 3965 as_attachment=True, 3966 download_name=path, 3967 ) 3968 return _response_with_file_attachment_headers(path, file_sender_response) 3969 3970 trace_data = _fetch_trace_data_from_store(store, request_id) 3971 if trace_data is None: 3972 trace_info = store.get_trace_info(request_id) 3973 trace_data = _get_trace_artifact_repo(trace_info).download_trace_data() 3974 3975 # Write data to a BytesIO buffer instead of needing to save a temp file 3976 buf = io.BytesIO() 3977 buf.write(json.dumps(trace_data).encode()) 3978 buf.seek(0) 3979 3980 file_sender_response = send_file( 3981 buf, 3982 mimetype="application/octet-stream", 3983 as_attachment=True, 3984 download_name=TRACE_DATA_FILE_NAME, 3985 ) 3986 return _response_with_file_attachment_headers(TRACE_DATA_FILE_NAME, file_sender_response) 3987 3988 3989 @catch_mlflow_exception 3990 @_disable_if_artifacts_only 3991 def _query_trace_metrics() -> Response: 3992 request_message = _get_request_message( 3993 QueryTraceMetrics(), 3994 schema={ 3995 "experiment_ids": [_assert_array, _assert_required, _assert_item_type_string], 3996 "view_type": [_assert_required], 3997 "metric_name": [_assert_string, _assert_required], 3998 "aggregations": [_assert_array, _assert_required], 3999 "dimensions": [_assert_array, _assert_item_type_string], 4000 "filters": [_assert_array, _assert_item_type_string], 4001 "time_interval_seconds": [_assert_intlike], 4002 "start_time_ms": [_assert_intlike], 4003 "end_time_ms": [_assert_intlike], 4004 "max_results": [_assert_intlike], 4005 "page_token": [_assert_string], 4006 }, 4007 ) 4008 max_results = ( 4009 request_message.max_results 4010 if request_message.HasField("max_results") 4011 else MAX_RESULTS_QUERY_TRACE_METRICS 4012 ) 4013 time_interval_seconds = ( 4014 request_message.time_interval_seconds 4015 if request_message.HasField("time_interval_seconds") 4016 else None 4017 ) 4018 start_time_ms = ( 4019 request_message.start_time_ms if request_message.HasField("start_time_ms") else None 4020 ) 4021 end_time_ms = request_message.end_time_ms if request_message.HasField("end_time_ms") else None 4022 4023 result = _get_tracking_store().query_trace_metrics( 4024 experiment_ids=request_message.experiment_ids, 4025 view_type=MetricViewType.from_proto(request_message.view_type), 4026 metric_name=request_message.metric_name, 4027 aggregations=[MetricAggregation.from_proto(agg) for agg in request_message.aggregations], 4028 dimensions=request_message.dimensions or None, 4029 filters=request_message.filters or None, 4030 time_interval_seconds=time_interval_seconds, 4031 start_time_ms=start_time_ms, 4032 end_time_ms=end_time_ms, 4033 max_results=max_results, 4034 page_token=request_message.page_token or None, 4035 ) 4036 response_message = QueryTraceMetrics.Response() 4037 response_message.data_points.extend([dp.to_proto() for dp in result]) 4038 if result.token: 4039 response_message.next_page_token = result.token 4040 return _wrap_response(response_message) 4041 4042 4043 # Assessments API handlers 4044 @catch_mlflow_exception 4045 @_disable_if_artifacts_only 4046 def _create_assessment(trace_id): 4047 """ 4048 A request handler for `POST /mlflow/traces/{assessment.trace_id}/assessments` 4049 to create a new assessment. 4050 """ 4051 request_message = _get_request_message( 4052 CreateAssessment(), 4053 schema={ 4054 "assessment": [_assert_required], 4055 }, 4056 ) 4057 4058 assessment = Assessment.from_proto(request_message.assessment) 4059 assessment.trace_id = trace_id 4060 created_assessment = _get_tracking_store().create_assessment(assessment) 4061 4062 response_message = CreateAssessment.Response(assessment=created_assessment.to_proto()) 4063 return _wrap_response(response_message) 4064 4065 4066 @catch_mlflow_exception 4067 @_disable_if_artifacts_only 4068 def _get_assessment(trace_id, assessment_id): 4069 """ 4070 A request handler for `GET /mlflow/traces/{trace_id}/assessments/{assessment_id}` 4071 to get an assessment. 4072 """ 4073 assessment = _get_tracking_store().get_assessment(trace_id, assessment_id) 4074 4075 response_message = GetAssessmentRequest.Response(assessment=assessment.to_proto()) 4076 return _wrap_response(response_message) 4077 4078 4079 @catch_mlflow_exception 4080 @_disable_if_artifacts_only 4081 def _update_assessment(trace_id, assessment_id): 4082 """ 4083 A request handler for `PATCH /mlflow/traces/{trace_id}/assessments/{assessment_id}` 4084 to update an assessment. 4085 """ 4086 request_message = _get_request_message( 4087 UpdateAssessment(), 4088 schema={ 4089 "assessment": [_assert_required], 4090 "update_mask": [_assert_required], 4091 }, 4092 ) 4093 4094 assessment_proto = request_message.assessment 4095 update_mask = request_message.update_mask 4096 4097 kwargs = {} 4098 4099 for path in update_mask.paths: 4100 if path == "assessment_name": 4101 kwargs["name"] = assessment_proto.assessment_name 4102 elif path == "expectation": 4103 kwargs["expectation"] = Expectation.from_proto(assessment_proto) 4104 elif path == "feedback": 4105 kwargs["feedback"] = Feedback.from_proto(assessment_proto) 4106 elif path == "rationale": 4107 kwargs["rationale"] = assessment_proto.rationale 4108 elif path == "metadata": 4109 kwargs["metadata"] = dict(assessment_proto.metadata) 4110 elif path == "valid": 4111 kwargs["valid"] = assessment_proto.valid 4112 4113 updated_assessment = _get_tracking_store().update_assessment( 4114 trace_id=trace_id, assessment_id=assessment_id, **kwargs 4115 ) 4116 4117 response_message = UpdateAssessment.Response(assessment=updated_assessment.to_proto()) 4118 return _wrap_response(response_message) 4119 4120 4121 @catch_mlflow_exception 4122 @_disable_if_artifacts_only 4123 def _delete_assessment(trace_id, assessment_id): 4124 """ 4125 A request handler for `DELETE /mlflow/traces/{trace_id}/assessments/{assessment_id}` 4126 to delete an assessment. 4127 """ 4128 _get_tracking_store().delete_assessment(trace_id, assessment_id) 4129 4130 response_message = DeleteAssessment.Response() 4131 return _wrap_response(response_message) 4132 4133 4134 @catch_mlflow_exception 4135 @_disable_if_artifacts_only 4136 def _create_issue(): 4137 """ 4138 A request handler for `POST /mlflow/issues` to create a new issue. 4139 """ 4140 request_message = _get_request_message( 4141 CreateIssue(), 4142 schema={ 4143 "name": [_assert_required, _assert_string], 4144 "description": [_assert_required, _assert_string], 4145 "experiment_id": [_assert_required, _assert_string], 4146 }, 4147 ) 4148 4149 # Build kwargs for create_issue 4150 create_kwargs = { 4151 "experiment_id": request_message.experiment_id, 4152 "name": request_message.name, 4153 "description": request_message.description, 4154 "source_run_id": request_message.source_run_id or None, 4155 "root_causes": list(request_message.root_causes) or None, 4156 "categories": list(request_message.categories) or None, 4157 "created_by": request_message.created_by or None, 4158 } 4159 4160 if request_message.HasField("status"): 4161 create_kwargs["status"] = IssueStatus(request_message.status) 4162 if request_message.HasField("severity"): 4163 create_kwargs["severity"] = IssueSeverity(request_message.severity) 4164 4165 created_issue = _get_tracking_store().create_issue(**create_kwargs) 4166 4167 response_message = CreateIssue.Response(issue=created_issue.to_proto()) 4168 return _wrap_response(response_message) 4169 4170 4171 @catch_mlflow_exception 4172 @_disable_if_artifacts_only 4173 def _update_issue(issue_id): 4174 """ 4175 A request handler for `PATCH /mlflow/issues/{issue_id}` to update an issue. 4176 """ 4177 request_message = _get_request_message( 4178 UpdateIssue(), 4179 schema={ 4180 "issue_id": [_assert_required], 4181 }, 4182 ) 4183 4184 status = IssueStatus(request_message.status) if request_message.HasField("status") else None 4185 severity = ( 4186 IssueSeverity(request_message.severity) if request_message.HasField("severity") else None 4187 ) 4188 4189 updated_issue = _get_tracking_store().update_issue( 4190 issue_id=issue_id, 4191 status=status, 4192 name=request_message.name or None, 4193 description=request_message.description or None, 4194 severity=severity, 4195 ) 4196 4197 response_message = UpdateIssue.Response(issue=updated_issue.to_proto()) 4198 return _wrap_response(response_message) 4199 4200 4201 @catch_mlflow_exception 4202 @_disable_if_artifacts_only 4203 def _get_issue(issue_id): 4204 """ 4205 A request handler for `GET /mlflow/issues/{issue_id}` to get an issue. 4206 """ 4207 issue = _get_tracking_store().get_issue(issue_id) 4208 4209 response_message = GetIssue.Response(issue=issue.to_proto()) 4210 return _wrap_response(response_message) 4211 4212 4213 @catch_mlflow_exception 4214 @_disable_if_artifacts_only 4215 def _search_issues(): 4216 """ 4217 A request handler for `POST /mlflow/issues/search` to search for issues. 4218 """ 4219 request_message = _get_request_message(SearchIssues()) 4220 4221 # Build kwargs for search_issues 4222 search_kwargs = { 4223 "experiment_id": request_message.experiment_id or None, 4224 "filter_string": request_message.filter_string or None, 4225 "page_token": request_message.page_token or None, 4226 } 4227 4228 if request_message.HasField("max_results"): 4229 search_kwargs["max_results"] = request_message.max_results 4230 4231 if request_message.HasField("include_trace_count"): 4232 search_kwargs["include_trace_count"] = request_message.include_trace_count 4233 4234 issues = _get_tracking_store().search_issues(**search_kwargs) 4235 4236 issue_protos = [issue.to_proto() for issue in issues] 4237 response_message = SearchIssues.Response( 4238 issues=issue_protos, next_page_token=issues.token or "" 4239 ) 4240 return _wrap_response(response_message) 4241 4242 4243 @catch_mlflow_exception 4244 @_disable_if_artifacts_only 4245 def _invoke_issue_detection_handler(): 4246 """ 4247 Invoke issue detection on traces asynchronously. 4248 4249 This is a UI-only AJAX endpoint for running issue detection from the frontend. 4250 """ 4251 from mlflow.genai.discovery.job import _fetch_provider_credentials, invoke_issue_detection_job 4252 from mlflow.server.jobs import submit_job 4253 4254 _validate_content_type(request, ["application/json"]) 4255 4256 request_json = _get_validated_flask_request_json( 4257 schema={ 4258 "experiment_id": [_assert_required, _assert_string], 4259 "trace_ids": [_assert_required, _assert_array], 4260 "categories": [_assert_required, _assert_array], 4261 "provider": [_assert_required, _assert_string], 4262 "model": [_assert_string], 4263 "secret_id": [_assert_string], 4264 "endpoint_name": [_assert_string], 4265 } 4266 ) 4267 4268 experiment_id = request_json.get("experiment_id") 4269 trace_ids = request_json.get("trace_ids", []) 4270 categories = request_json.get("categories", []) 4271 provider = request_json.get("provider") 4272 model = request_json.get("model") 4273 secret_id = request_json.get("secret_id") 4274 endpoint_name = request_json.get("endpoint_name") 4275 4276 if not endpoint_name and not (provider and model): 4277 raise MlflowException( 4278 "Either 'endpoint_name' or both 'provider' and 'model' must be provided" 4279 ) 4280 4281 # Fetch credentials required for executing the job 4282 if secret_id: 4283 store = _get_tracking_store() 4284 credentials = _fetch_provider_credentials(store, provider, secret_id) 4285 else: 4286 credentials = None 4287 4288 # Create the run upfront so we can return run_id immediately 4289 model_name = f"gateway:/{endpoint_name}" if endpoint_name else f"{provider}:/{model}" 4290 tags = { 4291 MLFLOW_RUN_TYPE: MLFLOW_RUN_TYPE_ISSUE_DETECTION, 4292 "categories": ",".join(categories), 4293 "model": model_name, 4294 "total_traces": len(trace_ids), 4295 } 4296 if endpoint_name: 4297 tags["endpoint_name"] = endpoint_name 4298 run = mlflow.start_run( 4299 experiment_id=experiment_id, 4300 tags=tags, 4301 ) 4302 run_id = run.info.run_id 4303 4304 job = submit_job( 4305 function=invoke_issue_detection_job, 4306 params={ 4307 "experiment_id": experiment_id, 4308 "trace_ids": trace_ids, 4309 "categories": categories, 4310 "run_id": run_id, 4311 "model": model_name, 4312 }, 4313 extra_envs=credentials, 4314 ) 4315 # Tag the run with job ID for later retrieval 4316 mlflow.set_tag(MLFLOW_ISSUE_DETECTION_JOB_ID, job.job_id) 4317 mlflow.end_run(RunStatus.to_string(RunStatus.RUNNING)) 4318 4319 return jsonify({"job_id": job.job_id, "run_id": run_id}) 4320 4321 4322 @catch_mlflow_exception 4323 @_disable_if_artifacts_only 4324 def _get_job(job_id): 4325 from mlflow.server.jobs import get_job 4326 4327 job = get_job(job_id) 4328 return jsonify({ 4329 "status": str(job.status), 4330 "result": job.parsed_result, 4331 "status_details": job.status_details, 4332 }) 4333 4334 4335 @catch_mlflow_exception 4336 @_disable_if_artifacts_only 4337 def _cancel_job(job_id): 4338 from mlflow.server.jobs import cancel_job 4339 4340 job = cancel_job(job_id) 4341 return jsonify({ 4342 "status": str(job.status), 4343 "result": job.parsed_result, 4344 }) 4345 4346 4347 # Deprecated MLflow Tracing APIs. Kept for backward compatibility but do not use. 4348 4349 4350 @catch_mlflow_exception 4351 @_disable_if_artifacts_only 4352 def _deprecated_start_trace_v2(): 4353 """ 4354 A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store. 4355 """ 4356 request_message = _get_request_message( 4357 StartTrace(), 4358 schema={ 4359 "experiment_id": [_assert_string], 4360 "timestamp_ms": [_assert_intlike], 4361 "request_metadata": [_assert_map_key_present], 4362 "tags": [_assert_map_key_present], 4363 }, 4364 ) 4365 request_metadata = {e.key: e.value for e in request_message.request_metadata} 4366 tags = {e.key: e.value for e in request_message.tags} 4367 4368 trace_info = _get_tracking_store().deprecated_start_trace_v2( 4369 experiment_id=request_message.experiment_id, 4370 timestamp_ms=request_message.timestamp_ms, 4371 request_metadata=request_metadata, 4372 tags=tags, 4373 ) 4374 response_message = StartTrace.Response(trace_info=trace_info.to_proto()) 4375 return _wrap_response(response_message) 4376 4377 4378 @catch_mlflow_exception 4379 @_disable_if_artifacts_only 4380 def _deprecated_end_trace_v2(request_id): 4381 """ 4382 A request handler for `PATCH /mlflow/traces/{request_id}` to mark an existing TraceInfo 4383 record completed in tracking store. 4384 """ 4385 request_message = _get_request_message( 4386 EndTrace(), 4387 schema={ 4388 "timestamp_ms": [_assert_intlike], 4389 "status": [_assert_string], 4390 "request_metadata": [_assert_map_key_present], 4391 "tags": [_assert_map_key_present], 4392 }, 4393 ) 4394 request_metadata = {e.key: e.value for e in request_message.request_metadata} 4395 tags = {e.key: e.value for e in request_message.tags} 4396 4397 trace_info = _get_tracking_store().deprecated_end_trace_v2( 4398 request_id=request_id, 4399 timestamp_ms=request_message.timestamp_ms, 4400 status=TraceStatus.from_proto(request_message.status), 4401 request_metadata=request_metadata, 4402 tags=tags, 4403 ) 4404 4405 if isinstance(trace_info, TraceInfo): 4406 trace_info = TraceInfoV2.from_v3(trace_info) 4407 4408 response_message = EndTrace.Response(trace_info=trace_info.to_proto()) 4409 return _wrap_response(response_message) 4410 4411 4412 @catch_mlflow_exception 4413 @_disable_if_artifacts_only 4414 def _deprecated_get_trace_info_v2(request_id): 4415 """ 4416 A request handler for `GET /mlflow/traces/{request_id}/info` to retrieve 4417 an existing TraceInfo record from tracking store. 4418 """ 4419 trace_info = _get_tracking_store().get_trace_info(request_id) 4420 trace_info = TraceInfoV2.from_v3(trace_info) 4421 response_message = GetTraceInfo.Response(trace_info=trace_info.to_proto()) 4422 return _wrap_response(response_message) 4423 4424 4425 @catch_mlflow_exception 4426 @_disable_if_artifacts_only 4427 def _deprecated_search_traces_v2(): 4428 """ 4429 A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store. 4430 """ 4431 request_message = _get_request_message( 4432 SearchTraces(), 4433 schema={ 4434 "experiment_ids": [ 4435 _assert_array, 4436 _assert_item_type_string, 4437 _assert_required, 4438 ], 4439 "filter": [_assert_string], 4440 "max_results": [ 4441 _assert_intlike, 4442 lambda x: _assert_less_than_or_equal(int(x), 500), 4443 ], 4444 "order_by": [_assert_array, _assert_item_type_string], 4445 "page_token": [_assert_string], 4446 }, 4447 ) 4448 4449 traces, token = _get_tracking_store().search_traces( 4450 experiment_ids=request_message.experiment_ids, 4451 filter_string=request_message.filter, 4452 max_results=request_message.max_results, 4453 order_by=request_message.order_by, 4454 page_token=request_message.page_token or None, 4455 ) 4456 traces = [TraceInfoV2.from_v3(t) for t in traces] 4457 response_message = SearchTraces.Response() 4458 response_message.traces.extend([e.to_proto() for e in traces]) 4459 if token: 4460 response_message.next_page_token = token 4461 return _wrap_response(response_message) 4462 4463 4464 # Logged Models APIs 4465 4466 4467 @catch_mlflow_exception 4468 @_disable_if_artifacts_only 4469 def get_logged_model_artifact_handler(model_id: str): 4470 artifact_file_path = request.args.get("artifact_file_path") 4471 if not artifact_file_path: 4472 raise MlflowException( 4473 'Request must include the "artifact_file_path" query parameter.', 4474 error_code=BAD_REQUEST, 4475 ) 4476 validate_path_is_safe(artifact_file_path) 4477 4478 logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id) 4479 if _is_servable_proxied_run_artifact_root(logged_model.artifact_location): 4480 artifact_repo = _get_artifact_repo_mlflow_artifacts() 4481 artifact_path = _get_proxied_run_artifact_destination_path( 4482 proxied_artifact_root=logged_model.artifact_location, 4483 relative_path=artifact_file_path, 4484 ) 4485 artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path) 4486 else: 4487 artifact_repo = get_artifact_repository(logged_model.artifact_location) 4488 artifact_path = artifact_file_path 4489 4490 return _send_artifact(artifact_repo, artifact_path) 4491 4492 4493 @catch_mlflow_exception 4494 @_disable_if_artifacts_only 4495 def _create_logged_model(): 4496 request_message = _get_request_message( 4497 CreateLoggedModel(), 4498 schema={ 4499 "experiment_id": [_assert_string, _assert_required], 4500 "name": [_assert_string], 4501 "model_type": [_assert_string], 4502 "source_run_id": [_assert_string], 4503 "params": [_assert_array], 4504 "tags": [_assert_array], 4505 }, 4506 ) 4507 4508 model = _get_tracking_store().create_logged_model( 4509 experiment_id=request_message.experiment_id, 4510 name=request_message.name or None, 4511 model_type=request_message.model_type, 4512 source_run_id=request_message.source_run_id, 4513 params=( 4514 [LoggedModelParameter.from_proto(param) for param in request_message.params] 4515 if request_message.params 4516 else None 4517 ), 4518 tags=( 4519 [LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags] 4520 if request_message.tags 4521 else None 4522 ), 4523 ) 4524 response_message = CreateLoggedModel.Response(model=model.to_proto()) 4525 return _wrap_response(response_message) 4526 4527 4528 @catch_mlflow_exception 4529 @_disable_if_artifacts_only 4530 def _log_logged_model_params(model_id: str): 4531 request_message = _get_request_message( 4532 LogLoggedModelParamsRequest(), 4533 schema={ 4534 "model_id": [_assert_string, _assert_required], 4535 "params": [_assert_array], 4536 }, 4537 ) 4538 params = ( 4539 [LoggedModelParameter.from_proto(param) for param in request_message.params] 4540 if request_message.params 4541 else [] 4542 ) 4543 _get_tracking_store().log_logged_model_params(model_id, params) 4544 return _wrap_response(LogLoggedModelParamsRequest.Response()) 4545 4546 4547 @catch_mlflow_exception 4548 @_disable_if_artifacts_only 4549 def _get_logged_model(model_id: str): 4550 allow_deleted = request.args.get("allow_deleted", "false").lower() == "true" 4551 model = _get_tracking_store().get_logged_model(model_id, allow_deleted=allow_deleted) 4552 response_message = GetLoggedModel.Response(model=model.to_proto()) 4553 return _wrap_response(response_message) 4554 4555 4556 @catch_mlflow_exception 4557 @_disable_if_artifacts_only 4558 def _finalize_logged_model(model_id: str): 4559 request_message = _get_request_message( 4560 FinalizeLoggedModel(), 4561 schema={ 4562 "model_id": [_assert_string, _assert_required], 4563 "status": [_assert_intlike, _assert_required], 4564 }, 4565 ) 4566 model = _get_tracking_store().finalize_logged_model( 4567 request_message.model_id, LoggedModelStatus.from_int(request_message.status) 4568 ) 4569 response_message = FinalizeLoggedModel.Response(model=model.to_proto()) 4570 return _wrap_response(response_message) 4571 4572 4573 @catch_mlflow_exception 4574 @_disable_if_artifacts_only 4575 def _delete_logged_model(model_id: str): 4576 _get_tracking_store().delete_logged_model(model_id) 4577 return _wrap_response(DeleteLoggedModel.Response()) 4578 4579 4580 @catch_mlflow_exception 4581 @_disable_if_artifacts_only 4582 def _set_logged_model_tags(model_id: str): 4583 request_message = _get_request_message( 4584 SetLoggedModelTags(), 4585 schema={"tags": [_assert_array]}, 4586 ) 4587 tags = [LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags] 4588 _get_tracking_store().set_logged_model_tags(model_id, tags) 4589 return _wrap_response(SetLoggedModelTags.Response()) 4590 4591 4592 @catch_mlflow_exception 4593 @_disable_if_artifacts_only 4594 def _delete_logged_model_tag(model_id: str, tag_key: str): 4595 _get_tracking_store().delete_logged_model_tag(model_id, tag_key) 4596 return _wrap_response(DeleteLoggedModelTag.Response()) 4597 4598 4599 @catch_mlflow_exception 4600 @_disable_if_artifacts_only 4601 def _search_logged_models(): 4602 request_message = _get_request_message( 4603 SearchLoggedModels(), 4604 schema={ 4605 "experiment_ids": [ 4606 _assert_array, 4607 _assert_item_type_string, 4608 _assert_required, 4609 ], 4610 "filter": [_assert_string], 4611 "datasets": [_assert_array], 4612 "max_results": [_assert_intlike], 4613 "order_by": [_assert_array], 4614 "page_token": [_assert_string], 4615 }, 4616 ) 4617 models = _get_tracking_store().search_logged_models( 4618 # Convert `RepeatedScalarContainer` objects (experiment_ids and order_by) to `list` 4619 # to avoid serialization issues 4620 experiment_ids=list(request_message.experiment_ids), 4621 filter_string=request_message.filter or None, 4622 datasets=( 4623 [ 4624 { 4625 "dataset_name": d.dataset_name, 4626 "dataset_digest": d.dataset_digest or None, 4627 } 4628 for d in request_message.datasets 4629 ] 4630 if request_message.datasets 4631 else None 4632 ), 4633 max_results=request_message.max_results or None, 4634 order_by=( 4635 [ 4636 { 4637 "field_name": ob.field_name, 4638 "ascending": ob.ascending, 4639 "dataset_name": ob.dataset_name or None, 4640 "dataset_digest": ob.dataset_digest or None, 4641 } 4642 for ob in request_message.order_by 4643 ] 4644 if request_message.order_by 4645 else None 4646 ), 4647 page_token=request_message.page_token or None, 4648 ) 4649 response_message = SearchLoggedModels.Response() 4650 response_message.models.extend([e.to_proto() for e in models]) 4651 if models.token: 4652 response_message.next_page_token = models.token 4653 return _wrap_response(response_message) 4654 4655 4656 @catch_mlflow_exception 4657 @_disable_if_artifacts_only 4658 def _list_logged_model_artifacts(model_id: str): 4659 request_message = _get_request_message( 4660 ListLoggedModelArtifacts(), 4661 schema={"artifact_directory_path": [_assert_string]}, 4662 ) 4663 if request_message.HasField("artifact_directory_path"): 4664 artifact_path = validate_path_is_safe(request_message.artifact_directory_path) 4665 else: 4666 artifact_path = None 4667 4668 return _list_logged_model_artifacts_impl(model_id, artifact_path) 4669 4670 4671 def _list_logged_model_artifacts_impl( 4672 model_id: str, artifact_directory_path: str | None 4673 ) -> Response: 4674 response = ListLoggedModelArtifacts.Response() 4675 logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id) 4676 if _is_servable_proxied_run_artifact_root(logged_model.artifact_location): 4677 artifacts = _list_artifacts_for_proxied_run_artifact_root( 4678 proxied_artifact_root=logged_model.artifact_location, 4679 relative_path=artifact_directory_path, 4680 ) 4681 else: 4682 artifacts = get_artifact_repository(logged_model.artifact_location).list_artifacts( 4683 artifact_directory_path 4684 ) 4685 4686 response.files.extend([a.to_proto() for a in artifacts]) 4687 response.root_uri = logged_model.artifact_location 4688 return _wrap_response(response) 4689 4690 4691 # ============================================================================= 4692 # Scorer Management Handlers 4693 # ============================================================================= 4694 4695 4696 @catch_mlflow_exception 4697 @_disable_if_artifacts_only 4698 def _register_scorer(): 4699 request_message = _get_request_message( 4700 RegisterScorer(), 4701 schema={ 4702 "experiment_id": [_assert_required, _assert_string], 4703 "name": [_assert_required, _assert_string], 4704 "serialized_scorer": [_assert_required, _assert_string], 4705 }, 4706 ) 4707 # Decorator scorers contain a `call_source` field that is executed via exec() during 4708 # deserialization. The Python client blocks this via `_check_can_be_registered()`, but 4709 # that check is client-side only and can be bypassed by calling the REST API directly. 4710 # Enforce the same restriction here in the server handler so it applies regardless of 4711 # how the request arrives. 4712 try: 4713 serialized_data = json.loads(request_message.serialized_scorer) 4714 except json.JSONDecodeError as e: 4715 raise MlflowException.invalid_parameter_value("serialized_scorer must be valid JSON") from e 4716 if serialized_data.get("call_source") is not None: 4717 raise MlflowException.invalid_parameter_value( 4718 DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR 4719 ) 4720 scorer_version = _get_tracking_store().register_scorer( 4721 request_message.experiment_id, 4722 request_message.name, 4723 request_message.serialized_scorer, 4724 ) 4725 response_message = RegisterScorer.Response() 4726 response_message.version = scorer_version.scorer_version 4727 response_message.scorer_id = scorer_version.scorer_id 4728 response_message.experiment_id = scorer_version.experiment_id 4729 response_message.name = scorer_version.scorer_name 4730 response_message.serialized_scorer = scorer_version._serialized_scorer 4731 response_message.creation_time = scorer_version.creation_time 4732 response = Response(mimetype="application/json") 4733 response.set_data(message_to_json(response_message)) 4734 return response 4735 4736 4737 @catch_mlflow_exception 4738 @_disable_if_artifacts_only 4739 def _list_scorers(): 4740 request_message = _get_request_message( 4741 ListScorers(), 4742 schema={"experiment_id": [_assert_required, _assert_string]}, 4743 ) 4744 response_message = ListScorers.Response() 4745 scorers = _get_tracking_store().list_scorers(request_message.experiment_id) 4746 response_message.scorers.extend([scorer.to_proto() for scorer in scorers]) 4747 response = Response(mimetype="application/json") 4748 response.set_data(message_to_json(response_message)) 4749 return response 4750 4751 4752 @catch_mlflow_exception 4753 @_disable_if_artifacts_only 4754 def _list_scorer_versions(): 4755 request_message = _get_request_message( 4756 ListScorerVersions(), 4757 schema={ 4758 "experiment_id": [_assert_required, _assert_string], 4759 "name": [_assert_required, _assert_string], 4760 }, 4761 ) 4762 response_message = ListScorerVersions.Response() 4763 scorers = _get_tracking_store().list_scorer_versions( 4764 request_message.experiment_id, request_message.name 4765 ) 4766 response_message.scorers.extend([scorer.to_proto() for scorer in scorers]) 4767 response = Response(mimetype="application/json") 4768 response.set_data(message_to_json(response_message)) 4769 return response 4770 4771 4772 @catch_mlflow_exception 4773 @_disable_if_artifacts_only 4774 def _get_scorer(): 4775 request_message = _get_request_message( 4776 GetScorer(), 4777 schema={ 4778 "experiment_id": [_assert_required, _assert_string], 4779 "name": [_assert_required, _assert_string], 4780 "version": [_assert_intlike], 4781 }, 4782 ) 4783 response_message = GetScorer.Response() 4784 scorer_version = _get_tracking_store().get_scorer( 4785 request_message.experiment_id, 4786 request_message.name, 4787 request_message.version if request_message.HasField("version") else None, 4788 ) 4789 response_message.scorer.CopyFrom(scorer_version.to_proto()) 4790 response = Response(mimetype="application/json") 4791 response.set_data(message_to_json(response_message)) 4792 return response 4793 4794 4795 @catch_mlflow_exception 4796 @_disable_if_artifacts_only 4797 def _delete_scorer(): 4798 request_message = _get_request_message( 4799 DeleteScorer(), 4800 schema={ 4801 "experiment_id": [_assert_required, _assert_string], 4802 "name": [_assert_required, _assert_string], 4803 "version": [_assert_intlike], 4804 }, 4805 ) 4806 _get_tracking_store().delete_scorer( 4807 request_message.experiment_id, 4808 request_message.name, 4809 request_message.version if request_message.HasField("version") else None, 4810 ) 4811 response_message = DeleteScorer.Response() 4812 response = Response(mimetype="application/json") 4813 response.set_data(message_to_json(response_message)) 4814 return response 4815 4816 4817 @catch_mlflow_exception 4818 @_disable_if_artifacts_only 4819 def _get_online_scoring_configs(): 4820 """ 4821 Get online scoring configurations for a list of scorer IDs. 4822 4823 Query Parameters: 4824 scorer_ids: List of scorer IDs to fetch configurations for. 4825 4826 Returns: 4827 JSON response containing a list of configurations. 4828 """ 4829 request_json = _get_validated_flask_request_json( 4830 flask_request=request, 4831 schema={ 4832 "scorer_ids": [_assert_required, _assert_array, _assert_item_type_string], 4833 }, 4834 ) 4835 4836 scorer_ids = request_json["scorer_ids"] 4837 configs = _get_tracking_store().get_online_scoring_configs(scorer_ids) 4838 4839 response = Response(mimetype="application/json") 4840 response.set_data(json.dumps({"configs": [c.to_dict() for c in configs]})) 4841 return response 4842 4843 4844 @catch_mlflow_exception 4845 @_disable_if_artifacts_only 4846 def _upsert_online_scoring_config(): 4847 """ 4848 Update the online scoring configuration for a registered scorer. 4849 4850 Request Body (JSON): 4851 experiment_id: The ID of the Experiment containing the scorer. 4852 name: The scorer name. 4853 sample_rate: The sampling rate (0.0 to 1.0). 4854 filter_string: Optional filter string for trace selection. 4855 4856 Returns: 4857 JSON response containing the updated configuration. 4858 """ 4859 request_json = _get_validated_flask_request_json( 4860 flask_request=request, 4861 schema={ 4862 "experiment_id": [_assert_required, _assert_string], 4863 "name": [_assert_required, _assert_string], 4864 "sample_rate": [_assert_required], 4865 "filter_string": [], 4866 }, 4867 ) 4868 4869 filter_string = request_json.get("filter_string") 4870 if filter_string is not None and not isinstance(filter_string, str): 4871 raise MlflowException( 4872 f"Invalid value {filter_string!r} for parameter 'filter_string' supplied: " 4873 f"Value was of type '{type(filter_string).__name__}'. " 4874 "Expected type 'str' or None.", 4875 error_code=INVALID_PARAMETER_VALUE, 4876 ) 4877 4878 config = _get_tracking_store().upsert_online_scoring_config( 4879 experiment_id=request_json["experiment_id"], 4880 scorer_name=request_json["name"], 4881 sample_rate=float(request_json["sample_rate"]), 4882 filter_string=filter_string, 4883 ) 4884 4885 response = Response(mimetype="application/json") 4886 response.set_data(json.dumps({"config": config.to_dict()})) 4887 return response 4888 4889 4890 # ============================================================================= 4891 # Secrets Management Handlers 4892 # ============================================================================= 4893 4894 4895 @catch_mlflow_exception 4896 @_disable_if_artifacts_only 4897 def _create_gateway_secret(): 4898 request_message = _get_request_message( 4899 CreateGatewaySecret(), 4900 schema={ 4901 "secret_name": [_assert_required, _assert_string], 4902 "secret_value": [_assert_secret_value], 4903 "provider": [_assert_string], 4904 "created_by": [_assert_string], 4905 }, 4906 ) 4907 # Empty map means no auth_config was provided 4908 auth_config = dict(request_message.auth_config) or None 4909 4910 secret = _get_tracking_store().create_gateway_secret( 4911 secret_name=request_message.secret_name, 4912 secret_value=dict(request_message.secret_value), 4913 provider=request_message.provider or None, 4914 auth_config=auth_config, 4915 created_by=request_message.created_by or None, 4916 ) 4917 response_message = CreateGatewaySecret.Response() 4918 response_message.secret.CopyFrom(secret.to_proto()) 4919 return _wrap_response(response_message) 4920 4921 4922 @catch_mlflow_exception 4923 @_disable_if_artifacts_only 4924 def _get_gateway_secret_info(): 4925 request_message = _get_request_message( 4926 GetGatewaySecretInfo(), 4927 schema={ 4928 "secret_id": [_assert_required, _assert_string], 4929 }, 4930 ) 4931 secret = _get_tracking_store().get_secret_info(request_message.secret_id) 4932 response_message = GetGatewaySecretInfo.Response() 4933 response_message.secret.CopyFrom(secret.to_proto()) 4934 return _wrap_response(response_message) 4935 4936 4937 @catch_mlflow_exception 4938 @_disable_if_artifacts_only 4939 def _update_gateway_secret(): 4940 request_message = _get_request_message( 4941 UpdateGatewaySecret(), 4942 schema={ 4943 "secret_id": [_assert_required, _assert_string], 4944 "updated_by": [_assert_string], 4945 }, 4946 ) 4947 # Empty map means no auth_config was provided 4948 auth_config = dict(request_message.auth_config) or None 4949 4950 # Empty map means no update to secret_value 4951 secret_value = dict(request_message.secret_value) or None 4952 4953 secret = _get_tracking_store().update_gateway_secret( 4954 secret_id=request_message.secret_id, 4955 secret_value=secret_value, 4956 auth_config=auth_config, 4957 updated_by=request_message.updated_by or None, 4958 ) 4959 response_message = UpdateGatewaySecret.Response() 4960 response_message.secret.CopyFrom(secret.to_proto()) 4961 return _wrap_response(response_message) 4962 4963 4964 @catch_mlflow_exception 4965 @_disable_if_artifacts_only 4966 def _delete_gateway_secret(): 4967 request_message = _get_request_message( 4968 DeleteGatewaySecret(), 4969 schema={ 4970 "secret_id": [_assert_required, _assert_string], 4971 }, 4972 ) 4973 _get_tracking_store().delete_gateway_secret(request_message.secret_id) 4974 response_message = DeleteGatewaySecret.Response() 4975 return _wrap_response(response_message) 4976 4977 4978 @catch_mlflow_exception 4979 @_disable_if_artifacts_only 4980 def _list_gateway_secrets(): 4981 request_message = _get_request_message( 4982 ListGatewaySecretInfos(), 4983 schema={ 4984 "provider": [_assert_string], 4985 }, 4986 ) 4987 secrets = _get_tracking_store().list_secret_infos( 4988 provider=request_message.provider or None, 4989 ) 4990 response_message = ListGatewaySecretInfos.Response() 4991 response_message.secrets.extend([s.to_proto() for s in secrets]) 4992 return _wrap_response(response_message) 4993 4994 4995 # ============================================================================= 4996 # Endpoints Management Handlers 4997 # ============================================================================= 4998 4999 5000 @catch_mlflow_exception 5001 @_disable_if_artifacts_only 5002 def _create_gateway_endpoint(): 5003 request_message = _get_request_message( 5004 CreateGatewayEndpoint(), 5005 schema={ 5006 "name": [_assert_required, _assert_string], 5007 "created_by": [_assert_string], 5008 "model_configs": [_assert_required], 5009 "routing_strategy": [_assert_string], 5010 }, 5011 ) 5012 if request_message.name and not is_valid_endpoint_name(request_message.name): 5013 raise MlflowException.invalid_parameter_value( 5014 f"Invalid endpoint name '{request_message.name}'. " 5015 "Name can only contain letters, numbers, underscores, hyphens, and dots." 5016 ) 5017 # Convert proto fallback_config to entity FallbackConfig 5018 fallback_config = None 5019 if request_message.HasField("fallback_config"): 5020 fallback_config = FallbackConfig( 5021 strategy=FallbackStrategy.from_proto(request_message.fallback_config.strategy) 5022 if request_message.fallback_config.HasField("strategy") 5023 else None, 5024 max_attempts=request_message.fallback_config.max_attempts 5025 if request_message.fallback_config.HasField("max_attempts") 5026 else None, 5027 ) 5028 5029 model_configs = [ 5030 GatewayEndpointModelConfig.from_proto(config) for config in request_message.model_configs 5031 ] 5032 5033 # Determine experiment_id and usage_tracking 5034 experiment_id = ( 5035 request_message.experiment_id if request_message.HasField("experiment_id") else None 5036 ) 5037 usage_tracking = ( 5038 request_message.usage_tracking if request_message.HasField("usage_tracking") else True 5039 ) 5040 5041 endpoint = _get_tracking_store().create_gateway_endpoint( 5042 name=request_message.name or None, 5043 model_configs=model_configs, 5044 created_by=request_message.created_by or None, 5045 routing_strategy=RoutingStrategyEntity.from_proto(request_message.routing_strategy) 5046 if request_message.HasField("routing_strategy") 5047 else None, 5048 fallback_config=fallback_config, 5049 experiment_id=experiment_id, 5050 usage_tracking=usage_tracking, 5051 ) 5052 response_message = CreateGatewayEndpoint.Response() 5053 response_message.endpoint.CopyFrom(endpoint.to_proto()) 5054 return _wrap_response(response_message) 5055 5056 5057 @catch_mlflow_exception 5058 @_disable_if_artifacts_only 5059 def _get_gateway_endpoint(): 5060 request_message = _get_request_message( 5061 GetGatewayEndpoint(), 5062 schema={ 5063 "endpoint_id": [_assert_string], 5064 "name": [_assert_string], 5065 }, 5066 ) 5067 endpoint = _get_tracking_store().get_gateway_endpoint( 5068 endpoint_id=request_message.endpoint_id or None, 5069 name=request_message.name or None, 5070 ) 5071 response_message = GetGatewayEndpoint.Response() 5072 response_message.endpoint.CopyFrom(endpoint.to_proto()) 5073 return _wrap_response(response_message) 5074 5075 5076 @catch_mlflow_exception 5077 @_disable_if_artifacts_only 5078 def _update_gateway_endpoint(): 5079 request_message = _get_request_message( 5080 UpdateGatewayEndpoint(), 5081 schema={ 5082 "endpoint_id": [_assert_required, _assert_string], 5083 "name": [_assert_string], 5084 "updated_by": [_assert_string], 5085 "routing_strategy": [_assert_string], 5086 }, 5087 ) 5088 if request_message.name and not is_valid_endpoint_name(request_message.name): 5089 raise MlflowException.invalid_parameter_value( 5090 f"Invalid endpoint name '{request_message.name}'. " 5091 "Name can only contain letters, numbers, underscores, hyphens, and dots." 5092 ) 5093 # Convert proto fallback_config to entity FallbackConfig 5094 fallback_config = None 5095 if request_message.HasField("fallback_config"): 5096 fallback_config = FallbackConfig( 5097 strategy=FallbackStrategy.from_proto(request_message.fallback_config.strategy) 5098 if request_message.fallback_config.HasField("strategy") 5099 else None, 5100 max_attempts=request_message.fallback_config.max_attempts 5101 if request_message.fallback_config.HasField("max_attempts") 5102 else None, 5103 ) 5104 5105 # Convert proto model_configs to entity GatewayEndpointModelConfig list 5106 model_configs = None 5107 if request_message.model_configs: 5108 model_configs = [ 5109 GatewayEndpointModelConfig.from_proto(config) 5110 for config in request_message.model_configs 5111 ] 5112 5113 # Determine experiment_id and usage_tracking 5114 experiment_id = ( 5115 request_message.experiment_id if request_message.HasField("experiment_id") else None 5116 ) 5117 usage_tracking = ( 5118 request_message.usage_tracking if request_message.HasField("usage_tracking") else None 5119 ) 5120 5121 endpoint = _get_tracking_store().update_gateway_endpoint( 5122 endpoint_id=request_message.endpoint_id, 5123 name=request_message.name or None, 5124 model_configs=model_configs, 5125 updated_by=request_message.updated_by or None, 5126 routing_strategy=RoutingStrategyEntity.from_proto(request_message.routing_strategy) 5127 if request_message.HasField("routing_strategy") 5128 else None, 5129 fallback_config=fallback_config, 5130 experiment_id=experiment_id, 5131 usage_tracking=usage_tracking, 5132 ) 5133 response_message = UpdateGatewayEndpoint.Response() 5134 response_message.endpoint.CopyFrom(endpoint.to_proto()) 5135 return _wrap_response(response_message) 5136 5137 5138 @catch_mlflow_exception 5139 @_disable_if_artifacts_only 5140 def _delete_gateway_endpoint(): 5141 request_message = _get_request_message( 5142 DeleteGatewayEndpoint(), 5143 schema={ 5144 "endpoint_id": [_assert_required, _assert_string], 5145 }, 5146 ) 5147 _get_tracking_store().delete_gateway_endpoint(request_message.endpoint_id) 5148 response_message = DeleteGatewayEndpoint.Response() 5149 return _wrap_response(response_message) 5150 5151 5152 @catch_mlflow_exception 5153 @_disable_if_artifacts_only 5154 def _list_gateway_endpoints(): 5155 request_message = _get_request_message( 5156 ListGatewayEndpoints(), 5157 schema={ 5158 "provider": [_assert_string], 5159 }, 5160 ) 5161 endpoints = _get_tracking_store().list_gateway_endpoints( 5162 provider=request_message.provider or None, 5163 ) 5164 response_message = ListGatewayEndpoints.Response() 5165 response_message.endpoints.extend([e.to_proto() for e in endpoints]) 5166 return _wrap_response(response_message) 5167 5168 5169 # ============================================================================= 5170 # Model Definitions Management Handlers 5171 # ============================================================================= 5172 5173 5174 @catch_mlflow_exception 5175 @_disable_if_artifacts_only 5176 def _create_gateway_model_definition(): 5177 request_message = _get_request_message( 5178 CreateGatewayModelDefinition(), 5179 schema={ 5180 "name": [_assert_required, _assert_string], 5181 "secret_id": [_assert_required, _assert_string], 5182 "provider": [_assert_required, _assert_string], 5183 "model_name": [_assert_required, _assert_string], 5184 "created_by": [_assert_string], 5185 }, 5186 ) 5187 model_definition = _get_tracking_store().create_gateway_model_definition( 5188 name=request_message.name, 5189 secret_id=request_message.secret_id, 5190 provider=request_message.provider, 5191 model_name=request_message.model_name, 5192 created_by=request_message.created_by or None, 5193 ) 5194 response_message = CreateGatewayModelDefinition.Response() 5195 response_message.model_definition.CopyFrom(model_definition.to_proto()) 5196 return _wrap_response(response_message) 5197 5198 5199 @catch_mlflow_exception 5200 @_disable_if_artifacts_only 5201 def _get_gateway_model_definition(): 5202 request_message = _get_request_message( 5203 GetGatewayModelDefinition(), 5204 schema={ 5205 "model_definition_id": [_assert_required, _assert_string], 5206 }, 5207 ) 5208 model_definition = _get_tracking_store().get_gateway_model_definition( 5209 request_message.model_definition_id 5210 ) 5211 response_message = GetGatewayModelDefinition.Response() 5212 response_message.model_definition.CopyFrom(model_definition.to_proto()) 5213 return _wrap_response(response_message) 5214 5215 5216 @catch_mlflow_exception 5217 @_disable_if_artifacts_only 5218 def _list_gateway_model_definitions(): 5219 request_message = _get_request_message( 5220 ListGatewayModelDefinitions(), 5221 schema={ 5222 "provider": [_assert_string], 5223 "secret_id": [_assert_string], 5224 }, 5225 ) 5226 model_definitions = _get_tracking_store().list_gateway_model_definitions( 5227 provider=request_message.provider or None, 5228 secret_id=request_message.secret_id or None, 5229 ) 5230 response_message = ListGatewayModelDefinitions.Response() 5231 response_message.model_definitions.extend([m.to_proto() for m in model_definitions]) 5232 return _wrap_response(response_message) 5233 5234 5235 @catch_mlflow_exception 5236 @_disable_if_artifacts_only 5237 def _update_gateway_model_definition(): 5238 request_message = _get_request_message( 5239 UpdateGatewayModelDefinition(), 5240 schema={ 5241 "model_definition_id": [_assert_required, _assert_string], 5242 "name": [_assert_string], 5243 "secret_id": [_assert_string], 5244 "model_name": [_assert_string], 5245 "updated_by": [_assert_string], 5246 "provider": [_assert_string], 5247 }, 5248 ) 5249 model_definition = _get_tracking_store().update_gateway_model_definition( 5250 model_definition_id=request_message.model_definition_id, 5251 name=request_message.name or None, 5252 secret_id=request_message.secret_id or None, 5253 model_name=request_message.model_name or None, 5254 updated_by=request_message.updated_by or None, 5255 provider=request_message.provider or None, 5256 ) 5257 response_message = UpdateGatewayModelDefinition.Response() 5258 response_message.model_definition.CopyFrom(model_definition.to_proto()) 5259 return _wrap_response(response_message) 5260 5261 5262 @catch_mlflow_exception 5263 @_disable_if_artifacts_only 5264 def _delete_gateway_model_definition(): 5265 request_message = _get_request_message( 5266 DeleteGatewayModelDefinition(), 5267 schema={ 5268 "model_definition_id": [_assert_required, _assert_string], 5269 }, 5270 ) 5271 _get_tracking_store().delete_gateway_model_definition(request_message.model_definition_id) 5272 response_message = DeleteGatewayModelDefinition.Response() 5273 return _wrap_response(response_message) 5274 5275 5276 # ============================================================================= 5277 # Endpoint Model Mappings Management Handlers 5278 # ============================================================================= 5279 5280 5281 @catch_mlflow_exception 5282 @_disable_if_artifacts_only 5283 def _attach_model_to_gateway_endpoint(): 5284 request_message = _get_request_message( 5285 AttachModelToGatewayEndpoint(), 5286 schema={ 5287 "endpoint_id": [_assert_required, _assert_string], 5288 "model_config": [_assert_required], 5289 "created_by": [_assert_string], 5290 }, 5291 ) 5292 5293 model_config = GatewayEndpointModelConfig.from_proto(request_message.model_config) 5294 5295 mapping = _get_tracking_store().attach_model_to_endpoint( 5296 endpoint_id=request_message.endpoint_id, 5297 model_config=model_config, 5298 created_by=request_message.created_by or None, 5299 ) 5300 response_message = AttachModelToGatewayEndpoint.Response() 5301 response_message.mapping.CopyFrom(mapping.to_proto()) 5302 return _wrap_response(response_message) 5303 5304 5305 @catch_mlflow_exception 5306 @_disable_if_artifacts_only 5307 def _detach_model_from_gateway_endpoint(): 5308 request_message = _get_request_message( 5309 DetachModelFromGatewayEndpoint(), 5310 schema={ 5311 "endpoint_id": [_assert_required, _assert_string], 5312 "model_definition_id": [_assert_required, _assert_string], 5313 }, 5314 ) 5315 _get_tracking_store().detach_model_from_endpoint( 5316 endpoint_id=request_message.endpoint_id, 5317 model_definition_id=request_message.model_definition_id, 5318 ) 5319 response_message = DetachModelFromGatewayEndpoint.Response() 5320 return _wrap_response(response_message) 5321 5322 5323 # ============================================================================= 5324 # Endpoint Bindings Management Handlers 5325 # ============================================================================= 5326 5327 5328 @catch_mlflow_exception 5329 @_disable_if_artifacts_only 5330 def _create_gateway_endpoint_binding(): 5331 request_message = _get_request_message( 5332 CreateGatewayEndpointBinding(), 5333 schema={ 5334 "endpoint_id": [_assert_required, _assert_string], 5335 "resource_type": [_assert_required, _assert_string], 5336 "resource_id": [_assert_required, _assert_string], 5337 "created_by": [_assert_string], 5338 }, 5339 ) 5340 binding = _get_tracking_store().create_endpoint_binding( 5341 endpoint_id=request_message.endpoint_id, 5342 resource_type=GatewayResourceType(request_message.resource_type), 5343 resource_id=request_message.resource_id, 5344 created_by=request_message.created_by or None, 5345 ) 5346 response_message = CreateGatewayEndpointBinding.Response() 5347 response_message.binding.CopyFrom(binding.to_proto()) 5348 return _wrap_response(response_message) 5349 5350 5351 @catch_mlflow_exception 5352 @_disable_if_artifacts_only 5353 def _delete_gateway_endpoint_binding(): 5354 request_message = _get_request_message( 5355 DeleteGatewayEndpointBinding(), 5356 schema={ 5357 "endpoint_id": [_assert_required, _assert_string], 5358 "resource_type": [_assert_required, _assert_string], 5359 "resource_id": [_assert_required, _assert_string], 5360 }, 5361 ) 5362 _get_tracking_store().delete_endpoint_binding( 5363 endpoint_id=request_message.endpoint_id, 5364 resource_type=request_message.resource_type, 5365 resource_id=request_message.resource_id, 5366 ) 5367 response_message = DeleteGatewayEndpointBinding.Response() 5368 return _wrap_response(response_message) 5369 5370 5371 @catch_mlflow_exception 5372 @_disable_if_artifacts_only 5373 def _list_gateway_endpoint_bindings(): 5374 request_message = _get_request_message( 5375 ListGatewayEndpointBindings(), 5376 schema={ 5377 "endpoint_id": [_assert_string], 5378 "resource_type": [_assert_string], 5379 "resource_id": [_assert_string], 5380 }, 5381 ) 5382 bindings = _get_tracking_store().list_endpoint_bindings( 5383 endpoint_id=request_message.endpoint_id or None, 5384 resource_type=request_message.resource_type or None, 5385 resource_id=request_message.resource_id or None, 5386 ) 5387 response_message = ListGatewayEndpointBindings.Response() 5388 response_message.bindings.extend([b.to_proto() for b in bindings]) 5389 return _wrap_response(response_message) 5390 5391 5392 @catch_mlflow_exception 5393 @_disable_if_artifacts_only 5394 def _set_gateway_endpoint_tag(): 5395 request_message = _get_request_message( 5396 SetGatewayEndpointTag(), 5397 schema={ 5398 "endpoint_id": [_assert_required, _assert_string], 5399 "key": [_assert_required, _assert_string], 5400 "value": [_assert_string], 5401 }, 5402 ) 5403 tag = GatewayEndpointTag(request_message.key, request_message.value) 5404 _get_tracking_store().set_gateway_endpoint_tag(request_message.endpoint_id, tag) 5405 response_message = SetGatewayEndpointTag.Response() 5406 response = Response(mimetype="application/json") 5407 response.set_data(message_to_json(response_message)) 5408 return response 5409 5410 5411 @catch_mlflow_exception 5412 @_disable_if_artifacts_only 5413 def _delete_gateway_endpoint_tag(): 5414 request_message = _get_request_message( 5415 DeleteGatewayEndpointTag(), 5416 schema={ 5417 "endpoint_id": [_assert_required, _assert_string], 5418 "key": [_assert_required, _assert_string], 5419 }, 5420 ) 5421 _get_tracking_store().delete_gateway_endpoint_tag( 5422 request_message.endpoint_id, request_message.key 5423 ) 5424 response_message = DeleteGatewayEndpointTag.Response() 5425 response = Response(mimetype="application/json") 5426 response.set_data(message_to_json(response_message)) 5427 return response 5428 5429 5430 # ============================================================================= 5431 # Budget Policy Management Handlers 5432 # ============================================================================= 5433 5434 5435 @catch_mlflow_exception 5436 @_disable_if_artifacts_only 5437 def _create_budget_policy(): 5438 request_message = _get_request_message( 5439 CreateGatewayBudgetPolicy(), 5440 schema={ 5441 "budget_unit": [_assert_required], 5442 "budget_amount": [_assert_required], 5443 "duration": [_assert_required], 5444 "target_scope": [_assert_required], 5445 "budget_action": [_assert_required], 5446 "created_by": [_assert_string], 5447 }, 5448 ) 5449 budget_unit = BudgetUnit.from_proto(request_message.budget_unit) 5450 if budget_unit is None: 5451 raise MlflowException( 5452 message=f"Invalid budget_unit: {request_message.budget_unit}", 5453 error_code=INVALID_PARAMETER_VALUE, 5454 ) 5455 duration_unit = BudgetDurationUnit.from_proto(request_message.duration.unit) 5456 if duration_unit is None: 5457 raise MlflowException( 5458 message=f"Invalid duration.unit: {request_message.duration.unit}", 5459 error_code=INVALID_PARAMETER_VALUE, 5460 ) 5461 if request_message.duration.value <= 0: 5462 raise MlflowException( 5463 message=f"duration.value must be a positive integer, got " 5464 f"{request_message.duration.value}", 5465 error_code=INVALID_PARAMETER_VALUE, 5466 ) 5467 target_scope = BudgetTargetScope.from_proto(request_message.target_scope) 5468 if target_scope is None: 5469 raise MlflowException( 5470 message=f"Invalid target_scope: {request_message.target_scope}", 5471 error_code=INVALID_PARAMETER_VALUE, 5472 ) 5473 budget_action = BudgetAction.from_proto(request_message.budget_action) 5474 if budget_action is None: 5475 raise MlflowException( 5476 message=f"Invalid budget_action: {request_message.budget_action}", 5477 error_code=INVALID_PARAMETER_VALUE, 5478 ) 5479 store = _get_tracking_store() 5480 policy = store.create_budget_policy( 5481 budget_unit=budget_unit, 5482 budget_amount=request_message.budget_amount, 5483 duration=BudgetDuration(unit=duration_unit, value=request_message.duration.value), 5484 target_scope=target_scope, 5485 budget_action=budget_action, 5486 created_by=request_message.created_by or None, 5487 ) 5488 get_budget_tracker().invalidate() 5489 maybe_refresh_budget_policies(store) 5490 response_message = CreateGatewayBudgetPolicy.Response() 5491 response_message.budget_policy.CopyFrom(policy.to_proto()) 5492 return _wrap_response(response_message) 5493 5494 5495 @catch_mlflow_exception 5496 @_disable_if_artifacts_only 5497 def _get_budget_policy(): 5498 request_message = _get_request_message( 5499 GetGatewayBudgetPolicy(), 5500 schema={ 5501 "budget_policy_id": [_assert_required, _assert_string], 5502 }, 5503 ) 5504 policy = _get_tracking_store().get_budget_policy( 5505 budget_policy_id=request_message.budget_policy_id, 5506 ) 5507 response_message = GetGatewayBudgetPolicy.Response() 5508 response_message.budget_policy.CopyFrom(policy.to_proto()) 5509 return _wrap_response(response_message) 5510 5511 5512 @catch_mlflow_exception 5513 @_disable_if_artifacts_only 5514 def _update_budget_policy(): 5515 request_message = _get_request_message( 5516 UpdateGatewayBudgetPolicy(), 5517 schema={ 5518 "budget_policy_id": [_assert_required, _assert_string], 5519 "updated_by": [_assert_string], 5520 }, 5521 ) 5522 budget_unit = None 5523 if request_message.HasField("budget_unit"): 5524 budget_unit = BudgetUnit.from_proto(request_message.budget_unit) 5525 if budget_unit is None: 5526 raise MlflowException( 5527 message=f"Invalid budget_unit: {request_message.budget_unit}", 5528 error_code=INVALID_PARAMETER_VALUE, 5529 ) 5530 duration = None 5531 if request_message.HasField("duration"): 5532 duration_unit = BudgetDurationUnit.from_proto(request_message.duration.unit) 5533 if duration_unit is None: 5534 raise MlflowException( 5535 message=f"Invalid duration.unit: {request_message.duration.unit}", 5536 error_code=INVALID_PARAMETER_VALUE, 5537 ) 5538 if request_message.duration.value <= 0: 5539 raise MlflowException( 5540 message=f"duration.value must be a positive integer, got " 5541 f"{request_message.duration.value}", 5542 error_code=INVALID_PARAMETER_VALUE, 5543 ) 5544 duration = BudgetDuration(unit=duration_unit, value=request_message.duration.value) 5545 target_scope = None 5546 if request_message.HasField("target_scope"): 5547 target_scope = BudgetTargetScope.from_proto(request_message.target_scope) 5548 if target_scope is None: 5549 raise MlflowException( 5550 message=f"Invalid target_scope: {request_message.target_scope}", 5551 error_code=INVALID_PARAMETER_VALUE, 5552 ) 5553 budget_action = None 5554 if request_message.HasField("budget_action"): 5555 budget_action = BudgetAction.from_proto(request_message.budget_action) 5556 if budget_action is None: 5557 raise MlflowException( 5558 message=f"Invalid budget_action: {request_message.budget_action}", 5559 error_code=INVALID_PARAMETER_VALUE, 5560 ) 5561 store = _get_tracking_store() 5562 policy = store.update_budget_policy( 5563 budget_policy_id=request_message.budget_policy_id, 5564 budget_unit=budget_unit, 5565 budget_amount=request_message.budget_amount 5566 if request_message.HasField("budget_amount") 5567 else None, 5568 duration=duration, 5569 target_scope=target_scope, 5570 budget_action=budget_action, 5571 updated_by=request_message.updated_by or None, 5572 ) 5573 get_budget_tracker().invalidate() 5574 maybe_refresh_budget_policies(store) 5575 response_message = UpdateGatewayBudgetPolicy.Response() 5576 response_message.budget_policy.CopyFrom(policy.to_proto()) 5577 return _wrap_response(response_message) 5578 5579 5580 @catch_mlflow_exception 5581 @_disable_if_artifacts_only 5582 def _delete_budget_policy(): 5583 request_message = _get_request_message( 5584 DeleteGatewayBudgetPolicy(), 5585 schema={ 5586 "budget_policy_id": [_assert_required, _assert_string], 5587 }, 5588 ) 5589 store = _get_tracking_store() 5590 store.delete_budget_policy(request_message.budget_policy_id) 5591 get_budget_tracker().invalidate() 5592 maybe_refresh_budget_policies(store) 5593 response_message = DeleteGatewayBudgetPolicy.Response() 5594 return _wrap_response(response_message) 5595 5596 5597 @catch_mlflow_exception 5598 @_disable_if_artifacts_only 5599 def _list_budget_policies(): 5600 request_message = _get_request_message( 5601 ListGatewayBudgetPolicies(), 5602 schema={ 5603 "max_results": [_assert_intlike], 5604 "page_token": [_assert_string], 5605 }, 5606 ) 5607 budget_policies = _get_tracking_store().list_budget_policies( 5608 max_results=request_message.max_results or SEARCH_MAX_RESULTS_DEFAULT, 5609 page_token=request_message.page_token or None, 5610 ) 5611 response_message = ListGatewayBudgetPolicies.Response() 5612 response_message.budget_policies.extend([p.to_proto() for p in budget_policies]) 5613 if budget_policies.token: 5614 response_message.next_page_token = budget_policies.token 5615 return _wrap_response(response_message) 5616 5617 5618 @catch_mlflow_exception 5619 @_disable_if_artifacts_only 5620 def _list_budget_windows(): 5621 _get_request_message(ListGatewayBudgetWindows()) 5622 store = _get_tracking_store() 5623 maybe_refresh_budget_policies(store) 5624 windows = get_budget_tracker().get_all_windows() 5625 response_message = ListGatewayBudgetWindows.Response() 5626 for w in windows: 5627 window_msg = ListGatewayBudgetWindows.BudgetWindow( 5628 budget_policy_id=w.policy.budget_policy_id, 5629 window_start_ms=int(w.window_start.timestamp() * 1000), 5630 window_end_ms=int(w.window_end.timestamp() * 1000), 5631 current_spend=w.cumulative_spend, 5632 ) 5633 response_message.windows.append(window_msg) 5634 return _wrap_response(response_message) 5635 5636 5637 @catch_mlflow_exception 5638 @_disable_if_artifacts_only 5639 def _create_gateway_guardrail(): 5640 request_message = _get_request_message( 5641 CreateGatewayGuardrail(), 5642 schema={ 5643 "name": [_assert_required, _assert_string], 5644 "scorer_id": [_assert_required, _assert_string], 5645 "scorer_version": [_assert_required, _assert_intlike], 5646 "stage": [_assert_required], 5647 "action": [_assert_required], 5648 "action_endpoint_id": [_assert_string], 5649 }, 5650 ) 5651 from mlflow.entities.gateway_guardrail import GuardrailAction, GuardrailStage 5652 5653 stage = GuardrailStage.from_proto(request_message.stage) 5654 if stage is None: 5655 raise MlflowException( 5656 message=f"Invalid stage: {request_message.stage}", 5657 error_code=INVALID_PARAMETER_VALUE, 5658 ) 5659 action = GuardrailAction.from_proto(request_message.action) 5660 if action is None: 5661 raise MlflowException( 5662 message=f"Invalid action: {request_message.action}", 5663 error_code=INVALID_PARAMETER_VALUE, 5664 ) 5665 guardrail = _get_tracking_store().create_gateway_guardrail( 5666 name=request_message.name, 5667 scorer_id=request_message.scorer_id, 5668 scorer_version=request_message.scorer_version, 5669 stage=stage, 5670 action=action, 5671 action_endpoint_id=request_message.action_endpoint_id or None, 5672 created_by=_get_user(), 5673 ) 5674 response_message = CreateGatewayGuardrail.Response() 5675 response_message.guardrail.CopyFrom(guardrail.to_proto()) 5676 return _wrap_response(response_message) 5677 5678 5679 @catch_mlflow_exception 5680 @_disable_if_artifacts_only 5681 def _get_gateway_guardrail(): 5682 request_message = _get_request_message( 5683 GetGatewayGuardrail(), 5684 schema={"guardrail_id": [_assert_required, _assert_string]}, 5685 ) 5686 guardrail = _get_tracking_store().get_gateway_guardrail( 5687 guardrail_id=request_message.guardrail_id, 5688 ) 5689 response_message = GetGatewayGuardrail.Response() 5690 response_message.guardrail.CopyFrom(guardrail.to_proto()) 5691 return _wrap_response(response_message) 5692 5693 5694 @catch_mlflow_exception 5695 @_disable_if_artifacts_only 5696 def _delete_gateway_guardrail(): 5697 request_message = _get_request_message( 5698 DeleteGatewayGuardrail(), 5699 schema={"guardrail_id": [_assert_required, _assert_string]}, 5700 ) 5701 _get_tracking_store().delete_gateway_guardrail(request_message.guardrail_id) 5702 return _wrap_response(DeleteGatewayGuardrail.Response()) 5703 5704 5705 @catch_mlflow_exception 5706 @_disable_if_artifacts_only 5707 def _list_gateway_guardrails(): 5708 request_message = _get_request_message( 5709 ListGatewayGuardrails(), 5710 schema={ 5711 "max_results": [_assert_intlike], 5712 "page_token": [_assert_string], 5713 }, 5714 ) 5715 guardrails = _get_tracking_store().list_gateway_guardrails( 5716 max_results=request_message.max_results or SEARCH_MAX_RESULTS_DEFAULT, 5717 page_token=request_message.page_token or None, 5718 ) 5719 response_message = ListGatewayGuardrails.Response() 5720 response_message.guardrails.extend([g.to_proto() for g in guardrails]) 5721 if guardrails.token: 5722 response_message.next_page_token = guardrails.token 5723 return _wrap_response(response_message) 5724 5725 5726 @catch_mlflow_exception 5727 @_disable_if_artifacts_only 5728 def _add_guardrail_to_endpoint(): 5729 request_message = _get_request_message( 5730 AddGuardrailToEndpoint(), 5731 schema={ 5732 "endpoint_id": [_assert_required, _assert_string], 5733 "guardrail_id": [_assert_required, _assert_string], 5734 "execution_order": [_assert_intlike], 5735 }, 5736 ) 5737 config = _get_tracking_store().add_guardrail_to_endpoint( 5738 endpoint_id=request_message.endpoint_id, 5739 guardrail_id=request_message.guardrail_id, 5740 execution_order=( 5741 request_message.execution_order if request_message.HasField("execution_order") else None 5742 ), 5743 created_by=_get_user(), 5744 ) 5745 response_message = AddGuardrailToEndpoint.Response() 5746 response_message.config.CopyFrom(config.to_proto()) 5747 return _wrap_response(response_message) 5748 5749 5750 @catch_mlflow_exception 5751 @_disable_if_artifacts_only 5752 def _remove_guardrail_from_endpoint(): 5753 request_message = _get_request_message( 5754 RemoveGuardrailFromEndpoint(), 5755 schema={ 5756 "endpoint_id": [_assert_required, _assert_string], 5757 "guardrail_id": [_assert_required, _assert_string], 5758 }, 5759 ) 5760 _get_tracking_store().remove_guardrail_from_endpoint( 5761 endpoint_id=request_message.endpoint_id, 5762 guardrail_id=request_message.guardrail_id, 5763 ) 5764 return _wrap_response(RemoveGuardrailFromEndpoint.Response()) 5765 5766 5767 @catch_mlflow_exception 5768 @_disable_if_artifacts_only 5769 def _list_endpoint_guardrail_configs(): 5770 request_message = _get_request_message( 5771 ListEndpointGuardrailConfigs(), 5772 schema={"endpoint_id": [_assert_required, _assert_string]}, 5773 ) 5774 configs = _get_tracking_store().list_endpoint_guardrail_configs( 5775 endpoint_id=request_message.endpoint_id, 5776 ) 5777 response_message = ListEndpointGuardrailConfigs.Response() 5778 response_message.configs.extend([c.to_proto() for c in configs]) 5779 return _wrap_response(response_message) 5780 5781 5782 @catch_mlflow_exception 5783 def _update_endpoint_guardrail_config(): 5784 request_message = _get_request_message( 5785 UpdateEndpointGuardrailConfig(), 5786 schema={ 5787 "endpoint_id": [_assert_required, _assert_string], 5788 "guardrail_id": [_assert_required, _assert_string], 5789 }, 5790 ) 5791 kwargs = { 5792 "endpoint_id": request_message.endpoint_id, 5793 "guardrail_id": request_message.guardrail_id, 5794 } 5795 if request_message.HasField("execution_order"): 5796 kwargs["execution_order"] = request_message.execution_order 5797 config = _get_tracking_store().update_endpoint_guardrail_config(**kwargs) 5798 response_message = UpdateEndpointGuardrailConfig.Response() 5799 response_message.config.CopyFrom(config.to_proto()) 5800 return _wrap_response(response_message) 5801 5802 5803 @catch_mlflow_exception 5804 def _get_server_info(): 5805 from mlflow.store.tracking.file_store import FileStore 5806 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore 5807 5808 store = _get_tracking_store() 5809 5810 if isinstance(store, FileStore): 5811 store_type = "FileStore" 5812 elif isinstance(store, SqlAlchemyStore): 5813 store_type = "SqlStore" 5814 else: 5815 store_type = None 5816 return jsonify({ 5817 "store_type": store_type, 5818 "workspaces_enabled": MLFLOW_ENABLE_WORKSPACES.get(), 5819 }) 5820 5821 5822 @catch_mlflow_exception 5823 @_disable_if_artifacts_only 5824 def _list_supported_providers(): 5825 try: 5826 providers = get_all_providers() 5827 return jsonify({"providers": sorted(providers)}) 5828 except ImportError as e: 5829 raise MlflowException(str(e), error_code=INVALID_PARAMETER_VALUE) 5830 5831 5832 @catch_mlflow_exception 5833 @_disable_if_artifacts_only 5834 def _list_supported_models(): 5835 try: 5836 provider_filter = request.args.get("provider") 5837 models = get_models(provider=provider_filter) 5838 return jsonify({"models": models}) 5839 except ImportError as e: 5840 raise MlflowException(str(e), error_code=INVALID_PARAMETER_VALUE) 5841 5842 5843 @catch_mlflow_exception 5844 @_disable_if_artifacts_only 5845 def _get_provider_config(): 5846 try: 5847 provider = request.args.get("provider") 5848 config = get_provider_config_response(provider) 5849 return jsonify(config) 5850 except (ImportError, ValueError) as e: 5851 raise MlflowException(str(e), error_code=INVALID_PARAMETER_VALUE) 5852 5853 5854 @catch_mlflow_exception 5855 @_disable_if_artifacts_only 5856 def _get_secrets_config(): 5857 using_default_passphrase = not os.environ.get(CRYPTO_KEK_PASSPHRASE_ENV_VAR) 5858 return jsonify({ 5859 "secrets_available": True, 5860 "using_default_passphrase": using_default_passphrase, 5861 }) 5862 5863 5864 @catch_mlflow_exception 5865 @_disable_if_artifacts_only 5866 def _invoke_scorer_handler(): 5867 """ 5868 Invoke a scorer on traces asynchronously. 5869 5870 This is a UI-only AJAX endpoint for invoking scorers from the frontend. 5871 """ 5872 _validate_content_type(request, ["application/json"]) 5873 5874 args = request.json 5875 experiment_id = args.get("experiment_id") 5876 serialized_scorer = args.get("serialized_scorer") 5877 trace_ids = args.get("trace_ids", []) 5878 log_assessments = args.get("log_assessments", False) 5879 5880 if not experiment_id: 5881 raise MlflowException( 5882 "Missing required parameter: experiment_id", 5883 error_code=INVALID_PARAMETER_VALUE, 5884 ) 5885 if not serialized_scorer: 5886 raise MlflowException( 5887 "Missing required parameter: serialized_scorer", 5888 error_code=INVALID_PARAMETER_VALUE, 5889 ) 5890 if not trace_ids: 5891 raise MlflowException( 5892 "Please select at least one trace to evaluate.", 5893 error_code=INVALID_PARAMETER_VALUE, 5894 ) 5895 5896 from mlflow.genai.scorers.base import Scorer 5897 from mlflow.genai.scorers.job import get_trace_batches_for_scorer, invoke_scorer_job 5898 from mlflow.server.jobs import submit_job 5899 5900 scorer = Scorer.model_validate_json(serialized_scorer) 5901 5902 tracking_store = _get_tracking_store() 5903 batches = get_trace_batches_for_scorer(trace_ids, scorer, tracking_store) 5904 5905 # Extract the authenticated username so that job subprocesses can make 5906 # gateway requests authorised as the original user (not the admin). 5907 username = request.authorization.username if request.authorization else None 5908 5909 jobs = [] 5910 for batch_trace_ids in batches: 5911 job = submit_job( 5912 function=invoke_scorer_job, 5913 params={ 5914 "experiment_id": experiment_id, 5915 "serialized_scorer": serialized_scorer, 5916 "trace_ids": batch_trace_ids, 5917 "log_assessments": log_assessments, 5918 "username": username, 5919 }, 5920 ) 5921 jobs.append({"job_id": job.job_id, "trace_ids": batch_trace_ids}) 5922 5923 return jsonify({"jobs": jobs}) 5924 5925 5926 def _get_rest_path(base_path, version=2): 5927 return _add_static_prefix(f"/api/{version}.0{base_path}") 5928 5929 5930 def _get_ajax_path(base_path, version=2): 5931 return _add_static_prefix(f"/ajax-api/{version}.0{base_path}") 5932 5933 5934 def _add_static_prefix(route: str) -> str: 5935 if prefix := os.environ.get(STATIC_PREFIX_ENV_VAR): 5936 return prefix.rstrip("/") + route 5937 return route 5938 5939 5940 def _get_paths(base_path, version=2): 5941 """ 5942 A service endpoints base path is typically something like /mlflow/experiment. 5943 We should register paths like /api/2.0/mlflow/experiment and 5944 /ajax-api/2.0/mlflow/experiment in the Flask router. 5945 """ 5946 base_path = _convert_path_parameter_to_flask_format(base_path) 5947 return [_get_rest_path(base_path, version), _get_ajax_path(base_path, version)] 5948 5949 5950 def _convert_path_parameter_to_flask_format(path): 5951 """ 5952 Converts path parameter format to Flask compatible format. 5953 5954 Some protobuf endpoint paths contain parameters like /mlflow/trace/{request_id}. 5955 This can be interpreted correctly by gRPC framework like Armeria, but Flask does 5956 not understand it. Instead, we need to specify it with a different format, 5957 like /mlflow/trace/<request_id>. 5958 """ 5959 # Handle simple parameters like {trace_id} 5960 path = re.sub(r"{(\w+)}", r"<\1>", path) 5961 5962 # Handle Databricks-specific syntax like {assessment.trace_id} -> <trace_id> 5963 # This is needed because Databricks can extract trace_id from request body, 5964 # but Flask needs it in the URL path 5965 return re.sub(r"{assessment\.trace_id}", r"<trace_id>", path) 5966 5967 5968 def get_handler(request_class): 5969 """ 5970 Args: 5971 request_class: The type of protobuf message 5972 """ 5973 return HANDLERS.get(request_class, _not_implemented) 5974 5975 5976 def get_service_endpoints(service, get_handler): 5977 ret = [] 5978 for service_method in service.DESCRIPTOR.methods: 5979 endpoints = service_method.GetOptions().Extensions[databricks_pb2.rpc].endpoints 5980 for endpoint in endpoints: 5981 for http_path in _get_paths(endpoint.path, version=endpoint.since.major): 5982 handler = get_handler(service().GetRequestClass(service_method)) 5983 ret.append((http_path, handler, [endpoint.method])) 5984 return ret 5985 5986 5987 def get_endpoints(get_handler=get_handler): 5988 """ 5989 Returns: 5990 List of tuples (path, handler, methods) 5991 """ 5992 return ( 5993 get_service_endpoints(MlflowService, get_handler) 5994 + get_internal_online_scoring_endpoints() 5995 + get_service_endpoints(ModelRegistryService, get_handler) 5996 + get_service_endpoints(MlflowArtifactsService, get_handler) 5997 + get_service_endpoints(WebhookService, get_handler) 5998 + [(_add_static_prefix("/graphql"), _graphql, ["GET", "POST"])] 5999 # NB: Use _get_paths() so that the endpoint is reachable at both 6000 # <static-prefix>/api/3.0/mlflow/server-info (for the Python client) 6001 # and <static-prefix>/ajax-api/3.0/mlflow/server-info (for the frontend). 6002 + [ 6003 (_path, _get_server_info, ["GET"]) 6004 for _path in _get_paths("/mlflow/server-info", version=3) 6005 ] 6006 + get_gateway_endpoints() 6007 + get_demo_endpoints() 6008 + get_issues_detection_endpoints() 6009 + get_job_endpoints() 6010 ) 6011 6012 6013 def get_gateway_endpoints(): 6014 """Returns endpoint tuples for gateway provider/model discovery APIs and scorer invocation.""" 6015 return [ 6016 ( 6017 _get_ajax_path("/mlflow/gateway/supported-providers", version=3), 6018 _list_supported_providers, 6019 ["GET"], 6020 ), 6021 ( 6022 _get_ajax_path("/mlflow/gateway/supported-models", version=3), 6023 _list_supported_models, 6024 ["GET"], 6025 ), 6026 ( 6027 _get_ajax_path("/mlflow/gateway/provider-config", version=3), 6028 _get_provider_config, 6029 ["GET"], 6030 ), 6031 ( 6032 _get_ajax_path("/mlflow/gateway/secrets/config", version=3), 6033 _get_secrets_config, 6034 ["GET"], 6035 ), 6036 ( 6037 _get_ajax_path("/mlflow/scorer/invoke", version=3), 6038 _invoke_scorer_handler, 6039 ["POST"], 6040 ), 6041 ] 6042 6043 6044 def get_issues_detection_endpoints(): 6045 return [ 6046 ( 6047 _get_ajax_path("/mlflow/issues/invoke", version=3), 6048 _invoke_issue_detection_handler, 6049 ["POST"], 6050 ), 6051 ] 6052 6053 6054 def get_job_endpoints(): 6055 return [ 6056 ( 6057 _get_ajax_path("/mlflow/jobs/cancel/<job_id>", version=3), 6058 _cancel_job, 6059 ["PATCH"], 6060 ), 6061 ( 6062 _get_ajax_path("/mlflow/jobs/<job_id>", version=3), 6063 _get_job, 6064 ["GET"], 6065 ), 6066 ] 6067 6068 6069 # Demo APIs 6070 6071 # Serialize demo generation so concurrent requests (e.g. FastAPI running Flask 6072 # handlers in a thread pool) cannot race on the process-wide MLFLOW_WORKSPACE 6073 # env var that WorkspaceContext temporarily sets during generate_all_demos. 6074 _demo_generate_lock = threading.Lock() 6075 6076 6077 def get_demo_endpoints(): 6078 """Returns endpoint tuples for demo data generation and deletion APIs.""" 6079 return [ 6080 ( 6081 _get_ajax_path("/mlflow/demo/generate", version=3), 6082 _generate_demo, 6083 ["POST"], 6084 ), 6085 ( 6086 _get_ajax_path("/mlflow/demo/delete", version=3), 6087 _delete_demo, 6088 ["POST"], 6089 ), 6090 ] 6091 6092 6093 @catch_mlflow_exception 6094 @_disable_if_artifacts_only 6095 def _generate_demo(): 6096 """Generate demo data for registered demo generators. 6097 6098 Accepts an optional JSON body with a ``features`` list to generate only specific 6099 features (e.g. ``{"features": ["traces", "prompts"]}``). When omitted, all features 6100 are generated. 6101 """ 6102 from mlflow.demo import generate_all_demos 6103 from mlflow.demo.base import DEMO_EXPERIMENT_NAME 6104 from mlflow.demo.registry import demo_registry 6105 6106 request_json = request.get_json(silent=True) or {} 6107 features = request_json.get("features") 6108 6109 store = _get_tracking_store() 6110 experiment = store.get_experiment_by_name(DEMO_EXPERIMENT_NAME) 6111 6112 generator_names = demo_registry.list_generators() 6113 if features is not None: 6114 generator_names = [n for n in generator_names if n in features] 6115 6116 all_exist = False 6117 if experiment and experiment.lifecycle_stage == "active": 6118 all_exist = all(demo_registry.get(name)().is_generated() for name in generator_names) 6119 6120 if experiment and all_exist: 6121 return jsonify({ 6122 "status": "exists", 6123 "experiment_id": experiment.experiment_id, 6124 "features_generated": [], 6125 "navigation_url": f"/experiments/{experiment.experiment_id}", 6126 }) 6127 6128 with _demo_generate_lock: 6129 results = generate_all_demos(features=features) 6130 6131 experiment = store.get_experiment_by_name(DEMO_EXPERIMENT_NAME) 6132 experiment_id = experiment.experiment_id if experiment else None 6133 navigation_url = f"/experiments/{experiment_id}" if experiment_id else "/experiments" 6134 6135 return jsonify({ 6136 "status": "created", 6137 "experiment_id": experiment_id, 6138 "features_generated": [r.feature for r in results], 6139 "navigation_url": navigation_url, 6140 }) 6141 6142 6143 @catch_mlflow_exception 6144 @_disable_if_artifacts_only 6145 def _delete_demo(): 6146 """Delete demo data for all registered demo generators. 6147 6148 Performs a full hard delete of the demo experiment and all associated data, 6149 equivalent to what `mlflow gc` would do. This ensures the demo data is 6150 completely removed rather than just soft-deleted. 6151 """ 6152 from mlflow.demo.base import DEMO_EXPERIMENT_NAME 6153 from mlflow.demo.registry import demo_registry 6154 6155 deleted_features = [] 6156 for name in demo_registry.list_generators(): 6157 generator = demo_registry.get(name)() 6158 if generator._data_exists(): 6159 generator.delete_demo() 6160 deleted_features.append(name) 6161 6162 store = _get_tracking_store() 6163 experiment = store.get_experiment_by_name(DEMO_EXPERIMENT_NAME) 6164 if experiment and experiment.lifecycle_stage == "active": 6165 store.delete_experiment(experiment.experiment_id) 6166 6167 return jsonify({ 6168 "status": "deleted", 6169 "features_deleted": deleted_features, 6170 }) 6171 6172 6173 def get_internal_online_scoring_endpoints(): 6174 """Returns endpoint definitions for internal (non public) online scoring APIs.""" 6175 return [ 6176 ( 6177 _get_ajax_path("/mlflow/scorers/online-configs", version=3), 6178 _get_online_scoring_configs, 6179 ["GET"], 6180 ), 6181 ( 6182 _get_rest_path("/mlflow/scorers/online-configs", version=3), 6183 _get_online_scoring_configs, 6184 ["GET"], 6185 ), 6186 ( 6187 _get_ajax_path("/mlflow/scorers/online-config", version=3), 6188 _upsert_online_scoring_config, 6189 ["PUT"], 6190 ), 6191 ( 6192 _get_rest_path("/mlflow/scorers/online-config", version=3), 6193 _upsert_online_scoring_config, 6194 ["PUT"], 6195 ), 6196 ] 6197 6198 6199 # Evaluation Dataset APIs 6200 6201 6202 @catch_mlflow_exception 6203 @_disable_if_artifacts_only 6204 def _create_dataset_handler(): 6205 request_message = _get_request_message( 6206 CreateDataset(), 6207 schema={ 6208 "name": [_assert_required, _assert_string], 6209 "experiment_ids": [_assert_array], 6210 "tags": [_assert_string], 6211 }, 6212 ) 6213 6214 tags = None 6215 if hasattr(request_message, "tags") and request_message.tags: 6216 tags = json.loads(request_message.tags) 6217 6218 dataset = _get_tracking_store().create_dataset( 6219 name=request_message.name, 6220 experiment_ids=list(request_message.experiment_ids) 6221 if request_message.experiment_ids 6222 else None, 6223 tags=tags, 6224 ) 6225 6226 response_message = CreateDataset.Response() 6227 response_message.dataset.CopyFrom(dataset.to_proto()) 6228 return _wrap_response(response_message) 6229 6230 6231 @catch_mlflow_exception 6232 @_disable_if_artifacts_only 6233 def _get_dataset_handler(dataset_id): 6234 dataset = _get_tracking_store().get_dataset(dataset_id) 6235 6236 response_message = GetDataset.Response() 6237 response_message.dataset.CopyFrom(dataset.to_proto()) 6238 return _wrap_response(response_message) 6239 6240 6241 @catch_mlflow_exception 6242 @_disable_if_artifacts_only 6243 def _delete_dataset_handler(dataset_id): 6244 _get_tracking_store().delete_dataset(dataset_id) 6245 6246 response_message = DeleteDataset.Response() 6247 return _wrap_response(response_message) 6248 6249 6250 @catch_mlflow_exception 6251 @_disable_if_artifacts_only 6252 def _search_evaluation_datasets_handler(): 6253 request_message = _get_request_message( 6254 SearchEvaluationDatasets(), 6255 schema={ 6256 "experiment_ids": [_assert_array], 6257 "filter_string": [_assert_string], 6258 "max_results": [_assert_intlike], 6259 "order_by": [_assert_array], 6260 "page_token": [_assert_string], 6261 }, 6262 ) 6263 6264 datasets = _get_tracking_store().search_datasets( 6265 experiment_ids=list(request_message.experiment_ids) 6266 if request_message.experiment_ids 6267 else None, 6268 filter_string=request_message.filter_string or None, 6269 max_results=request_message.max_results or None, 6270 order_by=list(request_message.order_by) if request_message.order_by else None, 6271 page_token=request_message.page_token or None, 6272 ) 6273 6274 response_message = SearchEvaluationDatasets.Response() 6275 response_message.datasets.extend([d.to_proto() for d in datasets]) 6276 if datasets.token: 6277 response_message.next_page_token = datasets.token 6278 6279 return _wrap_response(response_message) 6280 6281 6282 @catch_mlflow_exception 6283 @_disable_if_artifacts_only 6284 def _set_dataset_tags_handler(dataset_id): 6285 request_message = _get_request_message( 6286 SetDatasetTags(), 6287 schema={ 6288 "tags": [_assert_required, _assert_string], 6289 }, 6290 ) 6291 6292 tags = json.loads(request_message.tags) 6293 6294 _get_tracking_store().set_dataset_tags( 6295 dataset_id=dataset_id, 6296 tags=tags, 6297 ) 6298 6299 response_message = SetDatasetTags.Response() 6300 return _wrap_response(response_message) 6301 6302 6303 @catch_mlflow_exception 6304 @_disable_if_artifacts_only 6305 def _delete_dataset_tag_handler(dataset_id, key): 6306 _get_tracking_store().delete_dataset_tag( 6307 dataset_id=dataset_id, 6308 key=key, 6309 ) 6310 6311 response_message = DeleteDatasetTag.Response() 6312 return _wrap_response(response_message) 6313 6314 6315 @catch_mlflow_exception 6316 @_disable_if_artifacts_only 6317 def _upsert_dataset_records_handler(dataset_id): 6318 request_message = _get_request_message( 6319 UpsertDatasetRecords(), 6320 schema={ 6321 "records": [_assert_required, _assert_string], 6322 }, 6323 ) 6324 6325 records = json.loads(request_message.records) 6326 6327 result = _get_tracking_store().upsert_dataset_records( 6328 dataset_id=dataset_id, 6329 records=records, 6330 ) 6331 6332 response_message = UpsertDatasetRecords.Response() 6333 response_message.inserted_count = result["inserted"] 6334 response_message.updated_count = result["updated"] 6335 6336 return _wrap_response(response_message) 6337 6338 6339 def _get_dataset_experiment_ids_handler(dataset_id): 6340 """ 6341 Get experiment IDs associated with an evaluation dataset. 6342 """ 6343 experiment_ids = _get_tracking_store().get_dataset_experiment_ids(dataset_id=dataset_id) 6344 6345 response_message = GetDatasetExperimentIds.Response() 6346 response_message.experiment_ids.extend(experiment_ids) 6347 6348 return _wrap_response(response_message) 6349 6350 6351 @catch_mlflow_exception 6352 @_disable_if_artifacts_only 6353 def _add_dataset_to_experiments_handler(dataset_id): 6354 request_message = _get_request_message( 6355 AddDatasetToExperiments(), 6356 schema={ 6357 "experiment_ids": [_assert_array], 6358 }, 6359 ) 6360 6361 dataset = _get_tracking_store().add_dataset_to_experiments( 6362 dataset_id=dataset_id, 6363 experiment_ids=request_message.experiment_ids, 6364 ) 6365 6366 response_message = AddDatasetToExperiments.Response() 6367 response_message.dataset.CopyFrom(dataset.to_proto()) 6368 return _wrap_response(response_message) 6369 6370 6371 @catch_mlflow_exception 6372 @_disable_if_artifacts_only 6373 def _remove_dataset_from_experiments_handler(dataset_id): 6374 request_message = _get_request_message( 6375 RemoveDatasetFromExperiments(), 6376 schema={ 6377 "experiment_ids": [_assert_array], 6378 }, 6379 ) 6380 6381 dataset = _get_tracking_store().remove_dataset_from_experiments( 6382 dataset_id=dataset_id, 6383 experiment_ids=request_message.experiment_ids, 6384 ) 6385 6386 response_message = RemoveDatasetFromExperiments.Response() 6387 response_message.dataset.CopyFrom(dataset.to_proto()) 6388 return _wrap_response(response_message) 6389 6390 6391 def _get_dataset_records_handler(dataset_id): 6392 request_message = _get_request_message( 6393 GetDatasetRecords(), 6394 schema={ 6395 "max_results": [_assert_intlike], 6396 "page_token": [_assert_string], 6397 }, 6398 ) 6399 6400 max_results = request_message.max_results or 1000 6401 page_token = request_message.page_token or None 6402 6403 # Use the pagination-aware method 6404 records, next_page_token = _get_tracking_store()._load_dataset_records( 6405 dataset_id, max_results=max_results, page_token=page_token 6406 ) 6407 6408 response_message = GetDatasetRecords.Response() 6409 6410 records_dicts = [record.to_dict() for record in records] 6411 response_message.records = json.dumps(records_dicts) 6412 6413 if next_page_token: 6414 response_message.next_page_token = next_page_token 6415 6416 return _wrap_response(response_message) 6417 6418 6419 @catch_mlflow_exception 6420 @_disable_if_artifacts_only 6421 def _delete_dataset_records_handler(dataset_id): 6422 request_message = _get_request_message( 6423 DeleteDatasetRecords(), 6424 schema={ 6425 "dataset_record_ids": [_assert_array], 6426 }, 6427 ) 6428 6429 deleted_count = _get_tracking_store().delete_dataset_records( 6430 dataset_id=dataset_id, 6431 dataset_record_ids=list(request_message.dataset_record_ids), 6432 ) 6433 6434 response_message = DeleteDatasetRecords.Response() 6435 response_message.deleted_count = deleted_count 6436 return _wrap_response(response_message) 6437 6438 6439 # Cache for telemetry config with 3 hour TTL 6440 _telemetry_config_cache = TTLCache(maxsize=1, ttl=10800) 6441 6442 6443 def _get_or_fetch_ui_telemetry_config(): 6444 if (config := _telemetry_config_cache.get("config")) is None: 6445 config = fetch_ui_telemetry_config() 6446 _telemetry_config_cache["config"] = config 6447 return config 6448 6449 6450 @catch_mlflow_exception 6451 def get_ui_telemetry_handler(): 6452 """ 6453 GET handler for /telemetry endpoint. 6454 Returns the telemetry client configuration by fetching it directly. 6455 """ 6456 if is_telemetry_disabled(): 6457 return jsonify(FALLBACK_UI_CONFIG) 6458 6459 config = _get_or_fetch_ui_telemetry_config() 6460 6461 # UI telemetry should be also disabled if overall telemetry is disabled 6462 disable_ui_telemetry = config.get("disable_ui_telemetry", True) or config.get( 6463 "disable_telemetry", True 6464 ) 6465 response = { 6466 "disable_ui_telemetry": disable_ui_telemetry, 6467 "disable_ui_events": config.get("disable_ui_events", []), 6468 "ui_rollout_percentage": config.get("ui_rollout_percentage", 0), 6469 } 6470 return jsonify(response) 6471 6472 6473 @catch_mlflow_exception 6474 def post_ui_telemetry_handler(): 6475 """ 6476 POST handler for /telemetry endpoint. 6477 Accepts telemetry records and adds them to the telemetry client. 6478 """ 6479 try: 6480 if is_telemetry_disabled(): 6481 return jsonify({"status": "disabled"}) 6482 6483 data = request.json.get("records", []) 6484 6485 if not data: 6486 return jsonify({"status": "success"}) 6487 6488 if (client := get_telemetry_client()) is None: 6489 return jsonify({"status": "disabled"}) 6490 6491 # check cached config to see if telemetry is disabled 6492 # if so, don't process the records. we don't rely on the 6493 # config from the telemetry client because it is only fetched 6494 # once, so it won't be updated unless the server is restarted. 6495 config = _get_or_fetch_ui_telemetry_config() 6496 6497 # if updated telemetry config is disabled / missing, tell the UI to stop sending records 6498 if config.get("disable_ui_telemetry", True) or config.get("disable_telemetry", True): 6499 return jsonify({"status": "disabled"}) 6500 6501 server_installation_id = get_or_create_installation_id() 6502 records = [ 6503 Record( 6504 event_name=event["event_name"], 6505 timestamp_ns=event["timestamp_ns"], 6506 params=event["params"], 6507 status=Status.SUCCESS, 6508 installation_id=event["installation_id"], 6509 session_id=event["session_id"], 6510 server_installation_id=server_installation_id, 6511 duration_ms=0, 6512 ) 6513 for event in data 6514 ] 6515 6516 client.add_records(records) 6517 6518 return jsonify({"status": "success"}) 6519 except Exception as e: 6520 _logger.debug(f"Failed to process UI telemetry records: {e}") 6521 # if we run into unexpected errors, likely something is wrong 6522 # with the data format. if we return success, the UI will continue 6523 # to send records. if we return an error, the UI will retry sending 6524 # records. the safest thing to do is to tell the UI to stop sending 6525 return jsonify({"status": "disabled"}) 6526 6527 6528 def _parse_prompt_uri(prompt_uri: str) -> tuple[str, str]: 6529 """ 6530 Parse a prompt URI to extract the prompt name and version. 6531 6532 Args: 6533 prompt_uri: Prompt URI in the format "prompts:/prompt_name/version" 6534 6535 Returns: 6536 A tuple of (prompt_name, version). Returns empty strings if parsing fails. 6537 """ 6538 try: 6539 # Format: "prompts:/prompt_name/version" 6540 if prompt_uri.startswith("prompts:/"): 6541 parts = prompt_uri.replace("prompts:/", "").split("/") 6542 if len(parts) >= 2: 6543 return parts[0], parts[1] 6544 except Exception: 6545 pass 6546 return "", "" 6547 6548 6549 @catch_mlflow_exception 6550 @_disable_if_artifacts_only 6551 def _create_prompt_optimization_job(): 6552 # These imports must be local to avoid circular import with mlflow.server.jobs 6553 from mlflow.genai.datasets import get_dataset as get_genai_dataset 6554 from mlflow.genai.optimize.job import OptimizerType, optimize_prompts_job 6555 from mlflow.server.jobs import submit_job 6556 6557 request_message = _get_request_message( 6558 CreatePromptOptimizationJob(), 6559 schema={ 6560 "experiment_id": [_assert_string], 6561 "source_prompt_uri": [_assert_string, _assert_required], 6562 "config": [_assert_required], 6563 "tags": [_assert_array], 6564 }, 6565 ) 6566 6567 prompt_uri = request_message.source_prompt_uri or "" 6568 if not prompt_uri: 6569 raise MlflowException( 6570 "source_prompt_uri is required for optimization job", 6571 error_code=INVALID_PARAMETER_VALUE, 6572 ) 6573 6574 config = request_message.config 6575 dataset_id = config.dataset_id or "" 6576 6577 scorers = list(config.scorers) if config.scorers else [] 6578 6579 optimizer_type = OptimizerType.from_proto(config.optimizer_type) 6580 6581 experiment_id = (request_message.experiment_id or "").strip() 6582 if not experiment_id: 6583 raise MlflowException( 6584 "experiment_id is required for optimization job", 6585 error_code=INVALID_PARAMETER_VALUE, 6586 ) 6587 6588 # Parse optimizer_config_json to dict for the job function 6589 # Validate before creating run to avoid creating unused runs on validation failure 6590 optimizer_config = None 6591 if config.optimizer_config_json: 6592 try: 6593 optimizer_config = json.loads(config.optimizer_config_json) 6594 except json.JSONDecodeError as e: 6595 raise MlflowException( 6596 f"Invalid JSON in optimizer_config_json: {e}", 6597 error_code=INVALID_PARAMETER_VALUE, 6598 ) 6599 6600 # Create MLflow run upfront so run_id is immediately available 6601 # The job will resume this run when it starts executing 6602 tracking_store = _get_tracking_store() 6603 start_time = int(time.time() * 1000) 6604 6605 # Parse prompt name and version from URI for more descriptive run name 6606 prompt_name, prompt_version = _parse_prompt_uri(prompt_uri) 6607 run_name = f"optimize_prompt_{optimizer_type}_{prompt_name}_{prompt_version}_{start_time}" 6608 6609 run = tracking_store.create_run( 6610 experiment_id=experiment_id, 6611 user_id=_get_user(), 6612 start_time=start_time, 6613 tags=[], 6614 run_name=run_name, 6615 ) 6616 run_id = run.info.run_id 6617 6618 # Log optimization config as run parameters 6619 params_to_log = [ 6620 Param("source_prompt_uri", prompt_uri), 6621 Param("optimizer_type", optimizer_type), 6622 Param("dataset_id", dataset_id), 6623 Param("scorer_names", json.dumps(scorers)), 6624 ] 6625 if config.optimizer_config_json: 6626 params_to_log.append(Param("optimizer_config_json", config.optimizer_config_json)) 6627 tracking_store.log_batch(run_id=run_id, metrics=[], params=params_to_log, tags=[]) 6628 6629 # Link the evaluation dataset to the run for lineage tracking (if dataset_id is provided) 6630 if dataset_id: 6631 dataset = get_genai_dataset(dataset_id=dataset_id) 6632 dataset_input = DatasetInput( 6633 dataset=dataset._to_mlflow_entity(), 6634 tags=[InputTag(key="mlflow.data.context", value="optimization")], 6635 ) 6636 tracking_store.log_inputs(run_id=run_id, datasets=[dataset_input]) 6637 6638 params = { 6639 "run_id": run_id, 6640 "experiment_id": experiment_id, 6641 "prompt_uri": prompt_uri, 6642 "dataset_id": dataset_id, 6643 "optimizer_type": optimizer_type, 6644 "optimizer_config": optimizer_config, 6645 "scorer_names": scorers, 6646 } 6647 6648 job_entity = submit_job(optimize_prompts_job, params) 6649 6650 response_message = CreatePromptOptimizationJob.Response() 6651 optimization_job = PromptOptimizationJobProto() 6652 optimization_job.job_id = job_entity.job_id 6653 optimization_job.run_id = run_id 6654 optimization_job.state.status = JobStatus.JOB_STATUS_PENDING 6655 optimization_job.creation_timestamp_ms = job_entity.creation_time 6656 optimization_job.experiment_id = experiment_id 6657 optimization_job.config.CopyFrom(config) 6658 optimization_job.source_prompt_uri = prompt_uri 6659 6660 for tag in request_message.tags: 6661 job_tag = optimization_job.tags.add() 6662 job_tag.key = tag.key 6663 job_tag.value = tag.value 6664 6665 response_message.job.CopyFrom(optimization_job) 6666 return _wrap_response(response_message) 6667 6668 6669 def _build_prompt_optimization_job_from_entity(job_entity): 6670 from mlflow.genai.optimize.job import OptimizerType 6671 6672 optimization_job = PromptOptimizationJobProto() 6673 optimization_job.job_id = job_entity.job_id 6674 optimization_job.state.status = job_entity.status.to_proto() 6675 optimization_job.creation_timestamp_ms = job_entity.creation_time 6676 6677 params = json.loads(job_entity.params) 6678 if "experiment_id" in params: 6679 optimization_job.experiment_id = params["experiment_id"] 6680 if "prompt_uri" in params: 6681 optimization_job.source_prompt_uri = params["prompt_uri"] 6682 6683 if run_id := params.get("run_id"): 6684 optimization_job.run_id = run_id 6685 6686 # Populate config from job params 6687 config = optimization_job.config 6688 if "optimizer_type" in params: 6689 try: 6690 optimizer_type = OptimizerType(params["optimizer_type"]) 6691 config.optimizer_type = optimizer_type.to_proto() 6692 except (ValueError, KeyError): 6693 pass 6694 if params.get("dataset_id"): 6695 config.dataset_id = params["dataset_id"] 6696 if "scorer_names" in params: 6697 try: 6698 scorer_names = params["scorer_names"] 6699 if isinstance(scorer_names, str): 6700 scorer_names = json.loads(scorer_names) 6701 if isinstance(scorer_names, list): 6702 config.scorers.extend(scorer_names) 6703 except (json.JSONDecodeError, TypeError): 6704 pass 6705 if params.get("optimizer_config"): 6706 optimizer_config = params["optimizer_config"] 6707 if isinstance(optimizer_config, dict): 6708 config.optimizer_config_json = json.dumps(optimizer_config) 6709 elif isinstance(optimizer_config, str): 6710 config.optimizer_config_json = optimizer_config 6711 6712 # Get optimized_prompt_uri from job result (only available when job succeeds) 6713 if job_entity.status.name == "SUCCEEDED" and job_entity.parsed_result: 6714 result = job_entity.parsed_result 6715 if isinstance(result, dict) and result.get("optimized_prompt_uri"): 6716 optimization_job.optimized_prompt_uri = result["optimized_prompt_uri"] 6717 6718 # If job failed, add error message to state 6719 if job_entity.status.name == "FAILED" and job_entity.parsed_result: 6720 optimization_job.state.error_message = str(job_entity.parsed_result) 6721 6722 return optimization_job 6723 6724 6725 @catch_mlflow_exception 6726 @_disable_if_artifacts_only 6727 def _get_prompt_optimization_job(job_id): 6728 from mlflow.server.jobs import get_job 6729 6730 job_entity = get_job(job_id) 6731 optimization_job = _build_prompt_optimization_job_from_entity(job_entity) 6732 6733 # Fetch MLflow run to get evaluation scores from metrics 6734 try: 6735 mlflow_run = _get_tracking_store().get_run(optimization_job.run_id) 6736 run_metrics = mlflow_run.data.metrics 6737 6738 # Populate evaluation scores from run metrics 6739 # Aggregated scores are logged as "initial_eval_score" and "final_eval_score" 6740 # Per-scorer scores are logged as "initial_eval_score.<scorer_name>" and 6741 # "final_eval_score.<scorer_name>" 6742 total_metric_calls = None 6743 for metric_name, metric_value in run_metrics.items(): 6744 match metric_name.split(".", 1): 6745 case ["initial_eval_score"]: 6746 optimization_job.initial_eval_scores["aggregate"] = metric_value 6747 case ["final_eval_score"]: 6748 optimization_job.final_eval_scores["aggregate"] = metric_value 6749 case ["initial_eval_score", scorer_name]: 6750 optimization_job.initial_eval_scores[scorer_name] = metric_value 6751 case ["final_eval_score", scorer_name]: 6752 optimization_job.final_eval_scores[scorer_name] = metric_value 6753 case ["total_metric_calls"]: 6754 total_metric_calls = metric_value 6755 6756 if total_metric_calls is not None: 6757 params = json.loads(job_entity.params) 6758 optimizer_config = params.get("optimizer_config", {}) 6759 if max_metric_calls := optimizer_config.get("max_metric_calls"): 6760 progress = round(min(total_metric_calls / max_metric_calls, 1.0), 2) 6761 optimization_job.state.metadata["progress"] = str(progress) 6762 6763 except Exception as e: 6764 _logger.debug("Failed to fetch run details for optimization job %s: %s", job_id, e) 6765 6766 response_message = GetPromptOptimizationJob.Response() 6767 response_message.job.CopyFrom(optimization_job) 6768 return _wrap_response(response_message) 6769 6770 6771 @catch_mlflow_exception 6772 @_disable_if_artifacts_only 6773 def _search_prompt_optimization_jobs(): 6774 request_message = _get_request_message( 6775 SearchPromptOptimizationJobs(), 6776 schema={ 6777 "experiment_id": [_assert_required, _assert_string], 6778 }, 6779 ) 6780 6781 job_store = _get_job_store() 6782 6783 # Search for optimize_prompts jobs in the specified experiment 6784 jobs = job_store.list_jobs( 6785 job_name="optimize_prompts", 6786 params={"experiment_id": request_message.experiment_id}, 6787 ) 6788 6789 response_message = SearchPromptOptimizationJobs.Response() 6790 6791 for job_entity in jobs: 6792 optimization_job = _build_prompt_optimization_job_from_entity(job_entity) 6793 response_message.jobs.append(optimization_job) 6794 6795 return _wrap_response(response_message) 6796 6797 6798 @catch_mlflow_exception 6799 @_disable_if_artifacts_only 6800 def _cancel_prompt_optimization_job(job_id): 6801 # This import must be local to avoid circular import with mlflow.server.jobs 6802 from mlflow.server.jobs import cancel_job 6803 6804 job_entity = cancel_job(job_id) 6805 optimization_job = _build_prompt_optimization_job_from_entity(job_entity) 6806 # Override status to CANCELED since cancel_job may not update the entity status immediately 6807 optimization_job.state.status = JobStatus.JOB_STATUS_CANCELED 6808 6809 # Terminate the underlying MLflow run if it exists 6810 if optimization_job.run_id: 6811 try: 6812 _get_tracking_store().update_run_info( 6813 run_id=optimization_job.run_id, 6814 run_status=RunStatus.KILLED, 6815 end_time=get_current_time_millis(), 6816 run_name=None, 6817 ) 6818 except Exception: 6819 # If the run doesn't exist or is already terminated, log warning and continue 6820 _logger.warning( 6821 "Failed to terminate MLflow run '%s' when canceling job '%s'", 6822 optimization_job.run_id, 6823 job_id, 6824 ) 6825 6826 response_message = CancelPromptOptimizationJob.Response() 6827 response_message.job.CopyFrom(optimization_job) 6828 return _wrap_response(response_message) 6829 6830 6831 @catch_mlflow_exception 6832 @_disable_if_artifacts_only 6833 def _delete_prompt_optimization_job(job_id): 6834 job_store = _get_job_store() 6835 job_entity = job_store.get_job(job_id) 6836 optimization_job = _build_prompt_optimization_job_from_entity(job_entity) 6837 run_id = optimization_job.run_id 6838 6839 job_store.delete_jobs(job_ids=[job_id]) 6840 6841 # Delete the associated MLflow run if it exists. 6842 # Check if run exists before attempting deletion - user may have 6843 # deleted it manually before the job deletion request. 6844 if run_id: 6845 try: 6846 _get_tracking_store().get_run(run_id) 6847 _get_tracking_store().delete_run(run_id) 6848 except MlflowException: 6849 pass 6850 6851 response_message = DeletePromptOptimizationJob.Response() 6852 return _wrap_response(response_message) 6853 6854 6855 HANDLERS = { 6856 # Tracking Server APIs 6857 CreateExperiment: _create_experiment, 6858 GetExperiment: _get_experiment, 6859 GetExperimentByName: _get_experiment_by_name, 6860 DeleteExperiment: _delete_experiment, 6861 RestoreExperiment: _restore_experiment, 6862 UpdateExperiment: _update_experiment, 6863 CreateRun: _create_run, 6864 UpdateRun: _update_run, 6865 DeleteRun: _delete_run, 6866 RestoreRun: _restore_run, 6867 LogParam: _log_param, 6868 LogMetric: _log_metric, 6869 SetExperimentTag: _set_experiment_tag, 6870 DeleteExperimentTag: _delete_experiment_tag, 6871 SetTag: _set_tag, 6872 DeleteTag: _delete_tag, 6873 LogBatch: _log_batch, 6874 LogModel: _log_model, 6875 GetRun: _get_run, 6876 SearchRuns: _search_runs, 6877 ListArtifacts: _list_artifacts, 6878 CreatePresignedUploadUrl: _create_presigned_upload_url, 6879 GetMetricHistory: _get_metric_history, 6880 GetMetricHistoryBulkInterval: get_metric_history_bulk_interval_handler, 6881 SearchExperiments: _search_experiments, 6882 LogInputs: _log_inputs, 6883 LogOutputs: _log_outputs, 6884 # Evaluation Dataset APIs 6885 CreateDataset: _create_dataset_handler, 6886 GetDataset: _get_dataset_handler, 6887 DeleteDataset: _delete_dataset_handler, 6888 SearchEvaluationDatasets: _search_evaluation_datasets_handler, 6889 SetDatasetTags: _set_dataset_tags_handler, 6890 DeleteDatasetTag: _delete_dataset_tag_handler, 6891 UpsertDatasetRecords: _upsert_dataset_records_handler, 6892 GetDatasetExperimentIds: _get_dataset_experiment_ids_handler, 6893 GetDatasetRecords: _get_dataset_records_handler, 6894 DeleteDatasetRecords: _delete_dataset_records_handler, 6895 AddDatasetToExperiments: _add_dataset_to_experiments_handler, 6896 RemoveDatasetFromExperiments: _remove_dataset_from_experiments_handler, 6897 # Model Registry APIs 6898 CreateRegisteredModel: _create_registered_model, 6899 GetRegisteredModel: _get_registered_model, 6900 DeleteRegisteredModel: _delete_registered_model, 6901 UpdateRegisteredModel: _update_registered_model, 6902 RenameRegisteredModel: _rename_registered_model, 6903 SearchRegisteredModels: _search_registered_models, 6904 GetLatestVersions: _get_latest_versions, 6905 CreateModelVersion: _create_model_version, 6906 GetModelVersion: _get_model_version, 6907 DeleteModelVersion: _delete_model_version, 6908 UpdateModelVersion: _update_model_version, 6909 TransitionModelVersionStage: _transition_stage, 6910 GetModelVersionDownloadUri: _get_model_version_download_uri, 6911 SearchModelVersions: _search_model_versions, 6912 SetRegisteredModelTag: _set_registered_model_tag, 6913 DeleteRegisteredModelTag: _delete_registered_model_tag, 6914 SetModelVersionTag: _set_model_version_tag, 6915 DeleteModelVersionTag: _delete_model_version_tag, 6916 SetRegisteredModelAlias: _set_registered_model_alias, 6917 DeleteRegisteredModelAlias: _delete_registered_model_alias, 6918 GetModelVersionByAlias: _get_model_version_by_alias, 6919 # Webhook APIs 6920 CreateWebhook: _create_webhook, 6921 ListWebhooks: _list_webhooks, 6922 GetWebhook: _get_webhook, 6923 UpdateWebhook: _update_webhook, 6924 DeleteWebhook: _delete_webhook, 6925 TestWebhook: _test_webhook, 6926 # MLflow Artifacts APIs 6927 DownloadArtifact: _download_artifact, 6928 UploadArtifact: _upload_artifact, 6929 ListArtifactsMlflowArtifacts: _list_artifacts_mlflow_artifacts, 6930 DeleteArtifact: _delete_artifact_mlflow_artifacts, 6931 CreateMultipartUpload: _create_multipart_upload_artifact, 6932 CompleteMultipartUpload: _complete_multipart_upload_artifact, 6933 AbortMultipartUpload: _abort_multipart_upload_artifact, 6934 GetPresignedDownloadUrl: _get_presigned_download_url, 6935 # MLflow Tracing APIs (V3) 6936 StartTraceV3: _start_trace_v3, 6937 GetTraceInfoV3: _get_trace_info_v3, 6938 SearchTracesV3: _search_traces_v3, 6939 DeleteTracesV3: _delete_traces, 6940 CalculateTraceFilterCorrelation: _calculate_trace_filter_correlation, 6941 SetTraceTagV3: _set_trace_tag_v3, 6942 DeleteTraceTagV3: _delete_trace_tag_v3, 6943 LinkTracesToRun: _link_traces_to_run, 6944 LinkPromptsToTrace: _link_prompts_to_trace, 6945 BatchGetTraces: _batch_get_traces, 6946 BatchGetTraceInfos: _batch_get_trace_infos, 6947 GetTrace: _get_trace, 6948 QueryTraceMetrics: _query_trace_metrics, 6949 # Assessment APIs 6950 CreateAssessment: _create_assessment, 6951 GetAssessmentRequest: _get_assessment, 6952 UpdateAssessment: _update_assessment, 6953 DeleteAssessment: _delete_assessment, 6954 # Issue APIs 6955 CreateIssue: _create_issue, 6956 UpdateIssue: _update_issue, 6957 GetIssue: _get_issue, 6958 SearchIssues: _search_issues, 6959 # Legacy MLflow Tracing V2 APIs. Kept for backward compatibility but do not use. 6960 StartTrace: _deprecated_start_trace_v2, 6961 EndTrace: _deprecated_end_trace_v2, 6962 GetTraceInfo: _deprecated_get_trace_info_v2, 6963 SearchTraces: _deprecated_search_traces_v2, 6964 DeleteTraces: _delete_traces, 6965 SetTraceTag: _set_trace_tag, 6966 DeleteTraceTag: _delete_trace_tag, 6967 # Logged Models APIs 6968 CreateLoggedModel: _create_logged_model, 6969 GetLoggedModel: _get_logged_model, 6970 FinalizeLoggedModel: _finalize_logged_model, 6971 DeleteLoggedModel: _delete_logged_model, 6972 SetLoggedModelTags: _set_logged_model_tags, 6973 DeleteLoggedModelTag: _delete_logged_model_tag, 6974 SearchLoggedModels: _search_logged_models, 6975 ListLoggedModelArtifacts: _list_logged_model_artifacts, 6976 LogLoggedModelParamsRequest: _log_logged_model_params, 6977 # Scorer APIs 6978 RegisterScorer: _register_scorer, 6979 ListScorers: _list_scorers, 6980 ListScorerVersions: _list_scorer_versions, 6981 GetScorer: _get_scorer, 6982 DeleteScorer: _delete_scorer, 6983 # Secrets APIs 6984 CreateGatewaySecret: _create_gateway_secret, 6985 GetGatewaySecretInfo: _get_gateway_secret_info, 6986 UpdateGatewaySecret: _update_gateway_secret, 6987 DeleteGatewaySecret: _delete_gateway_secret, 6988 ListGatewaySecretInfos: _list_gateway_secrets, 6989 # Endpoints APIs 6990 CreateGatewayEndpoint: _create_gateway_endpoint, 6991 GetGatewayEndpoint: _get_gateway_endpoint, 6992 UpdateGatewayEndpoint: _update_gateway_endpoint, 6993 DeleteGatewayEndpoint: _delete_gateway_endpoint, 6994 ListGatewayEndpoints: _list_gateway_endpoints, 6995 # Model Definitions APIs 6996 CreateGatewayModelDefinition: _create_gateway_model_definition, 6997 GetGatewayModelDefinition: _get_gateway_model_definition, 6998 ListGatewayModelDefinitions: _list_gateway_model_definitions, 6999 UpdateGatewayModelDefinition: _update_gateway_model_definition, 7000 DeleteGatewayModelDefinition: _delete_gateway_model_definition, 7001 # Endpoint Model Mappings APIs 7002 AttachModelToGatewayEndpoint: _attach_model_to_gateway_endpoint, 7003 DetachModelFromGatewayEndpoint: _detach_model_from_gateway_endpoint, 7004 # Endpoint Bindings APIs 7005 CreateGatewayEndpointBinding: _create_gateway_endpoint_binding, 7006 DeleteGatewayEndpointBinding: _delete_gateway_endpoint_binding, 7007 ListGatewayEndpointBindings: _list_gateway_endpoint_bindings, 7008 # Endpoint Tags APIs 7009 SetGatewayEndpointTag: _set_gateway_endpoint_tag, 7010 DeleteGatewayEndpointTag: _delete_gateway_endpoint_tag, 7011 # Budget Policy APIs 7012 CreateGatewayBudgetPolicy: _create_budget_policy, 7013 GetGatewayBudgetPolicy: _get_budget_policy, 7014 UpdateGatewayBudgetPolicy: _update_budget_policy, 7015 DeleteGatewayBudgetPolicy: _delete_budget_policy, 7016 ListGatewayBudgetPolicies: _list_budget_policies, 7017 ListGatewayBudgetWindows: _list_budget_windows, 7018 # Guardrail APIs 7019 CreateGatewayGuardrail: _create_gateway_guardrail, 7020 GetGatewayGuardrail: _get_gateway_guardrail, 7021 DeleteGatewayGuardrail: _delete_gateway_guardrail, 7022 ListGatewayGuardrails: _list_gateway_guardrails, 7023 AddGuardrailToEndpoint: _add_guardrail_to_endpoint, 7024 RemoveGuardrailFromEndpoint: _remove_guardrail_from_endpoint, 7025 ListEndpointGuardrailConfigs: _list_endpoint_guardrail_configs, 7026 UpdateEndpointGuardrailConfig: _update_endpoint_guardrail_config, 7027 # Prompt Optimization APIs 7028 CreatePromptOptimizationJob: _create_prompt_optimization_job, 7029 GetPromptOptimizationJob: _get_prompt_optimization_job, 7030 SearchPromptOptimizationJobs: _search_prompt_optimization_jobs, 7031 CancelPromptOptimizationJob: _cancel_prompt_optimization_job, 7032 DeletePromptOptimizationJob: _delete_prompt_optimization_job, 7033 # Workspace APIs 7034 ListWorkspaces: _list_workspaces_handler, 7035 CreateWorkspace: _create_workspace_handler, 7036 GetWorkspace: _get_workspace_handler, 7037 UpdateWorkspace: _update_workspace_handler, 7038 DeleteWorkspace: _delete_workspace_handler, 7039 }