/ mlflow / server / handlers.py
handlers.py
   1  # Define all the service endpoint handlers here.
   2  import io
   3  import json
   4  import logging
   5  import os
   6  import pathlib
   7  import posixpath
   8  import re
   9  import tempfile
  10  import threading
  11  import time
  12  import urllib
  13  from functools import partial, wraps
  14  from typing import Any, Callable
  15  
  16  import requests
  17  from cachetools import TTLCache
  18  from flask import Request, Response, current_app, jsonify, request, send_file
  19  from google.protobuf import descriptor
  20  from google.protobuf.json_format import ParseError
  21  
  22  import mlflow
  23  from mlflow.entities import (
  24      Assessment,
  25      DatasetInput,
  26      Expectation,
  27      ExperimentTag,
  28      FallbackConfig,
  29      FallbackStrategy,
  30      Feedback,
  31      FileInfo,
  32      GatewayEndpointModelConfig,
  33      GatewayEndpointTag,
  34      GatewayResourceType,
  35      InputTag,
  36      IssueSeverity,
  37      IssueStatus,
  38      Metric,
  39      Param,
  40      RunStatus,
  41      RunTag,
  42      ViewType,
  43      Workspace,
  44      WorkspaceDeletionMode,
  45  )
  46  from mlflow.entities import (
  47      RoutingStrategy as RoutingStrategyEntity,
  48  )
  49  from mlflow.entities.gateway_budget_policy import (
  50      BudgetAction,
  51      BudgetDuration,
  52      BudgetDurationUnit,
  53      BudgetTargetScope,
  54      BudgetUnit,
  55  )
  56  from mlflow.entities.logged_model import LoggedModel
  57  from mlflow.entities.logged_model_input import LoggedModelInput
  58  from mlflow.entities.logged_model_output import LoggedModelOutput
  59  from mlflow.entities.logged_model_parameter import LoggedModelParameter
  60  from mlflow.entities.logged_model_status import LoggedModelStatus
  61  from mlflow.entities.logged_model_tag import LoggedModelTag
  62  from mlflow.entities.model_registry import ModelVersionTag, RegisteredModelTag
  63  from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
  64  from mlflow.entities.multipart_upload import MultipartUploadPart
  65  from mlflow.entities.trace_info import TraceInfo
  66  from mlflow.entities.trace_info_v2 import TraceInfoV2
  67  from mlflow.entities.trace_metrics import MetricAggregation, MetricViewType
  68  from mlflow.entities.trace_status import TraceStatus
  69  from mlflow.entities.webhook import WebhookAction, WebhookEntity, WebhookEvent, WebhookStatus
  70  from mlflow.environment_variables import (
  71      MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX,
  72      MLFLOW_DEPLOYMENTS_TARGET,
  73      MLFLOW_ENABLE_WORKSPACES,
  74      MLFLOW_PRESIGNED_DOWNLOAD_URL_TTL_SECONDS,
  75  )
  76  from mlflow.exceptions import (
  77      MlflowException,
  78      MlflowNotImplementedException,
  79      MlflowTracingException,
  80      _UnsupportedMultipartDownloadException,
  81      _UnsupportedMultipartUploadException,
  82      _UnsupportedPresignedUploadException,
  83  )
  84  from mlflow.gateway.budget import maybe_refresh_budget_policies
  85  from mlflow.gateway.budget_tracker import get_budget_tracker
  86  from mlflow.gateway.utils import is_valid_endpoint_name
  87  from mlflow.genai.scorers.scorer_utils import DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR
  88  from mlflow.models import Model
  89  from mlflow.prompt.constants import PROMPT_TEXT_TAG_KEY, PROMPT_TYPE_TAG_KEY
  90  from mlflow.protos import databricks_pb2
  91  from mlflow.protos.databricks_pb2 import (
  92      BAD_REQUEST,
  93      FEATURE_DISABLED,
  94      INTERNAL_ERROR,
  95      INVALID_PARAMETER_VALUE,
  96      INVALID_STATE,
  97      RESOURCE_DOES_NOT_EXIST,
  98  )
  99  from mlflow.protos.issues_pb2 import (
 100      CreateIssue,
 101      GetIssue,
 102      SearchIssues,
 103      UpdateIssue,
 104  )
 105  from mlflow.protos.jobs_pb2 import JobStatus
 106  from mlflow.protos.mlflow_artifacts_pb2 import (
 107      AbortMultipartUpload,
 108      CompleteMultipartUpload,
 109      CreateMultipartUpload,
 110      DeleteArtifact,
 111      DownloadArtifact,
 112      GetPresignedDownloadUrl,
 113      MlflowArtifactsService,
 114      UploadArtifact,
 115  )
 116  from mlflow.protos.mlflow_artifacts_pb2 import (
 117      ListArtifacts as ListArtifactsMlflowArtifacts,
 118  )
 119  from mlflow.protos.model_registry_pb2 import (
 120      CreateModelVersion,
 121      CreateRegisteredModel,
 122      DeleteModelVersion,
 123      DeleteModelVersionTag,
 124      DeleteRegisteredModel,
 125      DeleteRegisteredModelAlias,
 126      DeleteRegisteredModelTag,
 127      GetLatestVersions,
 128      GetModelVersion,
 129      GetModelVersionByAlias,
 130      GetModelVersionDownloadUri,
 131      GetRegisteredModel,
 132      ModelRegistryService,
 133      RenameRegisteredModel,
 134      SearchModelVersions,
 135      SearchRegisteredModels,
 136      SetModelVersionTag,
 137      SetRegisteredModelAlias,
 138      SetRegisteredModelTag,
 139      TransitionModelVersionStage,
 140      UpdateModelVersion,
 141      UpdateRegisteredModel,
 142  )
 143  from mlflow.protos.prompt_optimization_pb2 import (
 144      PromptOptimizationJob as PromptOptimizationJobProto,
 145  )
 146  from mlflow.protos.service_pb2 import (
 147      AddDatasetToExperiments,
 148      AddGuardrailToEndpoint,
 149      AttachModelToGatewayEndpoint,
 150      BatchGetTraceInfos,
 151      BatchGetTraces,
 152      CalculateTraceFilterCorrelation,
 153      CancelPromptOptimizationJob,
 154      CreateAssessment,
 155      CreateDataset,
 156      CreateExperiment,
 157      CreateGatewayBudgetPolicy,
 158      CreateGatewayEndpoint,
 159      CreateGatewayEndpointBinding,
 160      CreateGatewayGuardrail,
 161      CreateGatewayModelDefinition,
 162      CreateGatewaySecret,
 163      CreateLoggedModel,
 164      CreatePresignedUploadUrl,
 165      CreatePromptOptimizationJob,
 166      CreateRun,
 167      CreateWorkspace,
 168      DeleteAssessment,
 169      DeleteDataset,
 170      DeleteDatasetRecords,
 171      DeleteDatasetTag,
 172      DeleteExperiment,
 173      DeleteExperimentTag,
 174      DeleteGatewayBudgetPolicy,
 175      DeleteGatewayEndpoint,
 176      DeleteGatewayEndpointBinding,
 177      DeleteGatewayEndpointTag,
 178      DeleteGatewayGuardrail,
 179      DeleteGatewayModelDefinition,
 180      DeleteGatewaySecret,
 181      DeleteLoggedModel,
 182      DeleteLoggedModelTag,
 183      DeletePromptOptimizationJob,
 184      DeleteRun,
 185      DeleteScorer,
 186      DeleteTag,
 187      DeleteTraces,
 188      DeleteTracesV3,
 189      DeleteTraceTag,
 190      DeleteTraceTagV3,
 191      DeleteWorkspace,
 192      DetachModelFromGatewayEndpoint,
 193      EndTrace,
 194      FinalizeLoggedModel,
 195      GetAssessmentRequest,
 196      GetDataset,
 197      GetDatasetExperimentIds,
 198      GetDatasetRecords,
 199      GetExperiment,
 200      GetExperimentByName,
 201      GetGatewayBudgetPolicy,
 202      GetGatewayEndpoint,
 203      GetGatewayGuardrail,
 204      GetGatewayModelDefinition,
 205      GetGatewaySecretInfo,
 206      GetLoggedModel,
 207      GetMetricHistory,
 208      GetMetricHistoryBulkInterval,
 209      GetPromptOptimizationJob,
 210      GetRun,
 211      GetScorer,
 212      GetTrace,
 213      GetTraceInfo,
 214      GetTraceInfoV3,
 215      GetWorkspace,
 216      LinkPromptsToTrace,
 217      LinkTracesToRun,
 218      ListArtifacts,
 219      ListEndpointGuardrailConfigs,
 220      ListGatewayBudgetPolicies,
 221      ListGatewayBudgetWindows,
 222      ListGatewayEndpointBindings,
 223      ListGatewayEndpoints,
 224      ListGatewayGuardrails,
 225      ListGatewayModelDefinitions,
 226      ListGatewaySecretInfos,
 227      ListLoggedModelArtifacts,
 228      ListScorers,
 229      ListScorerVersions,
 230      ListWorkspaces,
 231      LogBatch,
 232      LogInputs,
 233      LogLoggedModelParamsRequest,
 234      LogMetric,
 235      LogModel,
 236      LogOutputs,
 237      LogParam,
 238      MlflowService,
 239      QueryTraceMetrics,
 240      RegisterScorer,
 241      RemoveDatasetFromExperiments,
 242      RemoveGuardrailFromEndpoint,
 243      RestoreExperiment,
 244      RestoreRun,
 245      SearchDatasets,
 246      SearchEvaluationDatasets,
 247      SearchExperiments,
 248      SearchLoggedModels,
 249      SearchPromptOptimizationJobs,
 250      SearchRuns,
 251      SearchTraces,
 252      SearchTracesV3,
 253      SetDatasetTags,
 254      SetExperimentTag,
 255      SetGatewayEndpointTag,
 256      SetLoggedModelTags,
 257      SetTag,
 258      SetTraceTag,
 259      SetTraceTagV3,
 260      StartTrace,
 261      StartTraceV3,
 262      UpdateAssessment,
 263      UpdateEndpointGuardrailConfig,
 264      UpdateExperiment,
 265      UpdateGatewayBudgetPolicy,
 266      UpdateGatewayEndpoint,
 267      UpdateGatewayModelDefinition,
 268      UpdateGatewaySecret,
 269      UpdateRun,
 270      UpdateWorkspace,
 271      UpsertDatasetRecords,
 272  )
 273  from mlflow.protos.service_pb2 import Trace as ProtoTrace
 274  from mlflow.protos.webhooks_pb2 import (
 275      CreateWebhook,
 276      DeleteWebhook,
 277      GetWebhook,
 278      ListWebhooks,
 279      TestWebhook,
 280      UpdateWebhook,
 281      WebhookService,
 282  )
 283  from mlflow.server.validation import _validate_content_type
 284  from mlflow.server.workspace_helpers import (
 285      _get_workspace_store,
 286  )
 287  from mlflow.store.artifact.artifact_repo import (
 288      MultipartDownloadMixin,
 289      MultipartUploadMixin,
 290      PresignedUploadMixin,
 291  )
 292  from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
 293  from mlflow.store.db.db_types import DATABASE_ENGINES
 294  from mlflow.store.jobs.abstract_store import AbstractJobStore
 295  from mlflow.store.model_registry.abstract_store import AbstractStore as AbstractModelRegistryStore
 296  from mlflow.store.model_registry.rest_store import RestStore as ModelRegistryRestStore
 297  from mlflow.store.tracking import MAX_RESULTS_QUERY_TRACE_METRICS, SEARCH_MAX_RESULTS_DEFAULT
 298  from mlflow.store.tracking.abstract_store import AbstractStore as AbstractTrackingStore
 299  from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore
 300  from mlflow.store.workspace.abstract_store import WorkspaceNameValidator
 301  from mlflow.telemetry import get_telemetry_client
 302  from mlflow.telemetry.installation_id import get_or_create_installation_id
 303  from mlflow.telemetry.schemas import Record, Status
 304  from mlflow.telemetry.utils import (
 305      FALLBACK_UI_CONFIG,
 306      fetch_ui_telemetry_config,
 307      is_telemetry_disabled,
 308  )
 309  from mlflow.tracing.utils.artifact_utils import (
 310      TRACE_DATA_FILE_NAME,
 311      get_artifact_uri_for_trace,
 312  )
 313  from mlflow.tracking._model_registry import utils as registry_utils
 314  from mlflow.tracking._model_registry.registry import ModelRegistryStoreRegistry
 315  from mlflow.tracking._tracking_service import utils
 316  from mlflow.tracking._tracking_service.registry import TrackingStoreRegistry
 317  from mlflow.tracking.context.default_context import _get_user
 318  from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
 319  from mlflow.utils import workspace_context
 320  from mlflow.utils.crypto import CRYPTO_KEK_PASSPHRASE_ENV_VAR
 321  from mlflow.utils.databricks_utils import get_databricks_host_creds
 322  from mlflow.utils.file_utils import local_file_uri_to_path
 323  from mlflow.utils.mime_type_utils import _guess_mime_type
 324  from mlflow.utils.mlflow_tags import (
 325      MLFLOW_ISSUE_DETECTION_JOB_ID,
 326      MLFLOW_RUN_TYPE,
 327      MLFLOW_RUN_TYPE_ISSUE_DETECTION,
 328  )
 329  from mlflow.utils.promptlab_utils import _create_promptlab_run_impl
 330  from mlflow.utils.proto_json_utils import message_to_json, parse_dict
 331  from mlflow.utils.providers import (
 332      get_all_providers,
 333      get_models,
 334      get_provider_config_response,
 335  )
 336  from mlflow.utils.string_utils import is_string_type
 337  from mlflow.utils.time import get_current_time_millis
 338  from mlflow.utils.uri import is_local_uri, validate_path_is_safe, validate_query_string
 339  from mlflow.utils.validation import (
 340      _validate_batch_log_api_req,
 341      _validate_experiment_artifact_location,
 342      _validate_experiment_artifact_location_length,
 343      invalid_value,
 344      missing_value,
 345  )
 346  from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME
 347  from mlflow.webhooks.delivery import deliver_webhook, test_webhook
 348  from mlflow.webhooks.types import (
 349      ModelVersionAliasCreatedPayload,
 350      ModelVersionAliasDeletedPayload,
 351      ModelVersionCreatedPayload,
 352      ModelVersionTagDeletedPayload,
 353      ModelVersionTagSetPayload,
 354      PromptAliasCreatedPayload,
 355      PromptAliasDeletedPayload,
 356      PromptCreatedPayload,
 357      PromptTagDeletedPayload,
 358      PromptTagSetPayload,
 359      PromptVersionCreatedPayload,
 360      PromptVersionTagDeletedPayload,
 361      PromptVersionTagSetPayload,
 362      RegisteredModelCreatedPayload,
 363  )
 364  
 365  _logger = logging.getLogger(__name__)
 366  _tracking_store = None
 367  _model_registry_store = None
 368  _job_store = None
 369  _artifact_repo = None
 370  STATIC_PREFIX_ENV_VAR = "_MLFLOW_STATIC_PREFIX"
 371  MAX_RUNS_GET_METRIC_HISTORY_BULK = 100
 372  MAX_RESULTS_PER_RUN = 2500
 373  # Chunk size for streaming artifact uploads and downloads (1 MB)
 374  ARTIFACT_STREAM_CHUNK_SIZE = 1024 * 1024
 375  
 376  
 377  class TrackingStoreRegistryWrapper(TrackingStoreRegistry):
 378      def __init__(self):
 379          super().__init__()
 380          self.register("", self._get_file_store)
 381          self.register("file", self._get_file_store)
 382          for scheme in DATABASE_ENGINES:
 383              self.register(scheme, self._get_sqlalchemy_store)
 384          # Add support for Databricks tracking store
 385          self.register("databricks", self._get_databricks_rest_store)
 386          self.register_entrypoints()
 387  
 388      @classmethod
 389      def _get_file_store(cls, store_uri, artifact_uri):
 390          from mlflow.store.tracking.file_store import FileStore
 391  
 392          return FileStore(store_uri, artifact_uri)
 393  
 394      @classmethod
 395      def _get_sqlalchemy_store(cls, store_uri, artifact_uri):
 396          from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
 397          from mlflow.store.tracking.sqlalchemy_workspace_store import (
 398              WorkspaceAwareSqlAlchemyStore,
 399          )
 400  
 401          store_cls = (
 402              WorkspaceAwareSqlAlchemyStore if MLFLOW_ENABLE_WORKSPACES.get() else SqlAlchemyStore
 403          )
 404          return store_cls(store_uri, artifact_uri)
 405  
 406      @classmethod
 407      def _get_databricks_rest_store(cls, store_uri, artifact_uri):
 408          return DatabricksTracingRestStore(partial(get_databricks_host_creds, store_uri))
 409  
 410  
 411  class ModelRegistryStoreRegistryWrapper(ModelRegistryStoreRegistry):
 412      def __init__(self):
 413          super().__init__()
 414          self.register("", self._get_file_store)
 415          self.register("file", self._get_file_store)
 416          for scheme in DATABASE_ENGINES:
 417              self.register(scheme, self._get_sqlalchemy_store)
 418          # Add support for Databricks registries
 419          self.register("databricks", self._get_databricks_rest_store)
 420          self.register("databricks-uc", self._get_databricks_uc_rest_store)
 421          self.register_entrypoints()
 422  
 423      @classmethod
 424      def _get_file_store(cls, store_uri):
 425          from mlflow.store.model_registry.file_store import FileStore
 426  
 427          return FileStore(store_uri)
 428  
 429      @classmethod
 430      def _get_sqlalchemy_store(cls, store_uri):
 431          from mlflow.store.model_registry.sqlalchemy_store import SqlAlchemyStore
 432          from mlflow.store.model_registry.sqlalchemy_workspace_store import (
 433              WorkspaceAwareSqlAlchemyStore,
 434          )
 435  
 436          store_cls = (
 437              WorkspaceAwareSqlAlchemyStore if MLFLOW_ENABLE_WORKSPACES.get() else SqlAlchemyStore
 438          )
 439          return store_cls(store_uri)
 440  
 441      @classmethod
 442      def _get_databricks_rest_store(cls, store_uri):
 443          return ModelRegistryRestStore(partial(get_databricks_host_creds, store_uri))
 444  
 445      @classmethod
 446      def _get_databricks_uc_rest_store(cls, store_uri):
 447          from mlflow.environment_variables import MLFLOW_TRACKING_URI
 448          from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore
 449  
 450          # Get tracking URI from environment or use "databricks-uc" as default
 451          tracking_uri = MLFLOW_TRACKING_URI.get() or "databricks-uc"
 452          return UcModelRegistryStore(store_uri, tracking_uri)
 453  
 454  
 455  _tracking_store_registry = TrackingStoreRegistryWrapper()
 456  _model_registry_store_registry = ModelRegistryStoreRegistryWrapper()
 457  
 458  
 459  def _get_artifact_repo_mlflow_artifacts():
 460      """
 461      Get an artifact repository specified by ``--artifacts-destination`` option for ``mlflow server``
 462      command.
 463      """
 464      from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR
 465  
 466      global _artifact_repo
 467      if _artifact_repo is None:
 468          _artifact_repo = get_artifact_repository(os.environ[ARTIFACTS_DESTINATION_ENV_VAR])
 469      return _artifact_repo
 470  
 471  
 472  def _get_trace_artifact_repo(trace_info: TraceInfo):
 473      """
 474      Resolve the artifact repository for fetching data for the given trace.
 475  
 476      Args:
 477          trace_info: The trace info object containing metadata about the trace.
 478      """
 479      artifact_uri = get_artifact_uri_for_trace(trace_info)
 480  
 481      if _is_servable_proxied_run_artifact_root(artifact_uri):
 482          # If the artifact location is a proxied run artifact root (e.g. mlflow-artifacts://...),
 483          # we need to resolve it to the actual artifact location.
 484          from mlflow.server import ARTIFACTS_DESTINATION_ENV_VAR
 485  
 486          path = _get_proxied_run_artifact_destination_path(artifact_uri)
 487          if not path:
 488              raise MlflowException(
 489                  f"Failed to resolve the proxied run artifact URI: {artifact_uri}. ",
 490                  "Trace artifact URI must contain subpath to the trace data directory.",
 491                  error_code=BAD_REQUEST,
 492              )
 493          root = os.environ[ARTIFACTS_DESTINATION_ENV_VAR]
 494          artifact_uri = posixpath.join(root, path)
 495  
 496          # We don't set it to global var unlike run artifact, because the artifact repo has
 497          # to be created with full trace artifact URI including request_id.
 498          # e.g. s3://<experiment_id>/traces/<request_id>
 499          artifact_repo = get_artifact_repository(artifact_uri)
 500      else:
 501          artifact_repo = get_artifact_repository(artifact_uri)
 502      return artifact_repo
 503  
 504  
 505  def _is_serving_proxied_artifacts():
 506      """
 507      Returns:
 508          True if the MLflow server is serving proxied artifacts (i.e. acting as a proxy for
 509          artifact upload / download / list operations), as would be enabled by specifying the
 510          --serve-artifacts configuration option. False otherwise.
 511      """
 512      from mlflow.server import SERVE_ARTIFACTS_ENV_VAR
 513  
 514      return os.environ.get(SERVE_ARTIFACTS_ENV_VAR, "false") == "true"
 515  
 516  
 517  def _is_servable_proxied_run_artifact_root(run_artifact_root):
 518      """
 519      Determines whether or not the following are true:
 520  
 521      - The specified Run artifact root is a proxied artifact root (i.e. an artifact root with scheme
 522        ``http``, ``https``, or ``mlflow-artifacts``).
 523  
 524      - The MLflow server is capable of resolving and accessing the underlying storage location
 525        corresponding to the proxied artifact root, allowing it to fulfill artifact list and
 526        download requests by using this storage location directly.
 527  
 528      Args:
 529          run_artifact_root: The Run artifact root location (URI).
 530  
 531      Returns:
 532          True if the specified Run artifact root refers to proxied artifacts that can be
 533          served by this MLflow server (i.e. the server has access to the destination and
 534          can respond to list and download requests for the artifact). False otherwise.
 535      """
 536      parsed_run_artifact_root = urllib.parse.urlparse(run_artifact_root)
 537      # NB: If the run artifact root is a proxied artifact root (has scheme `http`, `https`, or
 538      # `mlflow-artifacts`) *and* the MLflow server is configured to serve artifacts, the MLflow
 539      # server always assumes that it has access to the underlying storage location for the proxied
 540      # artifacts. This may not always be accurate. For example:
 541      #
 542      # An organization may initially use the MLflow server to serve Tracking API requests and proxy
 543      # access to artifacts stored in Location A (via `mlflow server --serve-artifacts`). Then, for
 544      # scalability and / or security purposes, the organization may decide to store artifacts in a
 545      # new location B and set up a separate server (e.g. `mlflow server --artifacts-only`) to proxy
 546      # access to artifacts stored in Location B.
 547      #
 548      # In this scenario, requests for artifacts stored in Location B that are sent to the original
 549      # MLflow server will fail if the original MLflow server does not have access to Location B
 550      # because it will assume that it can serve all proxied artifacts regardless of the underlying
 551      # location. Such failures can be remediated by granting the original MLflow server access to
 552      # Location B.
 553      return (
 554          parsed_run_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
 555          and _is_serving_proxied_artifacts()
 556      )
 557  
 558  
 559  def _get_proxied_run_artifact_destination_path(proxied_artifact_root, relative_path=None):
 560      """
 561      Resolves the specified proxied artifact location within a Run to a concrete storage location.
 562  
 563      Args:
 564          proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``,
 565              ``https``, or `mlflow-artifacts` that can be resolved by the MLflow server to a
 566              concrete storage location.
 567          relative_path: The relative path of the destination within the specified
 568              ``proxied_artifact_root``. If ``None``, the destination is assumed to be
 569              the resolved ``proxied_artifact_root``.
 570  
 571      Returns:
 572          The storage location of the specified artifact.
 573      """
 574      parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root)
 575      assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
 576  
 577      if parsed_proxied_artifact_root.scheme == "mlflow-artifacts":
 578          # If the proxied artifact root is an `mlflow-artifacts` URI, the run artifact root path is
 579          # simply the path component of the URI, since the fully-qualified format of an
 580          # `mlflow-artifacts` URI is `mlflow-artifacts://<netloc>/path/to/artifact`
 581          proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.lstrip("/")
 582      else:
 583          # In this case, the proxied artifact root is an HTTP(S) URL referring to an mlflow-artifacts
 584          # API route that can be used to download the artifact. These routes are always anchored at
 585          # `/api/2.0/mlflow-artifacts/artifacts`. Accordingly, we split the path on this route anchor
 586          # and interpret the rest of the path (everything after the route anchor) as the run artifact
 587          # root path
 588          mlflow_artifacts_http_route_anchor = "/api/2.0/mlflow-artifacts/artifacts/"
 589          assert mlflow_artifacts_http_route_anchor in parsed_proxied_artifact_root.path
 590  
 591          proxied_run_artifact_root_path = parsed_proxied_artifact_root.path.split(
 592              mlflow_artifacts_http_route_anchor
 593          )[1].lstrip("/")
 594  
 595      return (
 596          posixpath.join(proxied_run_artifact_root_path, relative_path)
 597          if relative_path is not None
 598          else proxied_run_artifact_root_path
 599      )
 600  
 601  
 602  def _get_tracking_store(
 603      backend_store_uri: str | None = None,
 604      default_artifact_root: str | None = None,
 605  ) -> AbstractTrackingStore:
 606      from mlflow.server import ARTIFACT_ROOT_ENV_VAR, BACKEND_STORE_URI_ENV_VAR
 607  
 608      global _tracking_store
 609      if _tracking_store is None:
 610          store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
 611          artifact_root = default_artifact_root or os.environ.get(ARTIFACT_ROOT_ENV_VAR, None)
 612          _tracking_store = _tracking_store_registry.get_store(store_uri, artifact_root)
 613          utils.set_tracking_uri(store_uri)
 614      return _tracking_store
 615  
 616  
 617  def _get_model_registry_store(registry_store_uri: str | None = None) -> AbstractModelRegistryStore:
 618      from mlflow.server import BACKEND_STORE_URI_ENV_VAR, REGISTRY_STORE_URI_ENV_VAR
 619  
 620      global _model_registry_store
 621      if _model_registry_store is None:
 622          store_uri = (
 623              registry_store_uri
 624              or os.environ.get(REGISTRY_STORE_URI_ENV_VAR, None)
 625              or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
 626          )
 627          _model_registry_store = _model_registry_store_registry.get_store(store_uri)
 628          registry_utils.set_registry_uri(store_uri)
 629      return _model_registry_store
 630  
 631  
 632  def _get_job_store(backend_store_uri: str | None = None) -> AbstractJobStore:
 633      """
 634      Get a job store instance based on the backend store URI.
 635  
 636      Args:
 637          backend_store_uri: Optional backend store URI. If not provided,
 638                            uses environment variable.
 639  
 640      Returns:
 641          An instance of AbstractJobStore
 642      """
 643      from mlflow.server import BACKEND_STORE_URI_ENV_VAR
 644      from mlflow.store.jobs.sqlalchemy_store import SqlAlchemyJobStore
 645      from mlflow.store.jobs.sqlalchemy_workspace_store import WorkspaceAwareSqlAlchemyJobStore
 646      from mlflow.utils.uri import extract_db_type_from_uri
 647  
 648      global _job_store
 649      if _job_store is None:
 650          store_uri = backend_store_uri or os.environ.get(BACKEND_STORE_URI_ENV_VAR, None)
 651          if not store_uri:
 652              raise MlflowException.invalid_parameter_value("Job store requires a backend store URI")
 653          try:
 654              extract_db_type_from_uri(store_uri)
 655          except (MlflowException, ValueError):
 656              # Require a database backend URI for the job store
 657              # Raise MlflowException so the CLI/REST layer returns a structured 400
 658              # instead of surfacing a generic 500 from ValueError
 659              raise MlflowException.invalid_parameter_value("Job store requires a backend store URI")
 660  
 661          store_cls = (
 662              WorkspaceAwareSqlAlchemyJobStore
 663              if MLFLOW_ENABLE_WORKSPACES.get()
 664              else SqlAlchemyJobStore
 665          )
 666          _job_store = store_cls(store_uri)
 667  
 668          if MLFLOW_ENABLE_WORKSPACES.get():
 669              _verify_job_store_workspace_support(_job_store)
 670  
 671      return _job_store
 672  
 673  
 674  def initialize_backend_stores(
 675      backend_store_uri: str | None = None,
 676      registry_store_uri: str | None = None,
 677      default_artifact_root: str | None = None,
 678      workspace_store_uri: str | None = None,
 679  ) -> None:
 680      tracking_store = _get_tracking_store(backend_store_uri, default_artifact_root)
 681      registry_store = None
 682      try:
 683          registry_store = _get_model_registry_store(registry_store_uri)
 684      except UnsupportedModelRegistryStoreURIException:
 685          pass
 686  
 687      if MLFLOW_ENABLE_WORKSPACES.get():
 688          # Initialize the workspace store to verify it's correctly configured
 689          _get_workspace_store(
 690              workspace_uri=workspace_store_uri,
 691              tracking_uri=backend_store_uri,
 692          )
 693          _verify_tracking_store_workspace_support(tracking_store)
 694          _verify_model_registry_store_workspace_support(registry_store)
 695  
 696  
 697  def _store_supports_workspaces(
 698      store: AbstractTrackingStore | AbstractModelRegistryStore | AbstractJobStore,
 699  ) -> bool:
 700      """Return whether the provided store reports workspace support."""
 701      return bool(getattr(store, "supports_workspaces", False))
 702  
 703  
 704  def _verify_tracking_store_workspace_support(tracking_store: AbstractTrackingStore) -> None:
 705      if not _store_supports_workspaces(tracking_store):
 706          raise MlflowException(
 707              "The configured tracking store does not support workspace-aware operations. "
 708              "Remove the --enable-workspaces flag or configure a workspace-capable backend store.",
 709              error_code=INVALID_STATE,
 710          )
 711  
 712  
 713  def _verify_model_registry_store_workspace_support(
 714      registry_store: AbstractModelRegistryStore,
 715  ) -> None:
 716      if registry_store is None:
 717          return
 718  
 719      if not _store_supports_workspaces(registry_store):
 720          raise MlflowException(
 721              "The configured model registry store does not support workspace-aware operations. "
 722              "Remove the --enable-workspaces flag or configure a workspace-capable backend store.",
 723              error_code=INVALID_STATE,
 724          )
 725  
 726  
 727  def _verify_job_store_workspace_support(job_store: AbstractJobStore) -> None:
 728      if not _store_supports_workspaces(job_store):
 729          raise MlflowException(
 730              "The configured job store does not support workspace-aware operations. "
 731              "Remove the --enable-workspaces flag or configure a workspace-capable backend store.",
 732              error_code=INVALID_STATE,
 733          )
 734  
 735  
 736  def _assert_string(x):
 737      assert isinstance(x, str)
 738  
 739  
 740  def _assert_intlike(x):
 741      try:
 742          x = int(x)
 743      except ValueError:
 744          pass
 745  
 746      assert isinstance(x, int)
 747  
 748  
 749  def _assert_bool(x):
 750      assert isinstance(x, bool)
 751  
 752  
 753  def _assert_floatlike(x):
 754      try:
 755          x = float(x)
 756      except ValueError:
 757          pass
 758  
 759      assert isinstance(x, float)
 760  
 761  
 762  def _assert_array(x):
 763      assert isinstance(x, list)
 764  
 765  
 766  def _assert_map_key_present(x):
 767      _assert_array(x)
 768      for entry in x:
 769          _assert_required(entry.get("key"))
 770  
 771  
 772  def _assert_required(x, path=None):
 773      if path is None:
 774          assert x is not None
 775          # When parsing JSON payloads via proto, absent string fields
 776          # are expressed as empty strings
 777          assert x != ""
 778      else:
 779          assert x is not None, missing_value(path)
 780          assert x != "", missing_value(path)
 781  
 782  
 783  def _assert_less_than_or_equal(x, max_value, message=None):
 784      if x > max_value:
 785          raise AssertionError(message) if message else AssertionError()
 786  
 787  
 788  def _assert_intlike_within_range(x, min_value, max_value, message=None):
 789      if not min_value <= x <= max_value:
 790          raise AssertionError(message) if message else AssertionError()
 791  
 792  
 793  def _assert_item_type_string(x):
 794      assert all(isinstance(item, str) for item in x)
 795  
 796  
 797  def _assert_secret_value(x):
 798      """Validate secret_value is present. Does not print values in errors."""
 799      if x is None:
 800          raise MlflowException(
 801              message="Missing value for required parameter 'secret_value'.",
 802              error_code=INVALID_PARAMETER_VALUE,
 803          )
 804  
 805  
 806  _TYPE_VALIDATORS = {
 807      _assert_intlike,
 808      _assert_string,
 809      _assert_bool,
 810      _assert_floatlike,
 811      _assert_array,
 812      _assert_item_type_string,
 813  }
 814  
 815  
 816  def _validate_param_against_schema(schema, param, value, proto_parsing_succeeded=False):
 817      """
 818      Attempts to validate a single parameter against a specified schema. Examples of the elements of
 819      the schema are type assertions and checks for required parameters. Returns None on validation
 820      success.  Otherwise, raises an MLFlowException if an assertion fails. This method is intended
 821      to be called for side effects.
 822  
 823      Args:
 824          schema: A list of functions to validate the parameter against.
 825          param: The string name of the parameter being validated.
 826          value: The corresponding value of the `param` being validated.
 827          proto_parsing_succeeded: A boolean value indicating whether proto parsing succeeded.
 828              If the proto was successfully parsed, we assume all of the types of the parameters in
 829              the request body were correctly specified, and thus we skip validating types. If proto
 830              parsing failed, then we validate types in addition to the rest of the schema. For
 831              details, see https://github.com/mlflow/mlflow/pull/5458#issuecomment-1080880870.
 832      """
 833  
 834      for f in schema:
 835          if f in _TYPE_VALIDATORS and proto_parsing_succeeded:
 836              continue
 837  
 838          try:
 839              f(value)
 840          except AssertionError as e:
 841              if e.args:
 842                  message = e.args[0]
 843              elif f == _assert_required:
 844                  message = f"Missing value for required parameter '{param}'."
 845              else:
 846                  message = invalid_value(
 847                      param, value, f" Hint: Value was of type '{type(value).__name__}'."
 848                  )
 849              raise MlflowException(
 850                  message=(
 851                      message + " See the API docs for more information about request parameters."
 852                  ),
 853                  error_code=INVALID_PARAMETER_VALUE,
 854              )
 855  
 856      return None
 857  
 858  
 859  def _get_request_json(flask_request=request):
 860      _validate_content_type(flask_request, ["application/json"])
 861      return flask_request.get_json(force=True, silent=True)
 862  
 863  
 864  def _get_normalized_request_json(flask_request: Request = request) -> dict[str, Any]:
 865      """
 866      Get request JSON with normalization for legacy clients.
 867  
 868      Handles double-encoded JSON strings from older clients and empty request bodies.
 869  
 870      Args:
 871          flask_request: The Flask request object.
 872  
 873      Returns:
 874          The request data as a dictionary (empty dict if no body).
 875      """
 876      request_json = _get_request_json(flask_request)
 877  
 878      # Older clients may post their JSON double-encoded as strings, so the get_json
 879      # above actually converts it to a string. Therefore, we check this condition
 880      # (which we can tell for sure because any proper request should be a dictionary),
 881      # and decode it a second time.
 882      if is_string_type(request_json):
 883          request_json = json.loads(request_json)
 884  
 885      # If request doesn't have json body then assume it's empty.
 886      if request_json is None:
 887          request_json = {}
 888  
 889      return request_json
 890  
 891  
 892  def _validate_request_json_with_schema(
 893      request_json: dict[str, Any],
 894      schema: dict[str, list[Callable[..., Any]]] | None,
 895      proto_parsing_succeeded: bool | None,
 896  ) -> None:
 897      """
 898      Validate request JSON against a schema without requiring protobuf messages.
 899  
 900      Args:
 901          request_json: The request data as a dictionary.
 902          schema: Dictionary mapping parameter names to lists of validation functions.
 903          proto_parsing_succeeded: Whether protobuf parsing succeeded. None indicates the
 904              request was not parsed from protobuf.
 905      """
 906      schema = schema or {}
 907      for schema_key, schema_validation_fns in schema.items():
 908          if schema_key in request_json or _assert_required in schema_validation_fns:
 909              value = request_json.get(schema_key)
 910              if schema_key == "run_id" and value is None and "run_uuid" in request_json:
 911                  value = request_json.get("run_uuid")
 912              _validate_param_against_schema(
 913                  schema=schema_validation_fns,
 914                  param=schema_key,
 915                  value=value,
 916                  proto_parsing_succeeded=proto_parsing_succeeded,
 917              )
 918  
 919  
 920  def _get_request_message(request_message, flask_request=request, schema=None):
 921      if flask_request.method == "GET" and flask_request.args:
 922          # Convert atomic values of repeated fields to lists before calling protobuf deserialization.
 923          # Context: We parse the parameter string into a dictionary outside of protobuf since
 924          # protobuf does not know how to read the query parameters directly. The query parser above
 925          # has no type information and hence any parameter that occurs exactly once is parsed as an
 926          # atomic value. Since protobuf requires that the values of repeated fields are lists,
 927          # deserialization will fail unless we do the fix below.
 928          request_json = {}
 929          for field in request_message.DESCRIPTOR.fields:
 930              if field.name not in flask_request.args:
 931                  continue
 932  
 933              # Use is_repeated property (preferred) with fallback to deprecated label
 934              try:
 935                  is_repeated = field.is_repeated
 936              except AttributeError:
 937                  is_repeated = field.label == descriptor.FieldDescriptor.LABEL_REPEATED
 938  
 939              if is_repeated:
 940                  request_json[field.name] = flask_request.args.getlist(field.name)
 941              else:
 942                  value = flask_request.args.get(field.name)
 943                  if field.type == descriptor.FieldDescriptor.TYPE_BOOL and isinstance(value, str):
 944                      if value.lower() not in ["true", "false"]:
 945                          raise MlflowException.invalid_parameter_value(
 946                              f"Invalid boolean value: {value}, must be 'true' or 'false'.",
 947                          )
 948                      value = value.lower() == "true"
 949                  request_json[field.name] = value
 950      else:
 951          request_json = _get_normalized_request_json(flask_request)
 952  
 953      proto_parsing_succeeded = True
 954      try:
 955          parse_dict(request_json, request_message)
 956      except ParseError:
 957          proto_parsing_succeeded = False
 958  
 959      _validate_request_json_with_schema(request_json, schema, proto_parsing_succeeded)
 960  
 961      return request_message
 962  
 963  
 964  def _get_validated_flask_request_json(
 965      flask_request: Request = request,
 966      schema: dict[str, list[Callable[..., Any]]] | None = None,
 967  ) -> dict[str, Any]:
 968      """
 969      Get and validate request data without protobuf parsing.
 970  
 971      This is an alternative to _get_request_message for endpoints that don't
 972      use protobuf message definitions. Supports both GET and POST/PUT requests.
 973  
 974      Args:
 975          flask_request: The Flask request object.
 976          schema: Dictionary mapping parameter names to lists of validation functions.
 977  
 978      Returns:
 979          The validated request data as a dictionary.
 980  
 981      Raises:
 982          MlflowException: If validation fails.
 983      """
 984      if flask_request.method == "GET" and flask_request.args:
 985          # Extract query parameters for GET requests
 986          request_json = {}
 987          schema = schema or {}
 988          for key in flask_request.args:
 989              # Get all values for this key (supports repeated parameters)
 990              values = flask_request.args.getlist(key)
 991              # Check if this field is a list type by looking for _assert_array validator
 992              is_list_type = _assert_array in schema.get(key, [])
 993              # If list type, always keep as list; otherwise use scalar if only one value
 994              request_json[key] = (
 995                  values if is_list_type else (values[0] if len(values) == 1 else values)
 996              )
 997      else:
 998          request_json = _get_normalized_request_json(flask_request)
 999  
1000      _validate_request_json_with_schema(request_json, schema, proto_parsing_succeeded=None)
1001  
1002      return request_json
1003  
1004  
1005  def _response_with_file_attachment_headers(file_path, response):
1006      mime_type = _guess_mime_type(file_path)
1007      filename = pathlib.Path(file_path).name
1008      response.mimetype = mime_type
1009      content_disposition_header_name = "Content-Disposition"
1010      if content_disposition_header_name not in response.headers:
1011          response.headers[content_disposition_header_name] = f"attachment; filename={filename}"
1012      response.headers["X-Content-Type-Options"] = "nosniff"
1013      response.headers["Content-Type"] = mime_type
1014      return response
1015  
1016  
1017  def _send_artifact(artifact_repository, path):
1018      file_path = os.path.abspath(artifact_repository.download_artifacts(path))
1019      # Always send artifacts as attachments to prevent the browser from displaying them on our web
1020      # server's domain, which might enable XSS.
1021      mime_type = _guess_mime_type(file_path)
1022      file_sender_response = send_file(file_path, mimetype=mime_type, as_attachment=True)
1023      return _response_with_file_attachment_headers(file_path, file_sender_response)
1024  
1025  
1026  def catch_mlflow_exception(func):
1027      @wraps(func)
1028      def wrapper(*args, **kwargs):
1029          try:
1030              return func(*args, **kwargs)
1031          except MlflowException as e:
1032              response = Response(mimetype="application/json")
1033              response.set_data(e.serialize_as_json())
1034              response.status_code = e.get_http_status_code()
1035              if response.status_code >= 500:
1036                  is_debug = _logger.isEnabledFor(logging.DEBUG)
1037                  msg = f"Error in {func.__name__}: {e}"
1038                  if not is_debug:
1039                      msg += ". Set MLFLOW_LOGGING_LEVEL=DEBUG for traceback."
1040                  _logger.error(msg, exc_info=is_debug)
1041              return response
1042  
1043      return wrapper
1044  
1045  
1046  def _disable_unless_serve_artifacts(func):
1047      @wraps(func)
1048      def wrapper(*args, **kwargs):
1049          if not _is_serving_proxied_artifacts():
1050              return Response(
1051                  (
1052                      f"Endpoint: {request.url_rule} disabled due to the mlflow server running "
1053                      "with `--no-serve-artifacts`. To enable artifacts server functionality, "
1054                      "run `mlflow server` with `--serve-artifacts`"
1055                  ),
1056                  503,
1057              )
1058          return func(*args, **kwargs)
1059  
1060      return wrapper
1061  
1062  
1063  def _disable_if_artifacts_only(func):
1064      @wraps(func)
1065      def wrapper(*args, **kwargs):
1066          from mlflow.server import ARTIFACTS_ONLY_ENV_VAR
1067  
1068          if os.environ.get(ARTIFACTS_ONLY_ENV_VAR):
1069              return Response(
1070                  (
1071                      f"Endpoint: {request.url_rule} disabled due to the mlflow server running "
1072                      "in `--artifacts-only` mode. To enable tracking server functionality, run "
1073                      "`mlflow server` without `--artifacts-only`"
1074                  ),
1075                  503,
1076              )
1077          return func(*args, **kwargs)
1078  
1079      return wrapper
1080  
1081  
1082  def _disable_if_workspaces_disabled(func):
1083      @wraps(func)
1084      def wrapper(*args, **kwargs):
1085          if not MLFLOW_ENABLE_WORKSPACES.get():
1086              return Response(
1087                  (
1088                      f"Endpoint: {request.url_rule} disabled because the server is running "
1089                      "without workspaces support. To enable workspace, run "
1090                      "`mlflow server` with `--enable-workspaces`"
1091                  ),
1092                  503,
1093              )
1094          return func(*args, **kwargs)
1095  
1096      return wrapper
1097  
1098  
1099  def _workspace_not_supported(message: str) -> MlflowException:
1100      return MlflowException(message, FEATURE_DISABLED)
1101  
1102  
1103  def _validate_artifact_root_uri(value: str, field_name: str) -> str:
1104      parsed = urllib.parse.urlparse(value)
1105      if parsed.fragment or parsed.params:
1106          raise MlflowException.invalid_parameter_value(
1107              f"'{field_name}' URL can't include fragments or params."
1108          )
1109  
1110      validate_query_string(parsed.query)
1111      _validate_experiment_artifact_location(value)
1112      _validate_experiment_artifact_location_length(value)
1113      return value
1114  
1115  
1116  def _validate_workspace_default_artifact_root(value: str | None) -> str | None:
1117      if value is None:
1118          return None
1119  
1120      trimmed = value.strip()
1121      if not trimmed:
1122          return ""
1123  
1124      return _validate_artifact_root_uri(trimmed, "default_artifact_root")
1125  
1126  
1127  def _ensure_artifact_root_available(workspace_artifact_root: str | None) -> None:
1128      """Ensure an artifact root is available either at workspace or server level.
1129  
1130      Args:
1131          workspace_artifact_root: The workspace's default_artifact_root value.
1132              - None means "not specified" (fallback to server default)
1133              - "" means "clear/unset" (fallback to server default)
1134              - non-empty string means "use this workspace-specific root"
1135  
1136      Raises:
1137          MlflowException: If neither workspace nor server has an artifact root configured.
1138      """
1139      # If workspace has a non-empty artifact root, it's valid
1140      if workspace_artifact_root:
1141          return
1142  
1143      # Otherwise, check if server has a default artifact root
1144      server_artifact_root = _get_tracking_store().artifact_root_uri
1145      if not server_artifact_root:
1146          raise MlflowException.invalid_parameter_value(
1147              "Cannot create or update workspace without an artifact root. Either specify "
1148              "'default_artifact_root' for this workspace or start the server with "
1149              "'--default-artifact-root'."
1150          )
1151  
1152  
1153  @catch_mlflow_exception
1154  @_disable_if_workspaces_disabled
1155  def _list_workspaces_handler():
1156      _get_request_message(ListWorkspaces())
1157      workspaces = _get_workspace_store().list_workspaces()
1158      response_message = ListWorkspaces.Response()
1159      response_message.workspaces.extend([ws.to_proto() for ws in workspaces])
1160      return _wrap_response(response_message)
1161  
1162  
1163  @catch_mlflow_exception
1164  @_disable_if_workspaces_disabled
1165  def _create_workspace_handler():
1166      request_message = _get_request_message(
1167          CreateWorkspace(),
1168          schema={
1169              "name": [_assert_required, _assert_string],
1170              "description": [_assert_string],
1171              "default_artifact_root": [_assert_string],
1172          },
1173      )
1174  
1175      if request_message.name == DEFAULT_WORKSPACE_NAME:
1176          raise MlflowException.invalid_parameter_value(
1177              f"The '{DEFAULT_WORKSPACE_NAME}' workspace is reserved and cannot be created"
1178          )
1179      WorkspaceNameValidator.validate(request_message.name)
1180      description = request_message.description if request_message.HasField("description") else None
1181      default_artifact_root = (
1182          request_message.default_artifact_root
1183          if request_message.HasField("default_artifact_root")
1184          else None
1185      )
1186      default_artifact_root = _validate_workspace_default_artifact_root(default_artifact_root)
1187      _ensure_artifact_root_available(default_artifact_root)
1188      store = _get_workspace_store()
1189      try:
1190          workspace = store.create_workspace(
1191              Workspace(
1192                  name=request_message.name,
1193                  description=description,
1194                  default_artifact_root=default_artifact_root,
1195              )
1196          )
1197      except NotImplementedError:
1198          raise _workspace_not_supported("Workspace creation is not supported by this provider")
1199  
1200      response_message = CreateWorkspace.Response()
1201      response_message.workspace.MergeFrom(workspace.to_proto())
1202      response = _wrap_response(response_message)
1203      response.status_code = 201
1204      return response
1205  
1206  
1207  @catch_mlflow_exception
1208  @_disable_if_workspaces_disabled
1209  def _get_workspace_handler(workspace_name: str):
1210      if workspace_name != DEFAULT_WORKSPACE_NAME:
1211          WorkspaceNameValidator.validate(workspace_name)
1212      workspace = _get_workspace_store().get_workspace(workspace_name)
1213      response_message = GetWorkspace.Response()
1214      response_message.workspace.MergeFrom(workspace.to_proto())
1215      return _wrap_response(response_message)
1216  
1217  
1218  @catch_mlflow_exception
1219  @_disable_if_workspaces_disabled
1220  def _update_workspace_handler(workspace_name: str):
1221      if workspace_name != DEFAULT_WORKSPACE_NAME:
1222          WorkspaceNameValidator.validate(workspace_name)
1223      request_message = _get_request_message(
1224          UpdateWorkspace(),
1225          schema={
1226              "description": [_assert_string],
1227              "default_artifact_root": [_assert_string],
1228          },
1229      )
1230  
1231      has_description = request_message.HasField("description")
1232      has_artifact_root = request_message.HasField("default_artifact_root")
1233  
1234      if not has_description and not has_artifact_root:
1235          raise MlflowException.invalid_parameter_value("Workspace update must have at least one key")
1236  
1237      description = request_message.description if has_description else None
1238      default_artifact_root = request_message.default_artifact_root if has_artifact_root else None
1239      default_artifact_root = _validate_workspace_default_artifact_root(default_artifact_root)
1240  
1241      # If the user is clearing the workspace artifact root (empty string), ensure the server
1242      # has a default artifact root configured
1243      if default_artifact_root == "":
1244          _ensure_artifact_root_available(default_artifact_root)
1245  
1246      store = _get_workspace_store()
1247      try:
1248          workspace = store.update_workspace(
1249              Workspace(
1250                  name=workspace_name,
1251                  description=description,
1252                  default_artifact_root=default_artifact_root,
1253              )
1254          )
1255      except NotImplementedError:
1256          raise _workspace_not_supported("Workspace updates are not supported by this provider")
1257  
1258      response_message = UpdateWorkspace.Response()
1259      response_message.workspace.MergeFrom(workspace.to_proto())
1260      return _wrap_response(response_message)
1261  
1262  
1263  @catch_mlflow_exception
1264  @_disable_if_workspaces_disabled
1265  def _delete_workspace_handler(workspace_name: str):
1266      if workspace_name == DEFAULT_WORKSPACE_NAME:
1267          raise MlflowException.invalid_parameter_value(
1268              f"The '{DEFAULT_WORKSPACE_NAME}' workspace is reserved and cannot be deleted"
1269          )
1270      WorkspaceNameValidator.validate(workspace_name)
1271      mode_str = request.args.get("mode", WorkspaceDeletionMode.RESTRICT.value)
1272      try:
1273          mode = WorkspaceDeletionMode(mode_str)
1274      except ValueError:
1275          raise MlflowException.invalid_parameter_value(
1276              f"Invalid deletion mode '{mode_str}'. "
1277              f"Must be one of: {', '.join(m.value for m in WorkspaceDeletionMode)}"
1278          )
1279      store = _get_workspace_store()
1280      try:
1281          store.delete_workspace(workspace_name, mode=mode)
1282      except NotImplementedError:
1283          raise _workspace_not_supported("Workspace deletion is not supported by this provider")
1284      return Response(status=204)
1285  
1286  
1287  @catch_mlflow_exception
1288  def get_artifact_handler():
1289      run_id = request.args.get("run_id") or request.args.get("run_uuid")
1290      path = request.args["path"]
1291      path = validate_path_is_safe(path)
1292      run = _get_tracking_store().get_run(run_id)
1293  
1294      if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
1295          artifact_repo = _get_artifact_repo_mlflow_artifacts()
1296          artifact_path = _get_proxied_run_artifact_destination_path(
1297              proxied_artifact_root=run.info.artifact_uri,
1298              relative_path=path,
1299          )
1300          artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
1301      else:
1302          artifact_repo = _get_artifact_repo(run)
1303          artifact_path = path
1304  
1305      return _send_artifact(artifact_repo, artifact_path)
1306  
1307  
1308  def _not_implemented():
1309      response = Response()
1310      response.status_code = 404
1311      return response
1312  
1313  
1314  # Tracking Server APIs
1315  
1316  
1317  @catch_mlflow_exception
1318  @_disable_if_artifacts_only
1319  def _create_experiment():
1320      request_message = _get_request_message(
1321          CreateExperiment(),
1322          schema={
1323              "name": [_assert_required, _assert_string],
1324              "artifact_location": [_assert_string],
1325              "tags": [_assert_array],
1326          },
1327      )
1328  
1329      tags = [ExperimentTag(tag.key, tag.value) for tag in request_message.tags]
1330  
1331      if request_message.artifact_location:
1332          _validate_artifact_root_uri(request_message.artifact_location, "artifact_location")
1333      experiment_id = _get_tracking_store().create_experiment(
1334          request_message.name, request_message.artifact_location, tags
1335      )
1336      response_message = CreateExperiment.Response()
1337      response_message.experiment_id = experiment_id
1338      response = Response(mimetype="application/json")
1339      response.set_data(message_to_json(response_message))
1340      return response
1341  
1342  
1343  @catch_mlflow_exception
1344  @_disable_if_artifacts_only
1345  def _get_experiment():
1346      request_message = _get_request_message(
1347          GetExperiment(), schema={"experiment_id": [_assert_required, _assert_string]}
1348      )
1349      response_message = get_experiment_impl(request_message)
1350      response = Response(mimetype="application/json")
1351      response.set_data(message_to_json(response_message))
1352      return response
1353  
1354  
1355  def get_experiment_impl(request_message):
1356      response_message = GetExperiment.Response()
1357      experiment = _get_tracking_store().get_experiment(request_message.experiment_id).to_proto()
1358      response_message.experiment.MergeFrom(experiment)
1359      return response_message
1360  
1361  
1362  @catch_mlflow_exception
1363  @_disable_if_artifacts_only
1364  def _get_experiment_by_name():
1365      request_message = _get_request_message(
1366          GetExperimentByName(),
1367          schema={"experiment_name": [_assert_required, _assert_string]},
1368      )
1369      response_message = GetExperimentByName.Response()
1370      store_exp = _get_tracking_store().get_experiment_by_name(request_message.experiment_name)
1371      if store_exp is None:
1372          raise MlflowException(
1373              f"Could not find experiment with name '{request_message.experiment_name}'",
1374              error_code=RESOURCE_DOES_NOT_EXIST,
1375          )
1376      experiment = store_exp.to_proto()
1377      response_message.experiment.MergeFrom(experiment)
1378      response = Response(mimetype="application/json")
1379      response.set_data(message_to_json(response_message))
1380      return response
1381  
1382  
1383  @catch_mlflow_exception
1384  @_disable_if_artifacts_only
1385  def _delete_experiment():
1386      request_message = _get_request_message(
1387          DeleteExperiment(), schema={"experiment_id": [_assert_required, _assert_string]}
1388      )
1389      _get_tracking_store().delete_experiment(request_message.experiment_id)
1390      response_message = DeleteExperiment.Response()
1391      response = Response(mimetype="application/json")
1392      response.set_data(message_to_json(response_message))
1393      return response
1394  
1395  
1396  @catch_mlflow_exception
1397  @_disable_if_artifacts_only
1398  def _restore_experiment():
1399      request_message = _get_request_message(
1400          RestoreExperiment(),
1401          schema={"experiment_id": [_assert_required, _assert_string]},
1402      )
1403      _get_tracking_store().restore_experiment(request_message.experiment_id)
1404      response_message = RestoreExperiment.Response()
1405      response = Response(mimetype="application/json")
1406      response.set_data(message_to_json(response_message))
1407      return response
1408  
1409  
1410  @catch_mlflow_exception
1411  @_disable_if_artifacts_only
1412  def _update_experiment():
1413      request_message = _get_request_message(
1414          UpdateExperiment(),
1415          schema={
1416              "experiment_id": [_assert_required, _assert_string],
1417              "new_name": [_assert_string, _assert_required],
1418          },
1419      )
1420      if request_message.new_name:
1421          _get_tracking_store().rename_experiment(
1422              request_message.experiment_id, request_message.new_name
1423          )
1424      response_message = UpdateExperiment.Response()
1425      response = Response(mimetype="application/json")
1426      response.set_data(message_to_json(response_message))
1427      return response
1428  
1429  
1430  @catch_mlflow_exception
1431  @_disable_if_artifacts_only
1432  def _create_run():
1433      request_message = _get_request_message(
1434          CreateRun(),
1435          schema={
1436              "experiment_id": [_assert_string],
1437              "start_time": [_assert_intlike],
1438              "run_name": [_assert_string],
1439          },
1440      )
1441  
1442      tags = [RunTag(tag.key, tag.value) for tag in request_message.tags]
1443      run = _get_tracking_store().create_run(
1444          experiment_id=request_message.experiment_id,
1445          user_id=request_message.user_id,
1446          start_time=request_message.start_time,
1447          tags=tags,
1448          run_name=request_message.run_name,
1449      )
1450  
1451      response_message = CreateRun.Response()
1452      response_message.run.MergeFrom(run.to_proto())
1453      response = Response(mimetype="application/json")
1454      response.set_data(message_to_json(response_message))
1455      return response
1456  
1457  
1458  @catch_mlflow_exception
1459  @_disable_if_artifacts_only
1460  def _update_run():
1461      request_message = _get_request_message(
1462          UpdateRun(),
1463          schema={
1464              "run_id": [_assert_required, _assert_string],
1465              "end_time": [_assert_intlike],
1466              "status": [_assert_string],
1467              "run_name": [_assert_string],
1468          },
1469      )
1470      run_id = request_message.run_id or request_message.run_uuid
1471      run_name = request_message.run_name if request_message.HasField("run_name") else None
1472      end_time = request_message.end_time if request_message.HasField("end_time") else None
1473      status = request_message.status if request_message.HasField("status") else None
1474      updated_info = _get_tracking_store().update_run_info(run_id, status, end_time, run_name)
1475      response_message = UpdateRun.Response(run_info=updated_info.to_proto())
1476      response = Response(mimetype="application/json")
1477      response.set_data(message_to_json(response_message))
1478      return response
1479  
1480  
1481  @catch_mlflow_exception
1482  @_disable_if_artifacts_only
1483  def _delete_run():
1484      request_message = _get_request_message(
1485          DeleteRun(), schema={"run_id": [_assert_required, _assert_string]}
1486      )
1487      _get_tracking_store().delete_run(request_message.run_id)
1488      response_message = DeleteRun.Response()
1489      response = Response(mimetype="application/json")
1490      response.set_data(message_to_json(response_message))
1491      return response
1492  
1493  
1494  @catch_mlflow_exception
1495  @_disable_if_artifacts_only
1496  def _restore_run():
1497      request_message = _get_request_message(
1498          RestoreRun(), schema={"run_id": [_assert_required, _assert_string]}
1499      )
1500      _get_tracking_store().restore_run(request_message.run_id)
1501      response_message = RestoreRun.Response()
1502      response = Response(mimetype="application/json")
1503      response.set_data(message_to_json(response_message))
1504      return response
1505  
1506  
1507  @catch_mlflow_exception
1508  @_disable_if_artifacts_only
1509  def _log_metric():
1510      request_message = _get_request_message(
1511          LogMetric(),
1512          schema={
1513              "run_id": [_assert_required, _assert_string],
1514              "key": [_assert_required, _assert_string],
1515              "value": [_assert_required, _assert_floatlike],
1516              "timestamp": [_assert_intlike, _assert_required],
1517              "step": [_assert_intlike],
1518              "model_id": [_assert_string],
1519              "dataset_name": [_assert_string],
1520              "dataset_digest": [_assert_string],
1521          },
1522      )
1523      metric = Metric(
1524          request_message.key,
1525          request_message.value,
1526          request_message.timestamp,
1527          request_message.step,
1528          request_message.model_id or None,
1529          request_message.dataset_name or None,
1530          request_message.dataset_digest or None,
1531          request_message.run_id or None,
1532      )
1533      run_id = request_message.run_id or request_message.run_uuid
1534      _get_tracking_store().log_metric(run_id, metric)
1535      response_message = LogMetric.Response()
1536      response = Response(mimetype="application/json")
1537      response.set_data(message_to_json(response_message))
1538      return response
1539  
1540  
1541  @catch_mlflow_exception
1542  @_disable_if_artifacts_only
1543  def _log_param():
1544      request_message = _get_request_message(
1545          LogParam(),
1546          schema={
1547              "run_id": [_assert_required, _assert_string],
1548              "key": [_assert_required, _assert_string],
1549              "value": [_assert_string],
1550          },
1551      )
1552      param = Param(request_message.key, request_message.value)
1553      run_id = request_message.run_id or request_message.run_uuid
1554      _get_tracking_store().log_param(run_id, param)
1555      response_message = LogParam.Response()
1556      response = Response(mimetype="application/json")
1557      response.set_data(message_to_json(response_message))
1558      return response
1559  
1560  
1561  @catch_mlflow_exception
1562  @_disable_if_artifacts_only
1563  def _log_inputs():
1564      request_message = _get_request_message(
1565          LogInputs(),
1566          schema={
1567              "run_id": [_assert_required, _assert_string],
1568              "datasets": [_assert_array],
1569              "models": [_assert_array],
1570          },
1571      )
1572      run_id = request_message.run_id
1573      datasets = [
1574          DatasetInput.from_proto(proto_dataset_input)
1575          for proto_dataset_input in request_message.datasets
1576      ]
1577      models = (
1578          [
1579              LoggedModelInput.from_proto(proto_logged_model_input)
1580              for proto_logged_model_input in request_message.models
1581          ]
1582          if request_message.models
1583          else None
1584      )
1585  
1586      _get_tracking_store().log_inputs(run_id, datasets=datasets, models=models)
1587      response_message = LogInputs.Response()
1588      response = Response(mimetype="application/json")
1589      response.set_data(message_to_json(response_message))
1590      return response
1591  
1592  
1593  @catch_mlflow_exception
1594  @_disable_if_artifacts_only
1595  def _log_outputs():
1596      request_message = _get_request_message(
1597          LogOutputs(),
1598          schema={
1599              "run_id": [_assert_required, _assert_string],
1600              "models": [_assert_required, _assert_array],
1601          },
1602      )
1603      models = [LoggedModelOutput.from_proto(p) for p in request_message.models]
1604      _get_tracking_store().log_outputs(run_id=request_message.run_id, models=models)
1605      response_message = LogOutputs.Response()
1606      return _wrap_response(response_message)
1607  
1608  
1609  @catch_mlflow_exception
1610  @_disable_if_artifacts_only
1611  def _set_experiment_tag():
1612      request_message = _get_request_message(
1613          SetExperimentTag(),
1614          schema={
1615              "experiment_id": [_assert_required, _assert_string],
1616              "key": [_assert_required, _assert_string],
1617              "value": [_assert_string],
1618          },
1619      )
1620      tag = ExperimentTag(request_message.key, request_message.value)
1621      _get_tracking_store().set_experiment_tag(request_message.experiment_id, tag)
1622      response_message = SetExperimentTag.Response()
1623      response = Response(mimetype="application/json")
1624      response.set_data(message_to_json(response_message))
1625      return response
1626  
1627  
1628  @catch_mlflow_exception
1629  @_disable_if_artifacts_only
1630  def _delete_experiment_tag():
1631      request_message = _get_request_message(
1632          DeleteExperimentTag(),
1633          schema={
1634              "experiment_id": [_assert_required, _assert_string],
1635              "key": [_assert_required, _assert_string],
1636          },
1637      )
1638      _get_tracking_store().delete_experiment_tag(request_message.experiment_id, request_message.key)
1639      response_message = DeleteExperimentTag.Response()
1640      response = Response(mimetype="application/json")
1641      response.set_data(message_to_json(response_message))
1642      return response
1643  
1644  
1645  @catch_mlflow_exception
1646  @_disable_if_artifacts_only
1647  def _set_tag():
1648      request_message = _get_request_message(
1649          SetTag(),
1650          schema={
1651              "run_id": [_assert_required, _assert_string],
1652              "key": [_assert_required, _assert_string],
1653              "value": [_assert_string],
1654          },
1655      )
1656      tag = RunTag(request_message.key, request_message.value)
1657      run_id = request_message.run_id or request_message.run_uuid
1658      _get_tracking_store().set_tag(run_id, tag)
1659      response_message = SetTag.Response()
1660      response = Response(mimetype="application/json")
1661      response.set_data(message_to_json(response_message))
1662      return response
1663  
1664  
1665  @catch_mlflow_exception
1666  @_disable_if_artifacts_only
1667  def _delete_tag():
1668      request_message = _get_request_message(
1669          DeleteTag(),
1670          schema={
1671              "run_id": [_assert_required, _assert_string],
1672              "key": [_assert_required, _assert_string],
1673          },
1674      )
1675      _get_tracking_store().delete_tag(request_message.run_id, request_message.key)
1676      response_message = DeleteTag.Response()
1677      response = Response(mimetype="application/json")
1678      response.set_data(message_to_json(response_message))
1679      return response
1680  
1681  
1682  @catch_mlflow_exception
1683  @_disable_if_artifacts_only
1684  def _get_run():
1685      request_message = _get_request_message(
1686          GetRun(), schema={"run_id": [_assert_required, _assert_string]}
1687      )
1688      response_message = get_run_impl(request_message)
1689      response = Response(mimetype="application/json")
1690      response.set_data(message_to_json(response_message))
1691      return response
1692  
1693  
1694  def get_run_impl(request_message):
1695      response_message = GetRun.Response()
1696      run_id = request_message.run_id or request_message.run_uuid
1697      response_message.run.MergeFrom(_get_tracking_store().get_run(run_id).to_proto())
1698      return response_message
1699  
1700  
1701  @catch_mlflow_exception
1702  @_disable_if_artifacts_only
1703  def _search_runs():
1704      request_message = _get_request_message(
1705          SearchRuns(),
1706          schema={
1707              "experiment_ids": [_assert_array],
1708              "filter": [_assert_string],
1709              "max_results": [
1710                  _assert_intlike,
1711                  lambda x: _assert_less_than_or_equal(int(x), 50000),
1712              ],
1713              "order_by": [_assert_array, _assert_item_type_string],
1714          },
1715      )
1716      response_message = search_runs_impl(request_message)
1717      response = Response(mimetype="application/json")
1718      response.set_data(message_to_json(response_message))
1719      return response
1720  
1721  
1722  def search_runs_impl(request_message):
1723      response_message = SearchRuns.Response()
1724      run_view_type = ViewType.ACTIVE_ONLY
1725      if request_message.HasField("run_view_type"):
1726          run_view_type = ViewType.from_proto(request_message.run_view_type)
1727      filter_string = request_message.filter
1728      max_results = request_message.max_results
1729      experiment_ids = list(request_message.experiment_ids)
1730  
1731      # NB: Local import to avoid circular dependency (auth imports from handlers)
1732      try:
1733          from mlflow.server import auth
1734  
1735          if auth.auth_config:
1736              experiment_ids = auth.filter_experiment_ids(experiment_ids)
1737      except ImportError:
1738          # Auth module not available (Flask-WTF not installed), skip filtering
1739          pass
1740  
1741      order_by = request_message.order_by
1742      run_entities = _get_tracking_store().search_runs(
1743          experiment_ids=experiment_ids,
1744          filter_string=filter_string,
1745          run_view_type=run_view_type,
1746          max_results=max_results,
1747          order_by=order_by,
1748          page_token=request_message.page_token or None,
1749      )
1750      response_message.runs.extend([r.to_proto() for r in run_entities])
1751      if run_entities.token:
1752          response_message.next_page_token = run_entities.token
1753      return response_message
1754  
1755  
1756  @catch_mlflow_exception
1757  @_disable_if_artifacts_only
1758  def _list_artifacts():
1759      request_message = _get_request_message(
1760          ListArtifacts(),
1761          schema={
1762              "run_id": [_assert_string, _assert_required],
1763              "path": [_assert_string],
1764              "page_token": [_assert_string],
1765          },
1766      )
1767      response_message = list_artifacts_impl(request_message)
1768      response = Response(mimetype="application/json")
1769      response.set_data(message_to_json(response_message))
1770      return response
1771  
1772  
1773  def list_artifacts_impl(request_message):
1774      response_message = ListArtifacts.Response()
1775      if request_message.HasField("path"):
1776          path = request_message.path
1777          path = validate_path_is_safe(path)
1778      else:
1779          path = None
1780      run_id = request_message.run_id or request_message.run_uuid
1781      run = _get_tracking_store().get_run(run_id)
1782  
1783      if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
1784          artifact_entities = _list_artifacts_for_proxied_run_artifact_root(
1785              proxied_artifact_root=run.info.artifact_uri,
1786              relative_path=path,
1787          )
1788      else:
1789          artifact_entities = _get_artifact_repo(run).list_artifacts(path)
1790  
1791      response_message.files.extend([a.to_proto() for a in artifact_entities])
1792      response_message.root_uri = run.info.artifact_uri
1793      return response_message
1794  
1795  
1796  def _list_artifacts_for_proxied_run_artifact_root(proxied_artifact_root, relative_path=None):
1797      """
1798      Lists artifacts from the specified ``relative_path`` within the specified proxied Run artifact
1799      root (i.e. a Run artifact root with scheme ``http``, ``https``, or ``mlflow-artifacts``).
1800  
1801      Args:
1802          proxied_artifact_root: The Run artifact root location (URI) with scheme ``http``,
1803                                 ``https``, or ``mlflow-artifacts`` that can be resolved by the
1804                                 MLflow server to a concrete storage location.
1805          relative_path: The relative path within the specified ``proxied_artifact_root`` under
1806                         which to list artifact contents. If ``None``, artifacts are listed from
1807                         the ``proxied_artifact_root`` directory.
1808      """
1809      parsed_proxied_artifact_root = urllib.parse.urlparse(proxied_artifact_root)
1810      assert parsed_proxied_artifact_root.scheme in ["http", "https", "mlflow-artifacts"]
1811  
1812      artifact_destination_repo = _get_artifact_repo_mlflow_artifacts()
1813      artifact_destination_path = _get_proxied_run_artifact_destination_path(
1814          proxied_artifact_root=proxied_artifact_root,
1815          relative_path=relative_path,
1816      )
1817      artifact_destination_path = _get_workspace_scoped_repo_path_if_enabled(
1818          artifact_destination_path
1819      )
1820  
1821      artifact_entities = []
1822      for file_info in artifact_destination_repo.list_artifacts(artifact_destination_path):
1823          basename = posixpath.basename(file_info.path)
1824          run_relative_artifact_path = (
1825              posixpath.join(relative_path, basename) if relative_path else basename
1826          )
1827          artifact_entities.append(
1828              FileInfo(run_relative_artifact_path, file_info.is_dir, file_info.file_size)
1829          )
1830  
1831      return artifact_entities
1832  
1833  
1834  @catch_mlflow_exception
1835  @_disable_if_artifacts_only
1836  def _get_metric_history():
1837      request_message = _get_request_message(
1838          GetMetricHistory(),
1839          schema={
1840              "run_id": [_assert_string, _assert_required],
1841              "metric_key": [_assert_string, _assert_required],
1842              "page_token": [_assert_string],
1843          },
1844      )
1845      response_message = GetMetricHistory.Response()
1846      run_id = request_message.run_id or request_message.run_uuid
1847  
1848      max_results = request_message.max_results if request_message.max_results is not None else None
1849  
1850      metric_entities = _get_tracking_store().get_metric_history(
1851          run_id,
1852          request_message.metric_key,
1853          max_results=max_results,
1854          page_token=request_message.page_token or None,
1855      )
1856      response_message.metrics.extend([m.to_proto() for m in metric_entities])
1857  
1858      # Set next_page_token if available
1859      if next_page_token := metric_entities.token:
1860          response_message.next_page_token = next_page_token
1861  
1862      response = Response(mimetype="application/json")
1863      response.set_data(message_to_json(response_message))
1864      return response
1865  
1866  
1867  @catch_mlflow_exception
1868  @_disable_if_artifacts_only
1869  def get_metric_history_bulk_handler():
1870      MAX_HISTORY_RESULTS = 25000
1871      MAX_RUN_IDS_PER_REQUEST = 100
1872      run_ids = request.args.to_dict(flat=False).get("run_id", [])
1873      if not run_ids:
1874          raise MlflowException(
1875              message="GetMetricHistoryBulk request must specify at least one run_id.",
1876              error_code=INVALID_PARAMETER_VALUE,
1877          )
1878      if len(run_ids) > MAX_RUN_IDS_PER_REQUEST:
1879          raise MlflowException(
1880              message=(
1881                  f"GetMetricHistoryBulk request cannot specify more than {MAX_RUN_IDS_PER_REQUEST}"
1882                  f" run_ids. Received {len(run_ids)} run_ids."
1883              ),
1884              error_code=INVALID_PARAMETER_VALUE,
1885          )
1886  
1887      metric_key = request.args.get("metric_key")
1888      if metric_key is None:
1889          raise MlflowException(
1890              message="GetMetricHistoryBulk request must specify a metric_key.",
1891              error_code=INVALID_PARAMETER_VALUE,
1892          )
1893  
1894      max_results = int(request.args.get("max_results", MAX_HISTORY_RESULTS))
1895      max_results = min(max_results, MAX_HISTORY_RESULTS)
1896  
1897      store = _get_tracking_store()
1898  
1899      def _default_history_bulk_impl():
1900          metrics_with_run_ids = []
1901          for run_id in sorted(run_ids):
1902              metrics_for_run = sorted(
1903                  store.get_metric_history(
1904                      run_id=run_id,
1905                      metric_key=metric_key,
1906                      max_results=max_results,
1907                  ),
1908                  key=lambda metric: (metric.timestamp, metric.step, metric.value),
1909              )
1910              metrics_with_run_ids.extend([
1911                  {
1912                      "key": metric.key,
1913                      "value": metric.value,
1914                      "timestamp": metric.timestamp,
1915                      "step": metric.step,
1916                      "run_id": run_id,
1917                  }
1918                  for metric in metrics_for_run
1919              ])
1920          return metrics_with_run_ids
1921  
1922      if hasattr(store, "get_metric_history_bulk"):
1923          metrics_with_run_ids = [
1924              metric.to_dict()
1925              for metric in store.get_metric_history_bulk(
1926                  run_ids=run_ids,
1927                  metric_key=metric_key,
1928                  max_results=max_results,
1929              )
1930          ]
1931      else:
1932          metrics_with_run_ids = _default_history_bulk_impl()
1933  
1934      return {
1935          "metrics": metrics_with_run_ids[:max_results],
1936      }
1937  
1938  
1939  @catch_mlflow_exception
1940  @_disable_if_artifacts_only
1941  def get_metric_history_bulk_interval_handler():
1942      request_message = _get_request_message(
1943          GetMetricHistoryBulkInterval(),
1944          schema={
1945              "run_ids": [
1946                  _assert_required,
1947                  _assert_array,
1948                  _assert_item_type_string,
1949                  lambda x: _assert_less_than_or_equal(
1950                      len(x),
1951                      MAX_RUNS_GET_METRIC_HISTORY_BULK,
1952                      message=f"GetMetricHistoryBulkInterval request must specify at most "
1953                      f"{MAX_RUNS_GET_METRIC_HISTORY_BULK} run_ids. Received {len(x)} run_ids.",
1954                  ),
1955              ],
1956              "metric_key": [_assert_required, _assert_string],
1957              "start_step": [_assert_intlike],
1958              "end_step": [_assert_intlike],
1959              "max_results": [
1960                  _assert_intlike,
1961                  lambda x: _assert_intlike_within_range(
1962                      int(x),
1963                      1,
1964                      MAX_RESULTS_PER_RUN,
1965                      message=f"max_results must be between 1 and {MAX_RESULTS_PER_RUN}.",
1966                  ),
1967              ],
1968          },
1969      )
1970      response_message = get_metric_history_bulk_interval_impl(request_message)
1971      response = Response(mimetype="application/json")
1972      response.set_data(message_to_json(response_message))
1973      return response
1974  
1975  
1976  def get_metric_history_bulk_interval_impl(request_message):
1977      args = request.args
1978      run_ids = request_message.run_ids
1979      metric_key = request_message.metric_key
1980      max_results = int(args.get("max_results", MAX_RESULTS_PER_RUN))
1981      start_step = args.get("start_step")
1982      end_step = args.get("end_step")
1983      if start_step is not None and end_step is not None:
1984          start_step = int(start_step)
1985          end_step = int(end_step)
1986          if start_step > end_step:
1987              raise MlflowException.invalid_parameter_value(
1988                  "end_step must be greater than start_step. "
1989                  f"Found start_step={start_step} and end_step={end_step}."
1990              )
1991      elif start_step is not None or end_step is not None:
1992          raise MlflowException.invalid_parameter_value(
1993              "If either start step or end step are specified, both must be specified."
1994          )
1995  
1996      store = _get_tracking_store()
1997      metrics_with_run_ids = store.get_metric_history_bulk_interval(
1998          run_ids=run_ids,
1999          metric_key=metric_key,
2000          max_results=max_results,
2001          start_step=start_step,
2002          end_step=end_step,
2003      )
2004  
2005      response_message = GetMetricHistoryBulkInterval.Response()
2006      response_message.metrics.extend([m.to_proto() for m in metrics_with_run_ids])
2007      return response_message
2008  
2009  
2010  @catch_mlflow_exception
2011  @_disable_if_artifacts_only
2012  def _search_datasets_handler():
2013      request_message = _get_request_message(
2014          SearchDatasets(),
2015      )
2016      response_message = search_datasets_impl(request_message)
2017      response = Response(mimetype="application/json")
2018      response.set_data(message_to_json(response_message))
2019      return response
2020  
2021  
2022  def search_datasets_impl(request_message):
2023      MAX_EXPERIMENT_IDS_PER_REQUEST = 20
2024      _validate_content_type(request, ["application/json"])
2025      experiment_ids = request_message.experiment_ids or []
2026      if not experiment_ids:
2027          raise MlflowException(
2028              message="SearchDatasets request must specify at least one experiment_id.",
2029              error_code=INVALID_PARAMETER_VALUE,
2030          )
2031      if len(experiment_ids) > MAX_EXPERIMENT_IDS_PER_REQUEST:
2032          raise MlflowException(
2033              message=(
2034                  f"SearchDatasets request cannot specify more than {MAX_EXPERIMENT_IDS_PER_REQUEST}"
2035                  f" experiment_ids. Received {len(experiment_ids)} experiment_ids."
2036              ),
2037              error_code=INVALID_PARAMETER_VALUE,
2038          )
2039  
2040      store = _get_tracking_store()
2041  
2042      if hasattr(store, "_search_datasets"):
2043          response_message = SearchDatasets.Response()
2044          response_message.dataset_summaries.extend([
2045              summary.to_proto() for summary in store._search_datasets(experiment_ids)
2046          ])
2047          return response_message
2048      else:
2049          return _not_implemented()
2050  
2051  
2052  def _validate_gateway_path(method: str, gateway_path: str) -> None:
2053      if not gateway_path:
2054          raise MlflowException(
2055              message="Deployments proxy request must specify a gateway_path.",
2056              error_code=INVALID_PARAMETER_VALUE,
2057          )
2058      elif method == "GET":
2059          if gateway_path.strip("/") != "api/2.0/endpoints":
2060              raise MlflowException(
2061                  message=f"Invalid gateway_path: {gateway_path} for method: {method}",
2062                  error_code=INVALID_PARAMETER_VALUE,
2063              )
2064      elif method == "POST":
2065          # For POST, gateway_path must be in the form of "gateway/{name}/invocations"
2066          if not re.fullmatch(r"gateway/[^/]+/invocations", gateway_path.strip("/")):
2067              raise MlflowException(
2068                  message=f"Invalid gateway_path: {gateway_path} for method: {method}",
2069                  error_code=INVALID_PARAMETER_VALUE,
2070              )
2071  
2072  
2073  @catch_mlflow_exception
2074  def gateway_proxy_handler():
2075      target_uri = MLFLOW_DEPLOYMENTS_TARGET.get()
2076      if not target_uri:
2077          # Pretend an empty gateway service is running
2078          return {"endpoints": []}
2079  
2080      args = request.args if request.method == "GET" else request.json
2081      gateway_path = args.get("gateway_path")
2082      _validate_gateway_path(request.method, gateway_path)
2083      json_data = args.get("json_data", None)
2084      response = requests.request(request.method, f"{target_uri}/{gateway_path}", json=json_data)
2085      if response.status_code == 200:
2086          return response.json()
2087      else:
2088          raise MlflowException(
2089              message=f"Deployments proxy request failed with error code {response.status_code}. "
2090              f"Error message: {response.text}",
2091              error_code=response.status_code,
2092          )
2093  
2094  
2095  @catch_mlflow_exception
2096  @_disable_if_artifacts_only
2097  def create_promptlab_run_handler():
2098      def assert_arg_exists(arg_name, arg):
2099          if not arg:
2100              raise MlflowException(
2101                  message=f"CreatePromptlabRun request must specify {arg_name}.",
2102                  error_code=INVALID_PARAMETER_VALUE,
2103              )
2104  
2105      _validate_content_type(request, ["application/json"])
2106  
2107      args = request.json
2108      experiment_id = args.get("experiment_id")
2109      assert_arg_exists("experiment_id", experiment_id)
2110      run_name = args.get("run_name", None)
2111      tags = args.get("tags", [])
2112      prompt_template = args.get("prompt_template")
2113      assert_arg_exists("prompt_template", prompt_template)
2114      raw_prompt_parameters = args.get("prompt_parameters")
2115      assert_arg_exists("prompt_parameters", raw_prompt_parameters)
2116      prompt_parameters = [
2117          Param(param.get("key"), param.get("value")) for param in args.get("prompt_parameters")
2118      ]
2119      model_route = args.get("model_route")
2120      assert_arg_exists("model_route", model_route)
2121      raw_model_parameters = args.get("model_parameters", [])
2122      model_parameters = [
2123          Param(param.get("key"), param.get("value")) for param in raw_model_parameters
2124      ]
2125      model_input = args.get("model_input")
2126      assert_arg_exists("model_input", model_input)
2127      model_output = args.get("model_output", None)
2128      raw_model_output_parameters = args.get("model_output_parameters", [])
2129      model_output_parameters = [
2130          Param(param.get("key"), param.get("value")) for param in raw_model_output_parameters
2131      ]
2132      mlflow_version = args.get("mlflow_version")
2133      assert_arg_exists("mlflow_version", mlflow_version)
2134      user_id = args.get("user_id", "unknown")
2135  
2136      # use current time if not provided
2137      start_time = args.get("start_time", int(time.time() * 1000))
2138  
2139      store = _get_tracking_store()
2140  
2141      run = _create_promptlab_run_impl(
2142          store,
2143          experiment_id=experiment_id,
2144          run_name=run_name,
2145          tags=tags,
2146          prompt_template=prompt_template,
2147          prompt_parameters=prompt_parameters,
2148          model_route=model_route,
2149          model_parameters=model_parameters,
2150          model_input=model_input,
2151          model_output=model_output,
2152          model_output_parameters=model_output_parameters,
2153          mlflow_version=mlflow_version,
2154          user_id=user_id,
2155          start_time=start_time,
2156      )
2157      response_message = CreateRun.Response()
2158      response_message.run.MergeFrom(run.to_proto())
2159      response = Response(mimetype="application/json")
2160      response.set_data(message_to_json(response_message))
2161      return response
2162  
2163  
2164  @catch_mlflow_exception
2165  def upload_artifact_handler():
2166      args = request.args
2167      run_uuid = args.get("run_uuid")
2168      if not run_uuid:
2169          raise MlflowException(
2170              message="Request must specify run_uuid.",
2171              error_code=INVALID_PARAMETER_VALUE,
2172          )
2173      path = args.get("path")
2174      if not path:
2175          raise MlflowException(
2176              message="Request must specify path.",
2177              error_code=INVALID_PARAMETER_VALUE,
2178          )
2179      path = validate_path_is_safe(path)
2180  
2181      if request.content_length and request.content_length > 10 * 1024 * 1024:
2182          raise MlflowException(
2183              message="Artifact size is too large. Max size is 10MB.",
2184              error_code=INVALID_PARAMETER_VALUE,
2185          )
2186  
2187      data = request.data
2188      if not data:
2189          raise MlflowException(
2190              message="Request must specify data.",
2191              error_code=INVALID_PARAMETER_VALUE,
2192          )
2193  
2194      run = _get_tracking_store().get_run(run_uuid)
2195      artifact_dir = run.info.artifact_uri
2196  
2197      basename = posixpath.basename(path)
2198      dirname = posixpath.dirname(path)
2199  
2200      def _log_artifact_to_repo(file, run, dirname, artifact_dir):
2201          if _is_servable_proxied_run_artifact_root(run.info.artifact_uri):
2202              artifact_repo = _get_artifact_repo_mlflow_artifacts()
2203              # Use posixpath.join since these are logical artifact paths (not local filesystem paths)
2204              # that should always use forward slashes regardless of the platform.
2205              path_to_log = (
2206                  posixpath.join(run.info.experiment_id, run.info.run_id, "artifacts", dirname)
2207                  if dirname
2208                  else posixpath.join(run.info.experiment_id, run.info.run_id, "artifacts")
2209              )
2210              path_to_log = _get_workspace_scoped_repo_path_if_enabled(path_to_log)
2211          else:
2212              artifact_repo = get_artifact_repository(artifact_dir)
2213              path_to_log = dirname
2214  
2215          artifact_repo.log_artifact(file, path_to_log)
2216  
2217      with tempfile.TemporaryDirectory() as tmpdir:
2218          dir_path = os.path.join(tmpdir, dirname) if dirname else tmpdir
2219          file_path = os.path.join(dir_path, basename)
2220  
2221          os.makedirs(dir_path, exist_ok=True)
2222  
2223          with open(file_path, "wb") as f:
2224              f.write(data)
2225  
2226          _log_artifact_to_repo(file_path, run, dirname, artifact_dir)
2227  
2228      return Response(mimetype="application/json")
2229  
2230  
2231  @catch_mlflow_exception
2232  @_disable_if_artifacts_only
2233  def _search_experiments():
2234      request_message = _get_request_message(
2235          SearchExperiments(),
2236          schema={
2237              "view_type": [_assert_intlike],
2238              "max_results": [_assert_intlike],
2239              "order_by": [_assert_array],
2240              "filter": [_assert_string],
2241              "page_token": [_assert_string],
2242          },
2243      )
2244  
2245      experiment_entities = _get_tracking_store().search_experiments(
2246          view_type=request_message.view_type,
2247          max_results=request_message.max_results,
2248          order_by=request_message.order_by,
2249          filter_string=request_message.filter,
2250          page_token=request_message.page_token or None,
2251      )
2252      response_message = SearchExperiments.Response()
2253      response_message.experiments.extend([e.to_proto() for e in experiment_entities])
2254      if experiment_entities.token:
2255          response_message.next_page_token = experiment_entities.token
2256      response = Response(mimetype="application/json")
2257      response.set_data(message_to_json(response_message))
2258      return response
2259  
2260  
2261  @catch_mlflow_exception
2262  def _get_artifact_repo(run):
2263      return get_artifact_repository(run.info.artifact_uri)
2264  
2265  
2266  @catch_mlflow_exception
2267  @_disable_if_artifacts_only
2268  def _log_batch():
2269      def _assert_metrics_fields_present(metrics):
2270          for idx, m in enumerate(metrics):
2271              _assert_required(m.get("key"), path=f"metrics[{idx}].key")
2272              _assert_required(m.get("value"), path=f"metrics[{idx}].value")
2273              _assert_required(m.get("timestamp"), path=f"metrics[{idx}].timestamp")
2274  
2275      def _assert_params_fields_present(params):
2276          for idx, param in enumerate(params):
2277              _assert_required(param.get("key"), path=f"params[{idx}].key")
2278  
2279      def _assert_tags_fields_present(tags):
2280          for idx, tag in enumerate(tags):
2281              _assert_required(tag.get("key"), path=f"tags[{idx}].key")
2282  
2283      _validate_batch_log_api_req(_get_request_json())
2284      request_message = _get_request_message(
2285          LogBatch(),
2286          schema={
2287              "run_id": [_assert_string, _assert_required],
2288              "metrics": [_assert_array, _assert_metrics_fields_present],
2289              "params": [_assert_array, _assert_params_fields_present],
2290              "tags": [_assert_array, _assert_tags_fields_present],
2291          },
2292      )
2293      metrics = [Metric.from_proto(proto_metric) for proto_metric in request_message.metrics]
2294      params = [Param.from_proto(proto_param) for proto_param in request_message.params]
2295      tags = [RunTag.from_proto(proto_tag) for proto_tag in request_message.tags]
2296      _get_tracking_store().log_batch(
2297          run_id=request_message.run_id, metrics=metrics, params=params, tags=tags
2298      )
2299      response_message = LogBatch.Response()
2300      response = Response(mimetype="application/json")
2301      response.set_data(message_to_json(response_message))
2302      return response
2303  
2304  
2305  @catch_mlflow_exception
2306  @_disable_if_artifacts_only
2307  def _log_model():
2308      request_message = _get_request_message(
2309          LogModel(),
2310          schema={
2311              "run_id": [_assert_string, _assert_required],
2312              "model_json": [_assert_string, _assert_required],
2313          },
2314      )
2315      try:
2316          model = json.loads(request_message.model_json)
2317      except Exception:
2318          raise MlflowException(
2319              f"Malformed model info. \n {request_message.model_json} \n is not a valid JSON.",
2320              error_code=INVALID_PARAMETER_VALUE,
2321          )
2322  
2323      missing_fields = {"artifact_path", "flavors", "utc_time_created", "run_id"} - set(model.keys())
2324  
2325      if missing_fields:
2326          raise MlflowException(
2327              f"Model json is missing mandatory fields: {missing_fields}",
2328              error_code=INVALID_PARAMETER_VALUE,
2329          )
2330      _get_tracking_store().record_logged_model(
2331          run_id=request_message.run_id, mlflow_model=Model.from_dict(model)
2332      )
2333      response_message = LogModel.Response()
2334      response = Response(mimetype="application/json")
2335      response.set_data(message_to_json(response_message))
2336      return response
2337  
2338  
2339  def _wrap_response(response_message):
2340      response = Response(mimetype="application/json")
2341      response.set_data(message_to_json(response_message))
2342      return response
2343  
2344  
2345  # Model Registry APIs
2346  
2347  
2348  @catch_mlflow_exception
2349  @_disable_if_artifacts_only
2350  def _create_registered_model():
2351      request_message = _get_request_message(
2352          CreateRegisteredModel(),
2353          schema={
2354              "name": [_assert_string, _assert_required],
2355              "tags": [_assert_array],
2356              "description": [_assert_string],
2357          },
2358      )
2359      store = _get_model_registry_store()
2360      registered_model = store.create_registered_model(
2361          name=request_message.name,
2362          tags=request_message.tags,
2363          description=request_message.description,
2364      )
2365      response_message = CreateRegisteredModel.Response(registered_model=registered_model.to_proto())
2366  
2367      # Determine if this is a prompt based on the tags
2368      if _is_prompt_request(request_message):
2369          # Send prompt creation webhook
2370          deliver_webhook(
2371              event=WebhookEvent(WebhookEntity.PROMPT, WebhookAction.CREATED),
2372              payload=PromptCreatedPayload(
2373                  name=request_message.name,
2374                  tags={
2375                      t.key: t.value
2376                      for t in request_message.tags
2377                      if t.key not in {IS_PROMPT_TAG_KEY, PROMPT_TYPE_TAG_KEY}
2378                  },
2379                  description=request_message.description,
2380              ),
2381              store=store,
2382          )
2383      else:
2384          # Send regular model creation webhook
2385          deliver_webhook(
2386              event=WebhookEvent(WebhookEntity.REGISTERED_MODEL, WebhookAction.CREATED),
2387              payload=RegisteredModelCreatedPayload(
2388                  name=request_message.name,
2389                  tags={t.key: t.value for t in request_message.tags},
2390                  description=request_message.description,
2391              ),
2392              store=store,
2393          )
2394  
2395      return _wrap_response(response_message)
2396  
2397  
2398  @catch_mlflow_exception
2399  @_disable_if_artifacts_only
2400  def _get_registered_model():
2401      request_message = _get_request_message(
2402          GetRegisteredModel(), schema={"name": [_assert_string, _assert_required]}
2403      )
2404      registered_model = _get_model_registry_store().get_registered_model(name=request_message.name)
2405      response_message = GetRegisteredModel.Response(registered_model=registered_model.to_proto())
2406      return _wrap_response(response_message)
2407  
2408  
2409  @catch_mlflow_exception
2410  @_disable_if_artifacts_only
2411  def _update_registered_model():
2412      request_message = _get_request_message(
2413          UpdateRegisteredModel(),
2414          schema={
2415              "name": [_assert_string, _assert_required],
2416              "description": [_assert_string],
2417          },
2418      )
2419      name = request_message.name
2420      new_description = request_message.description
2421      registered_model = _get_model_registry_store().update_registered_model(
2422          name=name, description=new_description
2423      )
2424      response_message = UpdateRegisteredModel.Response(registered_model=registered_model.to_proto())
2425      return _wrap_response(response_message)
2426  
2427  
2428  @catch_mlflow_exception
2429  @_disable_if_artifacts_only
2430  def _rename_registered_model():
2431      request_message = _get_request_message(
2432          RenameRegisteredModel(),
2433          schema={
2434              "name": [_assert_string, _assert_required],
2435              "new_name": [_assert_string, _assert_required],
2436          },
2437      )
2438      name = request_message.name
2439      new_name = request_message.new_name
2440      registered_model = _get_model_registry_store().rename_registered_model(
2441          name=name, new_name=new_name
2442      )
2443      response_message = RenameRegisteredModel.Response(registered_model=registered_model.to_proto())
2444      return _wrap_response(response_message)
2445  
2446  
2447  @catch_mlflow_exception
2448  @_disable_if_artifacts_only
2449  def _delete_registered_model():
2450      request_message = _get_request_message(
2451          DeleteRegisteredModel(), schema={"name": [_assert_string, _assert_required]}
2452      )
2453      _get_model_registry_store().delete_registered_model(name=request_message.name)
2454      return _wrap_response(DeleteRegisteredModel.Response())
2455  
2456  
2457  @catch_mlflow_exception
2458  @_disable_if_artifacts_only
2459  def _search_registered_models():
2460      request_message = _get_request_message(
2461          SearchRegisteredModels(),
2462          schema={
2463              "filter": [_assert_string],
2464              "max_results": [
2465                  _assert_intlike,
2466                  lambda x: _assert_less_than_or_equal(int(x), 1000),
2467              ],
2468              "order_by": [_assert_array, _assert_item_type_string],
2469              "page_token": [_assert_string],
2470          },
2471      )
2472      store = _get_model_registry_store()
2473      registered_models = store.search_registered_models(
2474          filter_string=request_message.filter,
2475          max_results=request_message.max_results,
2476          order_by=request_message.order_by,
2477          page_token=request_message.page_token or None,
2478      )
2479      response_message = SearchRegisteredModels.Response()
2480      response_message.registered_models.extend([e.to_proto() for e in registered_models])
2481      if registered_models.token:
2482          response_message.next_page_token = registered_models.token
2483      return _wrap_response(response_message)
2484  
2485  
2486  @catch_mlflow_exception
2487  @_disable_if_artifacts_only
2488  def _get_latest_versions():
2489      request_message = _get_request_message(
2490          GetLatestVersions(),
2491          schema={
2492              "name": [_assert_string, _assert_required],
2493              "stages": [_assert_array, _assert_item_type_string],
2494          },
2495      )
2496      latest_versions = _get_model_registry_store().get_latest_versions(
2497          name=request_message.name, stages=request_message.stages
2498      )
2499      response_message = GetLatestVersions.Response()
2500      response_message.model_versions.extend([e.to_proto() for e in latest_versions])
2501      return _wrap_response(response_message)
2502  
2503  
2504  @catch_mlflow_exception
2505  @_disable_if_artifacts_only
2506  def _set_registered_model_tag():
2507      request_message = _get_request_message(
2508          SetRegisteredModelTag(),
2509          schema={
2510              "name": [_assert_string, _assert_required],
2511              "key": [_assert_string, _assert_required],
2512              "value": [_assert_string],
2513          },
2514      )
2515      tag = RegisteredModelTag(key=request_message.key, value=request_message.value)
2516      store = _get_model_registry_store()
2517      store.set_registered_model_tag(name=request_message.name, tag=tag)
2518  
2519      if _is_prompt(request_message.name):
2520          # Send prompt tag set webhook
2521          deliver_webhook(
2522              event=WebhookEvent(WebhookEntity.PROMPT_TAG, WebhookAction.SET),
2523              payload=PromptTagSetPayload(
2524                  name=request_message.name,
2525                  key=request_message.key,
2526                  value=request_message.value,
2527              ),
2528              store=store,
2529          )
2530  
2531      return _wrap_response(SetRegisteredModelTag.Response())
2532  
2533  
2534  @catch_mlflow_exception
2535  @_disable_if_artifacts_only
2536  def _delete_registered_model_tag():
2537      request_message = _get_request_message(
2538          DeleteRegisteredModelTag(),
2539          schema={
2540              "name": [_assert_string, _assert_required],
2541              "key": [_assert_string, _assert_required],
2542          },
2543      )
2544      store = _get_model_registry_store()
2545      store.delete_registered_model_tag(name=request_message.name, key=request_message.key)
2546  
2547      if _is_prompt(request_message.name):
2548          # Send prompt tag deleted webhook
2549          deliver_webhook(
2550              event=WebhookEvent(WebhookEntity.PROMPT_TAG, WebhookAction.DELETED),
2551              payload=PromptTagDeletedPayload(
2552                  name=request_message.name,
2553                  key=request_message.key,
2554              ),
2555              store=store,
2556          )
2557  
2558      return _wrap_response(DeleteRegisteredModelTag.Response())
2559  
2560  
2561  def _validate_non_local_source_contains_relative_paths(source: str):
2562      """
2563      Validation check to ensure that sources that are provided that conform to the schemes:
2564      http, https, or mlflow-artifacts do not contain relative path designations that are intended
2565      to access local file system paths on the tracking server.
2566  
2567      Example paths that this validation function is intended to find and raise an Exception if
2568      passed:
2569      "mlflow-artifacts://host:port/../../../../"
2570      "http://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
2571      "https://host:port/api/2.0/mlflow-artifacts/artifacts/../../../../"
2572      "/models/artifacts/../../../"
2573      "s3:/my_bucket/models/path/../../other/path"
2574      "file://path/to/../../../../some/where/you/should/not/be"
2575      "mlflow-artifacts://host:port/..%2f..%2f..%2f..%2f"
2576      "http://host:port/api/2.0/mlflow-artifacts/artifacts%00"
2577      """
2578      invalid_source_error_message = (
2579          f"Invalid model version source: '{source}'. If supplying a source as an http, https, "
2580          "local file path, ftp, objectstore, or mlflow-artifacts uri, an absolute path must be "
2581          "provided without relative path references present. "
2582          "Please provide an absolute path."
2583      )
2584  
2585      while (unquoted := urllib.parse.unquote_plus(source)) != source:
2586          source = unquoted
2587      source_path = re.sub(r"/+", "/", urllib.parse.urlparse(source).path.rstrip("/"))
2588      if "\x00" in source_path or any(p == ".." for p in source.split("/")):
2589          raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE)
2590      resolved_source = pathlib.Path(source_path).resolve().as_posix()
2591      # NB: drive split is specifically for Windows since WindowsPath.resolve() will append the
2592      # drive path of the pwd to a given path. We don't care about the drive here, though.
2593      _, resolved_path = os.path.splitdrive(resolved_source)
2594  
2595      if resolved_path != source_path:
2596          raise MlflowException(invalid_source_error_message, INVALID_PARAMETER_VALUE)
2597  
2598  
2599  def _validate_source_run(source: str, run_id: str) -> None:
2600      if is_local_uri(source):
2601          if run_id:
2602              store = _get_tracking_store()
2603              run = store.get_run(run_id)
2604              source = pathlib.Path(local_file_uri_to_path(source)).resolve()
2605              if is_local_uri(run.info.artifact_uri):
2606                  run_artifact_dir = pathlib.Path(
2607                      local_file_uri_to_path(run.info.artifact_uri)
2608                  ).resolve()
2609                  if run_artifact_dir in [source, *source.parents]:
2610                      return
2611  
2612          raise MlflowException(
2613              f"Invalid model version source: '{source}'. To use a local path as a model version "
2614              "source, the run_id request parameter has to be specified and the local path has to be "
2615              "contained within the artifact directory of the run specified by the run_id.",
2616              INVALID_PARAMETER_VALUE,
2617          )
2618  
2619      # Checks if relative paths are present in the source (a security threat). If any are present,
2620      # raises an Exception.
2621      _validate_non_local_source_contains_relative_paths(source)
2622  
2623  
2624  def _validate_source_model(source: str, model_id: str) -> None:
2625      if is_local_uri(source):
2626          if model_id:
2627              store = _get_tracking_store()
2628              model = store.get_logged_model(model_id)
2629              source = pathlib.Path(local_file_uri_to_path(source)).resolve()
2630              if is_local_uri(model.artifact_location):
2631                  run_artifact_dir = pathlib.Path(
2632                      local_file_uri_to_path(model.artifact_location)
2633                  ).resolve()
2634                  if run_artifact_dir in [source, *source.parents]:
2635                      return
2636  
2637          raise MlflowException(
2638              f"Invalid model version source: '{source}'. To use a local path as a model version "
2639              "source, the model_id request parameter has to be specified and the local path has to "
2640              "be contained within the artifact directory of the run specified by the model_id.",
2641              INVALID_PARAMETER_VALUE,
2642          )
2643  
2644      # Checks if relative paths are present in the source (a security threat). If any are present,
2645      # raises an Exception.
2646      _validate_non_local_source_contains_relative_paths(source)
2647  
2648  
2649  @catch_mlflow_exception
2650  @_disable_if_artifacts_only
2651  def _create_model_version():
2652      request_message = _get_request_message(
2653          CreateModelVersion(),
2654          schema={
2655              "name": [_assert_string, _assert_required],
2656              "source": [_assert_string, _assert_required],
2657              "run_id": [_assert_string],
2658              "tags": [_assert_array],
2659              "run_link": [_assert_string],
2660              "description": [_assert_string],
2661              "model_id": [_assert_string],
2662          },
2663      )
2664  
2665      if request_message.source and (
2666          regex := MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX.get()
2667      ):
2668          if not re.search(regex, request_message.source):
2669              raise MlflowException(
2670                  f"Invalid model version source: '{request_message.source}'.",
2671                  error_code=INVALID_PARAMETER_VALUE,
2672              )
2673  
2674      is_prompt = _is_prompt_request(request_message)
2675      if is_prompt:
2676          # Prompt sources must not point to local filesystem paths.
2677          # Block file:// URIs and absolute paths (e.g. /etc/passwd) but allow
2678          # the legitimate schemeless placeholder sources used internally
2679          # (e.g. "prompt-template", "dummy-source").
2680          source = request_message.source
2681          parsed = urllib.parse.urlparse(source)
2682          if parsed.scheme == "file" or (parsed.scheme == "" and source.startswith("/")):
2683              raise MlflowException(
2684                  f"Invalid prompt source: '{source}'. "
2685                  "Local source paths are not allowed for prompts.",
2686                  INVALID_PARAMETER_VALUE,
2687              )
2688          # Only validate traversal for sources with a URL scheme (http, https, etc.)
2689          if parsed.scheme:
2690              _validate_non_local_source_contains_relative_paths(source)
2691      else:
2692          if request_message.model_id:
2693              _validate_source_model(request_message.source, request_message.model_id)
2694          else:
2695              _validate_source_run(request_message.source, request_message.run_id)
2696  
2697      store = _get_model_registry_store()
2698      model_version = store.create_model_version(
2699          name=request_message.name,
2700          source=request_message.source,
2701          run_id=request_message.run_id,
2702          run_link=request_message.run_link,
2703          tags=request_message.tags,
2704          description=request_message.description,
2705          model_id=request_message.model_id,
2706      )
2707      if not is_prompt and request_message.model_id:
2708          tracking_store = _get_tracking_store()
2709          tracking_store.set_model_versions_tags(
2710              name=request_message.name,
2711              version=model_version.version,
2712              model_id=request_message.model_id,
2713          )
2714      response_message = CreateModelVersion.Response(model_version=model_version.to_proto())
2715  
2716      if is_prompt:
2717          # Convert tags to dict and extract template text efficiently
2718          tags_dict = {t.key: t.value for t in request_message.tags}
2719          template_text = tags_dict.pop(PROMPT_TEXT_TAG_KEY, None)
2720          # Remove internal prompt identification and type tags
2721          tags_dict.pop(IS_PROMPT_TAG_KEY, None)
2722          tags_dict.pop(PROMPT_TYPE_TAG_KEY, None)
2723  
2724          # Send prompt version creation webhook
2725          deliver_webhook(
2726              event=WebhookEvent(WebhookEntity.PROMPT_VERSION, WebhookAction.CREATED),
2727              payload=PromptVersionCreatedPayload(
2728                  name=request_message.name,
2729                  version=str(model_version.version),
2730                  template=template_text,
2731                  tags=tags_dict,
2732                  description=request_message.description or None,
2733              ),
2734              store=store,
2735          )
2736      else:
2737          # Send regular model version creation webhook
2738          deliver_webhook(
2739              event=WebhookEvent(WebhookEntity.MODEL_VERSION, WebhookAction.CREATED),
2740              payload=ModelVersionCreatedPayload(
2741                  name=request_message.name,
2742                  version=str(model_version.version),
2743                  source=request_message.source,
2744                  run_id=request_message.run_id or None,
2745                  tags={t.key: t.value for t in request_message.tags},
2746                  description=request_message.description or None,
2747              ),
2748              store=store,
2749          )
2750  
2751      return _wrap_response(response_message)
2752  
2753  
2754  def _is_prompt_request(request_message):
2755      return any(tag.key == IS_PROMPT_TAG_KEY for tag in request_message.tags)
2756  
2757  
2758  def _is_prompt(name: str) -> bool:
2759      rm = _get_model_registry_store().get_registered_model(name=name)
2760      return rm._is_prompt()
2761  
2762  
2763  @catch_mlflow_exception
2764  @_disable_if_artifacts_only
2765  def get_model_version_artifact_handler():
2766      name = request.args.get("name")
2767      version = request.args.get("version")
2768      path = request.args["path"]
2769      path = validate_path_is_safe(path)
2770      artifact_uri = _get_model_registry_store().get_model_version_download_uri(name, version)
2771      if _is_servable_proxied_run_artifact_root(artifact_uri):
2772          artifact_repo = _get_artifact_repo_mlflow_artifacts()
2773          artifact_path = _get_proxied_run_artifact_destination_path(
2774              proxied_artifact_root=artifact_uri,
2775              relative_path=path,
2776          )
2777          artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
2778      else:
2779          artifact_repo = get_artifact_repository(artifact_uri)
2780          artifact_path = path
2781  
2782      return _send_artifact(artifact_repo, artifact_path)
2783  
2784  
2785  @catch_mlflow_exception
2786  @_disable_if_artifacts_only
2787  def _get_model_version():
2788      request_message = _get_request_message(
2789          GetModelVersion(),
2790          schema={
2791              "name": [_assert_string, _assert_required],
2792              "version": [_assert_string, _assert_required],
2793          },
2794      )
2795      model_version = _get_model_registry_store().get_model_version(
2796          name=request_message.name, version=request_message.version
2797      )
2798      response_proto = model_version.to_proto()
2799      response_message = GetModelVersion.Response(model_version=response_proto)
2800      return _wrap_response(response_message)
2801  
2802  
2803  @catch_mlflow_exception
2804  @_disable_if_artifacts_only
2805  def _update_model_version():
2806      request_message = _get_request_message(
2807          UpdateModelVersion(),
2808          schema={
2809              "name": [_assert_string, _assert_required],
2810              "version": [_assert_string, _assert_required],
2811              "description": [_assert_string],
2812          },
2813      )
2814      new_description = None
2815      if request_message.HasField("description"):
2816          new_description = request_message.description
2817      model_version = _get_model_registry_store().update_model_version(
2818          name=request_message.name,
2819          version=request_message.version,
2820          description=new_description,
2821      )
2822      return _wrap_response(UpdateModelVersion.Response(model_version=model_version.to_proto()))
2823  
2824  
2825  @catch_mlflow_exception
2826  @_disable_if_artifacts_only
2827  def _transition_stage():
2828      request_message = _get_request_message(
2829          TransitionModelVersionStage(),
2830          schema={
2831              "name": [_assert_string, _assert_required],
2832              "version": [_assert_string, _assert_required],
2833              "stage": [_assert_string, _assert_required],
2834              "archive_existing_versions": [_assert_bool],
2835          },
2836      )
2837      model_version = _get_model_registry_store().transition_model_version_stage(
2838          name=request_message.name,
2839          version=request_message.version,
2840          stage=request_message.stage,
2841          archive_existing_versions=request_message.archive_existing_versions,
2842      )
2843      return _wrap_response(
2844          TransitionModelVersionStage.Response(model_version=model_version.to_proto())
2845      )
2846  
2847  
2848  @catch_mlflow_exception
2849  @_disable_if_artifacts_only
2850  def _delete_model_version():
2851      request_message = _get_request_message(
2852          DeleteModelVersion(),
2853          schema={
2854              "name": [_assert_string, _assert_required],
2855              "version": [_assert_string, _assert_required],
2856          },
2857      )
2858      _get_model_registry_store().delete_model_version(
2859          name=request_message.name, version=request_message.version
2860      )
2861      return _wrap_response(DeleteModelVersion.Response())
2862  
2863  
2864  @catch_mlflow_exception
2865  @_disable_if_artifacts_only
2866  def _get_model_version_download_uri():
2867      request_message = _get_request_message(GetModelVersionDownloadUri())
2868      download_uri = _get_model_registry_store().get_model_version_download_uri(
2869          name=request_message.name, version=request_message.version
2870      )
2871      response_message = GetModelVersionDownloadUri.Response(artifact_uri=download_uri)
2872      return _wrap_response(response_message)
2873  
2874  
2875  @catch_mlflow_exception
2876  @_disable_if_artifacts_only
2877  def _search_model_versions():
2878      request_message = _get_request_message(
2879          SearchModelVersions(),
2880          schema={
2881              "filter": [_assert_string],
2882              "max_results": [
2883                  _assert_intlike,
2884                  lambda x: _assert_less_than_or_equal(int(x), 200_000),
2885              ],
2886              "order_by": [_assert_array, _assert_item_type_string],
2887              "page_token": [_assert_string],
2888          },
2889      )
2890      response_message = search_model_versions_impl(request_message)
2891      return _wrap_response(response_message)
2892  
2893  
2894  def search_model_versions_impl(request_message):
2895      store = _get_model_registry_store()
2896      model_versions = store.search_model_versions(
2897          filter_string=request_message.filter,
2898          max_results=request_message.max_results,
2899          order_by=request_message.order_by,
2900          page_token=request_message.page_token or None,
2901      )
2902      response_message = SearchModelVersions.Response()
2903      response_message.model_versions.extend([e.to_proto() for e in model_versions])
2904      if model_versions.token:
2905          response_message.next_page_token = model_versions.token
2906      return response_message
2907  
2908  
2909  @catch_mlflow_exception
2910  @_disable_if_artifacts_only
2911  def _set_model_version_tag():
2912      request_message = _get_request_message(
2913          SetModelVersionTag(),
2914          schema={
2915              "name": [_assert_string, _assert_required],
2916              "version": [_assert_string, _assert_required],
2917              "key": [_assert_string, _assert_required],
2918              "value": [_assert_string],
2919          },
2920      )
2921      tag = ModelVersionTag(key=request_message.key, value=request_message.value)
2922      store = _get_model_registry_store()
2923      store.set_model_version_tag(name=request_message.name, version=request_message.version, tag=tag)
2924  
2925      if _is_prompt(request_message.name):
2926          # Send prompt version tag set webhook
2927          deliver_webhook(
2928              event=WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.SET),
2929              payload=PromptVersionTagSetPayload(
2930                  name=request_message.name,
2931                  version=request_message.version,
2932                  key=request_message.key,
2933                  value=request_message.value,
2934              ),
2935              store=store,
2936          )
2937      else:
2938          # Send regular model version tag set webhook
2939          deliver_webhook(
2940              event=WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.SET),
2941              payload=ModelVersionTagSetPayload(
2942                  name=request_message.name,
2943                  version=request_message.version,
2944                  key=request_message.key,
2945                  value=request_message.value,
2946              ),
2947              store=store,
2948          )
2949  
2950      return _wrap_response(SetModelVersionTag.Response())
2951  
2952  
2953  @catch_mlflow_exception
2954  @_disable_if_artifacts_only
2955  def _delete_model_version_tag():
2956      request_message = _get_request_message(
2957          DeleteModelVersionTag(),
2958          schema={
2959              "name": [_assert_string, _assert_required],
2960              "version": [_assert_string, _assert_required],
2961              "key": [_assert_string, _assert_required],
2962          },
2963      )
2964      store = _get_model_registry_store()
2965      store.delete_model_version_tag(
2966          name=request_message.name,
2967          version=request_message.version,
2968          key=request_message.key,
2969      )
2970  
2971      if _is_prompt(request_message.name):
2972          # Send prompt version tag deleted webhook
2973          deliver_webhook(
2974              event=WebhookEvent(WebhookEntity.PROMPT_VERSION_TAG, WebhookAction.DELETED),
2975              payload=PromptVersionTagDeletedPayload(
2976                  name=request_message.name,
2977                  version=request_message.version,
2978                  key=request_message.key,
2979              ),
2980              store=store,
2981          )
2982      else:
2983          # Send regular model version tag deleted webhook
2984          deliver_webhook(
2985              event=WebhookEvent(WebhookEntity.MODEL_VERSION_TAG, WebhookAction.DELETED),
2986              payload=ModelVersionTagDeletedPayload(
2987                  name=request_message.name,
2988                  version=request_message.version,
2989                  key=request_message.key,
2990              ),
2991              store=store,
2992          )
2993  
2994      return _wrap_response(DeleteModelVersionTag.Response())
2995  
2996  
2997  @catch_mlflow_exception
2998  @_disable_if_artifacts_only
2999  def _set_registered_model_alias():
3000      request_message = _get_request_message(
3001          SetRegisteredModelAlias(),
3002          schema={
3003              "name": [_assert_string, _assert_required],
3004              "alias": [_assert_string, _assert_required],
3005              "version": [_assert_string, _assert_required],
3006          },
3007      )
3008      store = _get_model_registry_store()
3009      store.set_registered_model_alias(
3010          name=request_message.name,
3011          alias=request_message.alias,
3012          version=request_message.version,
3013      )
3014  
3015      if _is_prompt(request_message.name):
3016          # Send prompt alias created webhook
3017          deliver_webhook(
3018              event=WebhookEvent(WebhookEntity.PROMPT_ALIAS, WebhookAction.CREATED),
3019              payload=PromptAliasCreatedPayload(
3020                  name=request_message.name,
3021                  alias=request_message.alias,
3022                  version=request_message.version,
3023              ),
3024              store=store,
3025          )
3026      else:
3027          # Send regular model version alias created webhook
3028          deliver_webhook(
3029              event=WebhookEvent(WebhookEntity.MODEL_VERSION_ALIAS, WebhookAction.CREATED),
3030              payload=ModelVersionAliasCreatedPayload(
3031                  name=request_message.name,
3032                  alias=request_message.alias,
3033                  version=request_message.version,
3034              ),
3035              store=store,
3036          )
3037  
3038      return _wrap_response(SetRegisteredModelAlias.Response())
3039  
3040  
3041  @catch_mlflow_exception
3042  @_disable_if_artifacts_only
3043  def _delete_registered_model_alias():
3044      request_message = _get_request_message(
3045          DeleteRegisteredModelAlias(),
3046          schema={
3047              "name": [_assert_string, _assert_required],
3048              "alias": [_assert_string, _assert_required],
3049          },
3050      )
3051      store = _get_model_registry_store()
3052      store.delete_registered_model_alias(name=request_message.name, alias=request_message.alias)
3053  
3054      if _is_prompt(request_message.name):
3055          # Send prompt alias deleted webhook
3056          deliver_webhook(
3057              event=WebhookEvent(WebhookEntity.PROMPT_ALIAS, WebhookAction.DELETED),
3058              payload=PromptAliasDeletedPayload(
3059                  name=request_message.name,
3060                  alias=request_message.alias,
3061              ),
3062              store=store,
3063          )
3064      else:
3065          # Send regular model version alias deleted webhook
3066          deliver_webhook(
3067              event=WebhookEvent(WebhookEntity.MODEL_VERSION_ALIAS, WebhookAction.DELETED),
3068              payload=ModelVersionAliasDeletedPayload(
3069                  name=request_message.name,
3070                  alias=request_message.alias,
3071              ),
3072              store=store,
3073          )
3074  
3075      return _wrap_response(DeleteRegisteredModelAlias.Response())
3076  
3077  
3078  @catch_mlflow_exception
3079  @_disable_if_artifacts_only
3080  def _get_model_version_by_alias():
3081      request_message = _get_request_message(
3082          GetModelVersionByAlias(),
3083          schema={
3084              "name": [_assert_string, _assert_required],
3085              "alias": [_assert_string, _assert_required],
3086          },
3087      )
3088      model_version = _get_model_registry_store().get_model_version_by_alias(
3089          name=request_message.name, alias=request_message.alias
3090      )
3091      response_proto = model_version.to_proto()
3092      response_message = GetModelVersionByAlias.Response(model_version=response_proto)
3093      return _wrap_response(response_message)
3094  
3095  
3096  # Webhook APIs
3097  @catch_mlflow_exception
3098  @_disable_if_artifacts_only
3099  def _create_webhook():
3100      request_message = _get_request_message(
3101          CreateWebhook(),
3102          schema={
3103              "name": [_assert_string, _assert_required],
3104              "url": [_assert_string, _assert_required],
3105              "events": [_assert_array, _assert_required],
3106              "description": [_assert_string],
3107              "secret": [_assert_string],
3108              "status": [_assert_string],
3109          },
3110      )
3111  
3112      webhook = _get_model_registry_store().create_webhook(
3113          name=request_message.name,
3114          url=request_message.url,
3115          events=[WebhookEvent.from_proto(e) for e in request_message.events],
3116          description=request_message.description or None,
3117          secret=request_message.secret or None,
3118          status=WebhookStatus.from_proto(request_message.status) if request_message.status else None,
3119      )
3120      response_message = CreateWebhook.Response(webhook=webhook.to_proto())
3121      return _wrap_response(response_message)
3122  
3123  
3124  @catch_mlflow_exception
3125  @_disable_if_artifacts_only
3126  def _list_webhooks():
3127      request_message = _get_request_message(
3128          ListWebhooks(),
3129          schema={
3130              "max_results": [_assert_intlike],
3131              "page_token": [_assert_string],
3132          },
3133      )
3134      webhooks_page = _get_model_registry_store().list_webhooks(
3135          max_results=request_message.max_results,
3136          page_token=request_message.page_token or None,
3137      )
3138      response_message = ListWebhooks.Response(
3139          webhooks=[w.to_proto() for w in webhooks_page],
3140          next_page_token=webhooks_page.token,
3141      )
3142      return _wrap_response(response_message)
3143  
3144  
3145  @catch_mlflow_exception
3146  @_disable_if_artifacts_only
3147  def _get_webhook(webhook_id: str):
3148      webhook = _get_model_registry_store().get_webhook(webhook_id=webhook_id)
3149      response_message = GetWebhook.Response(webhook=webhook.to_proto())
3150      return _wrap_response(response_message)
3151  
3152  
3153  @catch_mlflow_exception
3154  @_disable_if_artifacts_only
3155  def _update_webhook(webhook_id: str):
3156      request_message = _get_request_message(
3157          UpdateWebhook(),
3158          schema={
3159              "name": [_assert_string],
3160              "description": [_assert_string],
3161              "url": [_assert_string],
3162              "events": [_assert_array],
3163              "secret": [_assert_string],
3164              "status": [_assert_string],
3165          },
3166      )
3167      webhook = _get_model_registry_store().update_webhook(
3168          webhook_id=webhook_id,
3169          name=request_message.name or None,
3170          description=request_message.description or None,
3171          url=request_message.url or None,
3172          events=(
3173              [WebhookEvent.from_proto(e) for e in request_message.events]
3174              if request_message.events
3175              else None
3176          ),
3177          secret=request_message.secret or None,
3178          status=WebhookStatus.from_proto(request_message.status) if request_message.status else None,
3179      )
3180      response_message = UpdateWebhook.Response(webhook=webhook.to_proto())
3181      return _wrap_response(response_message)
3182  
3183  
3184  @catch_mlflow_exception
3185  @_disable_if_artifacts_only
3186  def _delete_webhook(webhook_id: str):
3187      _get_model_registry_store().delete_webhook(webhook_id=webhook_id)
3188      response_message = DeleteWebhook.Response()
3189      return _wrap_response(response_message)
3190  
3191  
3192  @catch_mlflow_exception
3193  @_disable_if_artifacts_only
3194  def _test_webhook(webhook_id: str):
3195      request_message = _get_request_message(TestWebhook())
3196      event = (
3197          WebhookEvent.from_proto(request_message.event)
3198          if request_message.HasField("event")
3199          else None
3200      )
3201      store = _get_model_registry_store()
3202      webhook = store.get_webhook(webhook_id=webhook_id)
3203      test_result = test_webhook(webhook=webhook, event=event)
3204      response_message = TestWebhook.Response(result=test_result.to_proto())
3205      return _wrap_response(response_message)
3206  
3207  
3208  # MLflow Artifacts APIs
3209  
3210  
3211  def _get_workspace_scoped_repo_path_if_enabled(artifact_path: str | None) -> str | None:
3212      """
3213      Normalize artifact paths for proxied (served) artifacts so they remain workspace-isolated.
3214  
3215      When ``mlflow-artifacts`` proxying is enabled and workspaces are on, every path under the HTTP
3216      artifact endpoint must be rooted at ``workspaces/<workspace>/...``. Direct artifact repositories
3217      (e.g., S3, GCS, local URIs) already encode their own isolation, so they bypass this logic by
3218      calling the underlying store directly. Only the proxied repos need to be rewritten/validated
3219      here.
3220  
3221      Returns:
3222          The workspace-scoped path. May return the original path in the following cases:
3223          - Workspaces are disabled (returns ``artifact_path`` unchanged).
3224          - Default workspace with no path (returns ``artifact_path`` unchanged to preserve legacy
3225            root behavior, where artifacts live at the root rather than under ``workspaces/default``).
3226          For non-default workspaces, always returns a string (``workspaces/<workspace>/...``).
3227      """
3228      if not MLFLOW_ENABLE_WORKSPACES.get():
3229          return artifact_path
3230  
3231      workspace = workspace_context.get_request_workspace()
3232      if not workspace:
3233          raise MlflowException.invalid_parameter_value(
3234              "Active workspace is required for artifact operations. "
3235              "Ensure X-MLFLOW-WORKSPACE is set or call mlflow.set_workspace()."
3236          )
3237  
3238      normalized = artifact_path.lstrip("/") if artifact_path else ""
3239      base = posixpath.join("workspaces", workspace)
3240  
3241      if not normalized:
3242          # For the default workspace, preserve the legacy root behavior (no prefix),
3243          # so root operations continue to see the existing layout.
3244          return base if workspace != DEFAULT_WORKSPACE_NAME else artifact_path
3245  
3246      if workspace == DEFAULT_WORKSPACE_NAME and not normalized.startswith("workspaces/"):
3247          # Legacy default-workspace artifacts never had the workspace prefix; allow them to be served
3248          # without rewriting as long as the path isn't trying to opt into the reserved namespace.
3249          return artifact_path
3250  
3251      leading_segments = normalized.split("/", 2)
3252      if leading_segments and leading_segments[0] == "workspaces":
3253          if len(leading_segments) == 1 or not leading_segments[1]:
3254              raise MlflowException.invalid_parameter_value(
3255                  "Artifact paths prefixed with 'workspaces/' must include a workspace name."
3256              )
3257          requested_workspace = leading_segments[1]
3258          if requested_workspace != workspace:
3259              raise MlflowException.invalid_parameter_value(
3260                  f"Artifact path targets workspace '{requested_workspace}' "
3261                  f"but the workspace specified in the request is '{workspace}'."
3262              )
3263          return normalized
3264  
3265      return posixpath.join(base, normalized)
3266  
3267  
3268  @catch_mlflow_exception
3269  @_disable_unless_serve_artifacts
3270  def _download_artifact(artifact_path):
3271      """
3272      A request handler for `GET /mlflow-artifacts/artifacts/<artifact_path>` to download an artifact
3273      from `artifact_path` (a relative path from the root artifact directory).
3274      """
3275      artifact_path = validate_path_is_safe(artifact_path)
3276      artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
3277      tmp_dir = tempfile.TemporaryDirectory()
3278      artifact_repo = _get_artifact_repo_mlflow_artifacts()
3279      dst = artifact_repo.download_artifacts(artifact_path, tmp_dir.name)
3280  
3281      # Ref: https://stackoverflow.com/a/24613980/6943581
3282      file_handle = open(dst, "rb")  # noqa: SIM115
3283  
3284      def stream_and_remove_file():
3285          while chunk := file_handle.read(ARTIFACT_STREAM_CHUNK_SIZE):
3286              yield chunk
3287          file_handle.close()
3288          tmp_dir.cleanup()
3289  
3290      file_sender_response = current_app.response_class(stream_and_remove_file())
3291  
3292      return _response_with_file_attachment_headers(artifact_path, file_sender_response)
3293  
3294  
3295  @catch_mlflow_exception
3296  @_disable_unless_serve_artifacts
3297  def _upload_artifact(artifact_path):
3298      """
3299      A request handler for `PUT /mlflow-artifacts/artifacts/<artifact_path>` to upload an artifact
3300      to `artifact_path` (a relative path from the root artifact directory).
3301      """
3302      artifact_path = validate_path_is_safe(artifact_path)
3303      artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
3304      head, tail = posixpath.split(artifact_path)
3305      with tempfile.TemporaryDirectory() as tmp_dir:
3306          tmp_path = os.path.join(tmp_dir, tail)
3307          with open(tmp_path, "wb") as f:
3308              while chunk := request.stream.read(ARTIFACT_STREAM_CHUNK_SIZE):
3309                  f.write(chunk)
3310  
3311          artifact_repo = _get_artifact_repo_mlflow_artifacts()
3312          artifact_repo.log_artifact(tmp_path, artifact_path=head or None)
3313  
3314      return _wrap_response(UploadArtifact.Response())
3315  
3316  
3317  @catch_mlflow_exception
3318  @_disable_unless_serve_artifacts
3319  def _list_artifacts_mlflow_artifacts():
3320      """
3321      A request handler for `GET /mlflow-artifacts/artifacts?path=<value>` to list artifacts in `path`
3322      (a relative path from the root artifact directory).
3323      """
3324      request_message = _get_request_message(ListArtifactsMlflowArtifacts())
3325      path = validate_path_is_safe(request_message.path) if request_message.HasField("path") else None
3326      path = _get_workspace_scoped_repo_path_if_enabled(path)
3327      artifact_repo = _get_artifact_repo_mlflow_artifacts()
3328      files = []
3329      for file_info in artifact_repo.list_artifacts(path):
3330          basename = posixpath.basename(file_info.path)
3331          new_file_info = FileInfo(basename, file_info.is_dir, file_info.file_size)
3332          files.append(new_file_info.to_proto())
3333      response_message = ListArtifacts.Response()
3334      response_message.files.extend(files)
3335      response = Response(mimetype="application/json")
3336      response.set_data(message_to_json(response_message))
3337      return response
3338  
3339  
3340  @catch_mlflow_exception
3341  @_disable_unless_serve_artifacts
3342  def _delete_artifact_mlflow_artifacts(artifact_path):
3343      """
3344      A request handler for `DELETE /mlflow-artifacts/artifacts?path=<value>` to delete artifacts in
3345      `path` (a relative path from the root artifact directory).
3346      """
3347      artifact_path = validate_path_is_safe(artifact_path)
3348      artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
3349      _get_request_message(DeleteArtifact())
3350      artifact_repo = _get_artifact_repo_mlflow_artifacts()
3351      artifact_repo.delete_artifacts(artifact_path)
3352      response_message = DeleteArtifact.Response()
3353      response = Response(mimetype="application/json")
3354      response.set_data(message_to_json(response_message))
3355      return response
3356  
3357  
3358  def _get_graphql_auth_middleware():
3359      """
3360      Get GraphQL authorization middleware if basic-auth is enabled.
3361  
3362      Returns:
3363          A list of middleware instances if auth is enabled, empty list otherwise.
3364      """
3365      try:
3366          from mlflow.server.auth import get_graphql_authorization_middleware
3367  
3368          return get_graphql_authorization_middleware()
3369      except Exception:
3370          # Auth not configured or other error
3371          return []
3372  
3373  
3374  @catch_mlflow_exception
3375  def _graphql():
3376      from graphql import parse
3377  
3378      from mlflow.server.graphql.graphql_no_batching import check_query_safety
3379      from mlflow.server.graphql.graphql_schema_extensions import schema
3380  
3381      # Extracting the query, variables, and operationName from the request
3382      request_json = _get_request_json()
3383      query = request_json.get("query")
3384      variables = request_json.get("variables")
3385      operation_name = request_json.get("operationName")
3386  
3387      node = parse(query)
3388      if check_result := check_query_safety(node):
3389          result = check_result
3390      else:
3391          # Get auth middleware if basic-auth is enabled
3392          middleware = _get_graphql_auth_middleware()
3393  
3394          # Executing the GraphQL query using the Graphene schema
3395          result = schema.execute(
3396              query,
3397              variables=variables,
3398              operation_name=operation_name,
3399              middleware=middleware,
3400          )
3401  
3402      # Convert execution result into json.
3403      result_data = {
3404          "data": result.data,
3405          "errors": [error.message for error in result.errors] if result.errors else None,
3406      }
3407  
3408      # Return the response
3409      return jsonify(result_data)
3410  
3411  
3412  def _validate_support_multipart_upload(artifact_repo):
3413      if not isinstance(artifact_repo, MultipartUploadMixin):
3414          raise _UnsupportedMultipartUploadException()
3415  
3416  
3417  def _validate_support_multipart_download(artifact_repo):
3418      if not isinstance(artifact_repo, MultipartDownloadMixin):
3419          raise _UnsupportedMultipartDownloadException()
3420  
3421  
3422  def _validate_support_presigned_upload(artifact_repo):
3423      if not isinstance(artifact_repo, PresignedUploadMixin):
3424          raise _UnsupportedPresignedUploadException()
3425  
3426  
3427  @catch_mlflow_exception
3428  @_disable_if_artifacts_only
3429  def _create_presigned_upload_url():
3430      """
3431      Handler for POST /api/2.0/mlflow/artifacts/presigned-upload-url.
3432      Generates a presigned URL for uploading an artifact directly to cloud storage.
3433  
3434      Client reference: https://github.com/aws/sagemaker-mlflow
3435      """
3436      request_message = _get_request_message(
3437          CreatePresignedUploadUrl(),
3438          schema={
3439              "run_id": [_assert_required, _assert_string],
3440              "path": [_assert_required, _assert_string],
3441              "expiration": [_assert_intlike],
3442          },
3443      )
3444      run_id = request_message.run_id
3445      path = validate_path_is_safe(request_message.path)
3446      expiration = request_message.expiration if request_message.HasField("expiration") else 900
3447  
3448      run = _get_tracking_store().get_run(run_id)
3449      artifact_uri = run.info.artifact_uri
3450      artifact_uri_scheme = urllib.parse.urlparse(artifact_uri).scheme
3451      if artifact_uri_scheme in ("http", "https", "mlflow-artifacts"):
3452          raise MlflowException(
3453              "Presigned upload is not supported for runs with proxied artifact storage "
3454              f"(artifact URI scheme: {artifact_uri_scheme}). "
3455              "This endpoint requires a run with a direct cloud storage artifact URI.",
3456              error_code=INVALID_PARAMETER_VALUE,
3457          )
3458      artifact_repo = _get_artifact_repo(run)
3459      _validate_support_presigned_upload(artifact_repo)
3460  
3461      response = artifact_repo.create_presigned_upload_url(path, expiration=expiration)
3462      response_message = response.to_proto()
3463      resp = Response(mimetype="application/json")
3464      resp.set_data(message_to_json(response_message))
3465      return resp
3466  
3467  
3468  @catch_mlflow_exception
3469  @_disable_unless_serve_artifacts
3470  def _create_multipart_upload_artifact(artifact_path):
3471      """
3472      A request handler for `POST /mlflow-artifacts/mpu/create` to create a multipart upload
3473      to `artifact_path` (a relative path from the root artifact directory).
3474      """
3475      artifact_path = validate_path_is_safe(artifact_path)
3476      artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
3477  
3478      request_message = _get_request_message(
3479          CreateMultipartUpload(),
3480          schema={
3481              "path": [_assert_required, _assert_string],
3482              "num_parts": [_assert_intlike],
3483          },
3484      )
3485      path = request_message.path
3486      num_parts = request_message.num_parts
3487  
3488      artifact_repo = _get_artifact_repo_mlflow_artifacts()
3489      _validate_support_multipart_upload(artifact_repo)
3490  
3491      create_response = artifact_repo.create_multipart_upload(
3492          path,
3493          num_parts,
3494          artifact_path,
3495      )
3496      response_message = create_response.to_proto()
3497      response = Response(mimetype="application/json")
3498      response.set_data(message_to_json(response_message))
3499      return response
3500  
3501  
3502  @catch_mlflow_exception
3503  @_disable_unless_serve_artifacts
3504  def _complete_multipart_upload_artifact(artifact_path):
3505      """
3506      A request handler for `POST /mlflow-artifacts/mpu/complete` to complete a multipart upload
3507      to `artifact_path` (a relative path from the root artifact directory).
3508      """
3509      artifact_path = validate_path_is_safe(artifact_path)
3510      artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
3511  
3512      request_message = _get_request_message(
3513          CompleteMultipartUpload(),
3514          schema={
3515              "path": [_assert_required, _assert_string],
3516              "upload_id": [_assert_string],
3517              "parts": [_assert_required],
3518          },
3519      )
3520      path = request_message.path
3521      upload_id = request_message.upload_id
3522      parts = [MultipartUploadPart.from_proto(part) for part in request_message.parts]
3523  
3524      artifact_repo = _get_artifact_repo_mlflow_artifacts()
3525      _validate_support_multipart_upload(artifact_repo)
3526  
3527      artifact_repo.complete_multipart_upload(
3528          path,
3529          upload_id,
3530          parts,
3531          artifact_path,
3532      )
3533      return _wrap_response(CompleteMultipartUpload.Response())
3534  
3535  
3536  @catch_mlflow_exception
3537  @_disable_unless_serve_artifacts
3538  def _abort_multipart_upload_artifact(artifact_path):
3539      """
3540      A request handler for `POST /mlflow-artifacts/mpu/abort` to abort a multipart upload
3541      to `artifact_path` (a relative path from the root artifact directory).
3542      """
3543      artifact_path = validate_path_is_safe(artifact_path)
3544      artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
3545  
3546      request_message = _get_request_message(
3547          AbortMultipartUpload(),
3548          schema={
3549              "path": [_assert_required, _assert_string],
3550              "upload_id": [_assert_string],
3551          },
3552      )
3553      path = request_message.path
3554      upload_id = request_message.upload_id
3555  
3556      artifact_repo = _get_artifact_repo_mlflow_artifacts()
3557      _validate_support_multipart_upload(artifact_repo)
3558  
3559      artifact_repo.abort_multipart_upload(
3560          path,
3561          upload_id,
3562          artifact_path,
3563      )
3564      return _wrap_response(AbortMultipartUpload.Response())
3565  
3566  
3567  @catch_mlflow_exception
3568  @_disable_unless_serve_artifacts
3569  def _get_presigned_download_url(artifact_path):
3570      """
3571      A request handler for `GET /mlflow-artifacts/presigned/<artifact_path>` to get
3572      a presigned URL for downloading an artifact directly from cloud storage.
3573      """
3574      artifact_path = validate_path_is_safe(artifact_path)
3575  
3576      artifact_repo = _get_artifact_repo_mlflow_artifacts()
3577      _validate_support_multipart_download(artifact_repo)
3578  
3579      expiration = MLFLOW_PRESIGNED_DOWNLOAD_URL_TTL_SECONDS.get()
3580      presigned_response = artifact_repo.get_download_presigned_url(
3581          artifact_path, expiration=expiration
3582      )
3583      response = Response(mimetype="application/json")
3584      response.set_data(json.dumps(presigned_response.to_dict()))
3585      return response
3586  
3587  
3588  # MLflow Tracing APIs
3589  
3590  
3591  @catch_mlflow_exception
3592  @_disable_if_artifacts_only
3593  def _start_trace_v3():
3594      """
3595      A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store.
3596      """
3597      request_message = _get_request_message(
3598          StartTraceV3(),
3599          schema={"trace": [_assert_required]},
3600      )
3601      trace_info = TraceInfo.from_proto(request_message.trace.trace_info)
3602      trace_info = _get_tracking_store().start_trace(trace_info)
3603      response_message = StartTraceV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto()))
3604      return _wrap_response(response_message)
3605  
3606  
3607  @catch_mlflow_exception
3608  @_disable_if_artifacts_only
3609  def _get_trace_info_v3(trace_id):
3610      """
3611      A request handler for `GET /mlflow/traces/{trace_id}` to retrieve
3612      an existing TraceInfo record from tracking store.
3613      """
3614      trace_info = _get_tracking_store().get_trace_info(trace_id)
3615      response_message = GetTraceInfoV3.Response(trace=ProtoTrace(trace_info=trace_info.to_proto()))
3616      return _wrap_response(response_message)
3617  
3618  
3619  @catch_mlflow_exception
3620  @_disable_if_artifacts_only
3621  def _batch_get_traces() -> Response:
3622      """
3623      A request handler for `GET /mlflow/traces/batchGet` to retrieve
3624      a batch of complete traces with spans for given trace ids.
3625      """
3626      request_message = _get_request_message(
3627          BatchGetTraces(),
3628          schema={"trace_ids": [_assert_array, _assert_required, _assert_item_type_string]},
3629      )
3630      traces = _get_tracking_store().batch_get_traces(request_message.trace_ids, None)
3631      response_message = BatchGetTraces.Response()
3632      response_message.traces.extend([t.to_proto() for t in traces])
3633      return _wrap_response(response_message)
3634  
3635  
3636  @catch_mlflow_exception
3637  @_disable_if_artifacts_only
3638  def _batch_get_trace_infos() -> Response:
3639      request_message = _get_request_message(
3640          BatchGetTraceInfos(),
3641          schema={"trace_ids": [_assert_array, _assert_required, _assert_item_type_string]},
3642      )
3643      trace_infos = _get_tracking_store().batch_get_trace_infos(request_message.trace_ids)
3644      response_message = BatchGetTraceInfos.Response()
3645      response_message.trace_infos.extend([ti.to_proto() for ti in trace_infos])
3646      return _wrap_response(response_message)
3647  
3648  
3649  @catch_mlflow_exception
3650  @_disable_if_artifacts_only
3651  def _get_trace() -> Response:
3652      """
3653      A request handler for `GET /mlflow/traces/get` to get a trace with spans for given trace id.
3654      """
3655      request_message = _get_request_message(
3656          GetTrace(),
3657          schema={
3658              "trace_id": [_assert_string, _assert_required],
3659              "allow_partial": [_assert_bool],
3660          },
3661      )
3662      trace_id = request_message.trace_id
3663      allow_partial = request_message.allow_partial
3664      trace = _get_tracking_store().get_trace(trace_id, allow_partial=allow_partial)
3665      response_message = GetTrace.Response(trace=trace.to_proto())
3666      return _wrap_response(response_message)
3667  
3668  
3669  @catch_mlflow_exception
3670  @_disable_if_artifacts_only
3671  def _search_traces_v3():
3672      """
3673      A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store.
3674      """
3675      request_message = _get_request_message(
3676          SearchTracesV3(),
3677          schema={
3678              "locations": [_assert_array, _assert_required],
3679              "filter": [_assert_string],
3680              "max_results": [
3681                  _assert_intlike,
3682                  lambda x: _assert_less_than_or_equal(int(x), 500),
3683              ],
3684              "order_by": [_assert_array, _assert_item_type_string],
3685              "page_token": [_assert_string],
3686          },
3687      )
3688      experiment_ids = [
3689          location.mlflow_experiment.experiment_id
3690          for location in request_message.locations
3691          if location.HasField("mlflow_experiment")
3692      ]
3693  
3694      traces, token = _get_tracking_store().search_traces(
3695          locations=experiment_ids,
3696          filter_string=request_message.filter,
3697          max_results=request_message.max_results,
3698          order_by=request_message.order_by,
3699          page_token=request_message.page_token or None,
3700      )
3701      response_message = SearchTracesV3.Response()
3702      response_message.traces.extend([e.to_proto() for e in traces])
3703      if token:
3704          response_message.next_page_token = token
3705      return _wrap_response(response_message)
3706  
3707  
3708  @catch_mlflow_exception
3709  @_disable_if_artifacts_only
3710  def _delete_traces():
3711      """
3712      A request handler for `POST /mlflow/traces/delete-traces` to delete TraceInfo records
3713      from tracking store.
3714      """
3715      request_message = _get_request_message(
3716          DeleteTraces(),
3717          schema={
3718              "experiment_id": [_assert_string, _assert_required],
3719              "max_timestamp_millis": [_assert_intlike],
3720              "max_traces": [_assert_intlike],
3721              "request_ids": [_assert_array, _assert_item_type_string],
3722          },
3723      )
3724  
3725      # NB: Interestingly, the field accessor for the message object returns the default
3726      #   value for optional field if it's not set. For example, `request_message.max_traces`
3727      #   returns 0 if max_traces is not specified in the request. This is not desirable,
3728      #   because null and 0 means completely opposite i.e. the former is 'delete nothing'
3729      #   while the latter is 'delete all'. To handle this, we need to explicitly check
3730      #   if the field is set or not using `HasField` method and return None if not.
3731      def _get_nullable_field(field):
3732          if request_message.HasField(field):
3733              return getattr(request_message, field)
3734          return None
3735  
3736      traces_deleted = _get_tracking_store().delete_traces(
3737          experiment_id=request_message.experiment_id,
3738          max_timestamp_millis=_get_nullable_field("max_timestamp_millis"),
3739          max_traces=_get_nullable_field("max_traces"),
3740          trace_ids=request_message.request_ids,
3741      )
3742      return _wrap_response(DeleteTraces.Response(traces_deleted=traces_deleted))
3743  
3744  
3745  @catch_mlflow_exception
3746  @_disable_if_artifacts_only
3747  def _calculate_trace_filter_correlation():
3748      """
3749      A request handler for `POST /mlflow/traces/calculate-filter-correlation` to calculate
3750      NPMI correlation between two trace filter conditions.
3751      """
3752      request_message = _get_request_message(
3753          CalculateTraceFilterCorrelation(),
3754          schema={
3755              "experiment_ids": [_assert_array, _assert_required, _assert_item_type_string],
3756              "filter_string1": [_assert_string, _assert_required],
3757              "filter_string2": [_assert_string, _assert_required],
3758              "base_filter": [_assert_string],
3759          },
3760      )
3761  
3762      result = _get_tracking_store().calculate_trace_filter_correlation(
3763          experiment_ids=request_message.experiment_ids,
3764          filter_string1=request_message.filter_string1,
3765          filter_string2=request_message.filter_string2,
3766          base_filter=request_message.base_filter
3767          if request_message.HasField("base_filter")
3768          else None,
3769      )
3770  
3771      return _wrap_response(result.to_proto())
3772  
3773  
3774  @catch_mlflow_exception
3775  @_disable_if_artifacts_only
3776  def _set_trace_tag(request_id):
3777      """
3778      A request handler for `PATCH /mlflow/traces/{request_id}/tags` to set tags on a TraceInfo record
3779      """
3780      request_message = _get_request_message(
3781          SetTraceTag(),
3782          schema={
3783              "key": [_assert_string, _assert_required],
3784              "value": [_assert_string],
3785          },
3786      )
3787      _get_tracking_store().set_trace_tag(request_id, request_message.key, request_message.value)
3788      return _wrap_response(SetTraceTag.Response())
3789  
3790  
3791  @catch_mlflow_exception
3792  @_disable_if_artifacts_only
3793  def _set_trace_tag_v3(trace_id):
3794      """
3795      A request handler for `PATCH /mlflow/traces/{trace_id}/tags` to set tags on a TraceInfo record.
3796      Identical to `_set_trace_tag`, but with request_id renamed to with trace_id.
3797      """
3798      request_message = _get_request_message(
3799          SetTraceTagV3(),
3800          schema={
3801              "key": [_assert_string, _assert_required],
3802              "value": [_assert_string],
3803          },
3804      )
3805      _get_tracking_store().set_trace_tag(trace_id, request_message.key, request_message.value)
3806      return _wrap_response(SetTraceTagV3.Response())
3807  
3808  
3809  @catch_mlflow_exception
3810  @_disable_if_artifacts_only
3811  def _delete_trace_tag(request_id):
3812      """
3813      A request handler for `DELETE /mlflow/traces/{request_id}/tags` to delete tags from a TraceInfo
3814      record.
3815      """
3816      request_message = _get_request_message(
3817          DeleteTraceTag(),
3818          schema={
3819              "key": [_assert_string, _assert_required],
3820          },
3821      )
3822      _get_tracking_store().delete_trace_tag(request_id, request_message.key)
3823      return _wrap_response(DeleteTraceTag.Response())
3824  
3825  
3826  @catch_mlflow_exception
3827  @_disable_if_artifacts_only
3828  def _delete_trace_tag_v3(trace_id):
3829      """
3830      A request handler for `DELETE /mlflow/traces/{trace_id}/tags` to delete tags
3831      from a TraceInfo record.
3832      Identical to `_delete_trace_tag`, but with request_id renamed to with trace_id.
3833      """
3834      request_message = _get_request_message(
3835          DeleteTraceTagV3(),
3836          schema={
3837              "key": [_assert_string, _assert_required],
3838          },
3839      )
3840      _get_tracking_store().delete_trace_tag(trace_id, request_message.key)
3841      return _wrap_response(DeleteTraceTagV3.Response())
3842  
3843  
3844  @catch_mlflow_exception
3845  @_disable_if_artifacts_only
3846  def _link_traces_to_run():
3847      """
3848      A request handler for `POST /mlflow/traces/link-to-run` to link traces to a run.
3849      """
3850      request_message = _get_request_message(
3851          LinkTracesToRun(),
3852          schema={
3853              "trace_ids": [_assert_array, _assert_required, _assert_item_type_string],
3854              "run_id": [_assert_string, _assert_required],
3855          },
3856      )
3857      _get_tracking_store().link_traces_to_run(
3858          trace_ids=request_message.trace_ids,
3859          run_id=request_message.run_id,
3860      )
3861      return _wrap_response(LinkTracesToRun.Response())
3862  
3863  
3864  @catch_mlflow_exception
3865  @_disable_if_artifacts_only
3866  def _link_prompts_to_trace():
3867      """
3868      A request handler for `POST /mlflow/traces/link-prompts` to link prompt versions to a trace.
3869      """
3870      from mlflow.entities.model_registry import PromptVersion
3871  
3872      request_message = _get_request_message(
3873          LinkPromptsToTrace(),
3874          schema={
3875              "trace_id": [_assert_string, _assert_required],
3876              "prompt_versions": [_assert_array, _assert_required],
3877          },
3878      )
3879  
3880      # Convert PromptVersionRef proto messages to PromptVersion objects
3881      # It doesn't load prompt versions since name and version are sufficient for linking
3882      prompt_versions = [
3883          PromptVersion(name=pv.name, version=int(pv.version), template="")
3884          for pv in request_message.prompt_versions
3885      ]
3886  
3887      _get_tracking_store().link_prompts_to_trace(
3888          trace_id=request_message.trace_id,
3889          prompt_versions=prompt_versions,
3890      )
3891      return _wrap_response(LinkPromptsToTrace.Response())
3892  
3893  
3894  def _fetch_trace_data_from_store(
3895      store: AbstractTrackingStore, request_id: str
3896  ) -> dict[str, Any] | None:
3897      try:
3898          # allow partial so the frontend can render in-progress traces
3899          trace = store.get_trace(request_id, allow_partial=True)
3900          return trace.data.to_dict()
3901      except MlflowTracingException:
3902          return None
3903      except MlflowNotImplementedException:
3904          # fallback to batch_get_traces if get_trace is not implemented
3905          pass
3906  
3907      try:
3908          traces = store.batch_get_traces([request_id], None)
3909          match traces:
3910              case [trace]:
3911                  return trace.data.to_dict()
3912              case _:
3913                  raise MlflowException(
3914                      f"Trace with id={request_id} not found.",
3915                      error_code=RESOURCE_DOES_NOT_EXIST,
3916                  )
3917      # For stores that don't support batch get traces, or if trace data is not in the store,
3918      # return None to signal fallback to artifact repository
3919      except (MlflowTracingException, MlflowNotImplementedException):
3920          return None
3921  
3922  
3923  @catch_mlflow_exception
3924  @_disable_if_artifacts_only
3925  def get_trace_artifact_handler() -> Response:
3926      request_id = request.args.get("request_id")
3927      path = request.args.get("path")
3928  
3929      if not request_id:
3930          raise MlflowException(
3931              'Request must include the "request_id" query parameter.',
3932              error_code=BAD_REQUEST,
3933          )
3934  
3935      store = _get_tracking_store()
3936  
3937      if path:
3938          path = validate_path_is_safe(path)
3939          trace_info = store.get_trace_info(request_id)
3940          if trace_info is None:
3941              raise MlflowException(
3942                  f"Trace with ID '{request_id}' not found.",
3943                  error_code=RESOURCE_DOES_NOT_EXIST,
3944              )
3945          repo = _get_trace_artifact_repo(trace_info)
3946          try:
3947              content_bytes = repo.download_trace_attachment(path)
3948          except MlflowException:
3949              raise
3950          except Exception:
3951              _logger.warning(
3952                  "Failed to download attachment '%s' for trace '%s'",
3953                  path,
3954                  request_id,
3955                  exc_info=True,
3956              )
3957              raise MlflowException(
3958                  f"Failed to download attachment '{path}' for trace '{request_id}'.",
3959                  error_code=INTERNAL_ERROR,
3960              )
3961          buf = io.BytesIO(content_bytes)
3962          file_sender_response = send_file(
3963              buf,
3964              mimetype="application/octet-stream",
3965              as_attachment=True,
3966              download_name=path,
3967          )
3968          return _response_with_file_attachment_headers(path, file_sender_response)
3969  
3970      trace_data = _fetch_trace_data_from_store(store, request_id)
3971      if trace_data is None:
3972          trace_info = store.get_trace_info(request_id)
3973          trace_data = _get_trace_artifact_repo(trace_info).download_trace_data()
3974  
3975      # Write data to a BytesIO buffer instead of needing to save a temp file
3976      buf = io.BytesIO()
3977      buf.write(json.dumps(trace_data).encode())
3978      buf.seek(0)
3979  
3980      file_sender_response = send_file(
3981          buf,
3982          mimetype="application/octet-stream",
3983          as_attachment=True,
3984          download_name=TRACE_DATA_FILE_NAME,
3985      )
3986      return _response_with_file_attachment_headers(TRACE_DATA_FILE_NAME, file_sender_response)
3987  
3988  
3989  @catch_mlflow_exception
3990  @_disable_if_artifacts_only
3991  def _query_trace_metrics() -> Response:
3992      request_message = _get_request_message(
3993          QueryTraceMetrics(),
3994          schema={
3995              "experiment_ids": [_assert_array, _assert_required, _assert_item_type_string],
3996              "view_type": [_assert_required],
3997              "metric_name": [_assert_string, _assert_required],
3998              "aggregations": [_assert_array, _assert_required],
3999              "dimensions": [_assert_array, _assert_item_type_string],
4000              "filters": [_assert_array, _assert_item_type_string],
4001              "time_interval_seconds": [_assert_intlike],
4002              "start_time_ms": [_assert_intlike],
4003              "end_time_ms": [_assert_intlike],
4004              "max_results": [_assert_intlike],
4005              "page_token": [_assert_string],
4006          },
4007      )
4008      max_results = (
4009          request_message.max_results
4010          if request_message.HasField("max_results")
4011          else MAX_RESULTS_QUERY_TRACE_METRICS
4012      )
4013      time_interval_seconds = (
4014          request_message.time_interval_seconds
4015          if request_message.HasField("time_interval_seconds")
4016          else None
4017      )
4018      start_time_ms = (
4019          request_message.start_time_ms if request_message.HasField("start_time_ms") else None
4020      )
4021      end_time_ms = request_message.end_time_ms if request_message.HasField("end_time_ms") else None
4022  
4023      result = _get_tracking_store().query_trace_metrics(
4024          experiment_ids=request_message.experiment_ids,
4025          view_type=MetricViewType.from_proto(request_message.view_type),
4026          metric_name=request_message.metric_name,
4027          aggregations=[MetricAggregation.from_proto(agg) for agg in request_message.aggregations],
4028          dimensions=request_message.dimensions or None,
4029          filters=request_message.filters or None,
4030          time_interval_seconds=time_interval_seconds,
4031          start_time_ms=start_time_ms,
4032          end_time_ms=end_time_ms,
4033          max_results=max_results,
4034          page_token=request_message.page_token or None,
4035      )
4036      response_message = QueryTraceMetrics.Response()
4037      response_message.data_points.extend([dp.to_proto() for dp in result])
4038      if result.token:
4039          response_message.next_page_token = result.token
4040      return _wrap_response(response_message)
4041  
4042  
4043  # Assessments API handlers
4044  @catch_mlflow_exception
4045  @_disable_if_artifacts_only
4046  def _create_assessment(trace_id):
4047      """
4048      A request handler for `POST /mlflow/traces/{assessment.trace_id}/assessments`
4049      to create a new assessment.
4050      """
4051      request_message = _get_request_message(
4052          CreateAssessment(),
4053          schema={
4054              "assessment": [_assert_required],
4055          },
4056      )
4057  
4058      assessment = Assessment.from_proto(request_message.assessment)
4059      assessment.trace_id = trace_id
4060      created_assessment = _get_tracking_store().create_assessment(assessment)
4061  
4062      response_message = CreateAssessment.Response(assessment=created_assessment.to_proto())
4063      return _wrap_response(response_message)
4064  
4065  
4066  @catch_mlflow_exception
4067  @_disable_if_artifacts_only
4068  def _get_assessment(trace_id, assessment_id):
4069      """
4070      A request handler for `GET /mlflow/traces/{trace_id}/assessments/{assessment_id}`
4071      to get an assessment.
4072      """
4073      assessment = _get_tracking_store().get_assessment(trace_id, assessment_id)
4074  
4075      response_message = GetAssessmentRequest.Response(assessment=assessment.to_proto())
4076      return _wrap_response(response_message)
4077  
4078  
4079  @catch_mlflow_exception
4080  @_disable_if_artifacts_only
4081  def _update_assessment(trace_id, assessment_id):
4082      """
4083      A request handler for `PATCH /mlflow/traces/{trace_id}/assessments/{assessment_id}`
4084      to update an assessment.
4085      """
4086      request_message = _get_request_message(
4087          UpdateAssessment(),
4088          schema={
4089              "assessment": [_assert_required],
4090              "update_mask": [_assert_required],
4091          },
4092      )
4093  
4094      assessment_proto = request_message.assessment
4095      update_mask = request_message.update_mask
4096  
4097      kwargs = {}
4098  
4099      for path in update_mask.paths:
4100          if path == "assessment_name":
4101              kwargs["name"] = assessment_proto.assessment_name
4102          elif path == "expectation":
4103              kwargs["expectation"] = Expectation.from_proto(assessment_proto)
4104          elif path == "feedback":
4105              kwargs["feedback"] = Feedback.from_proto(assessment_proto)
4106          elif path == "rationale":
4107              kwargs["rationale"] = assessment_proto.rationale
4108          elif path == "metadata":
4109              kwargs["metadata"] = dict(assessment_proto.metadata)
4110          elif path == "valid":
4111              kwargs["valid"] = assessment_proto.valid
4112  
4113      updated_assessment = _get_tracking_store().update_assessment(
4114          trace_id=trace_id, assessment_id=assessment_id, **kwargs
4115      )
4116  
4117      response_message = UpdateAssessment.Response(assessment=updated_assessment.to_proto())
4118      return _wrap_response(response_message)
4119  
4120  
4121  @catch_mlflow_exception
4122  @_disable_if_artifacts_only
4123  def _delete_assessment(trace_id, assessment_id):
4124      """
4125      A request handler for `DELETE /mlflow/traces/{trace_id}/assessments/{assessment_id}`
4126      to delete an assessment.
4127      """
4128      _get_tracking_store().delete_assessment(trace_id, assessment_id)
4129  
4130      response_message = DeleteAssessment.Response()
4131      return _wrap_response(response_message)
4132  
4133  
4134  @catch_mlflow_exception
4135  @_disable_if_artifacts_only
4136  def _create_issue():
4137      """
4138      A request handler for `POST /mlflow/issues` to create a new issue.
4139      """
4140      request_message = _get_request_message(
4141          CreateIssue(),
4142          schema={
4143              "name": [_assert_required, _assert_string],
4144              "description": [_assert_required, _assert_string],
4145              "experiment_id": [_assert_required, _assert_string],
4146          },
4147      )
4148  
4149      # Build kwargs for create_issue
4150      create_kwargs = {
4151          "experiment_id": request_message.experiment_id,
4152          "name": request_message.name,
4153          "description": request_message.description,
4154          "source_run_id": request_message.source_run_id or None,
4155          "root_causes": list(request_message.root_causes) or None,
4156          "categories": list(request_message.categories) or None,
4157          "created_by": request_message.created_by or None,
4158      }
4159  
4160      if request_message.HasField("status"):
4161          create_kwargs["status"] = IssueStatus(request_message.status)
4162      if request_message.HasField("severity"):
4163          create_kwargs["severity"] = IssueSeverity(request_message.severity)
4164  
4165      created_issue = _get_tracking_store().create_issue(**create_kwargs)
4166  
4167      response_message = CreateIssue.Response(issue=created_issue.to_proto())
4168      return _wrap_response(response_message)
4169  
4170  
4171  @catch_mlflow_exception
4172  @_disable_if_artifacts_only
4173  def _update_issue(issue_id):
4174      """
4175      A request handler for `PATCH /mlflow/issues/{issue_id}` to update an issue.
4176      """
4177      request_message = _get_request_message(
4178          UpdateIssue(),
4179          schema={
4180              "issue_id": [_assert_required],
4181          },
4182      )
4183  
4184      status = IssueStatus(request_message.status) if request_message.HasField("status") else None
4185      severity = (
4186          IssueSeverity(request_message.severity) if request_message.HasField("severity") else None
4187      )
4188  
4189      updated_issue = _get_tracking_store().update_issue(
4190          issue_id=issue_id,
4191          status=status,
4192          name=request_message.name or None,
4193          description=request_message.description or None,
4194          severity=severity,
4195      )
4196  
4197      response_message = UpdateIssue.Response(issue=updated_issue.to_proto())
4198      return _wrap_response(response_message)
4199  
4200  
4201  @catch_mlflow_exception
4202  @_disable_if_artifacts_only
4203  def _get_issue(issue_id):
4204      """
4205      A request handler for `GET /mlflow/issues/{issue_id}` to get an issue.
4206      """
4207      issue = _get_tracking_store().get_issue(issue_id)
4208  
4209      response_message = GetIssue.Response(issue=issue.to_proto())
4210      return _wrap_response(response_message)
4211  
4212  
4213  @catch_mlflow_exception
4214  @_disable_if_artifacts_only
4215  def _search_issues():
4216      """
4217      A request handler for `POST /mlflow/issues/search` to search for issues.
4218      """
4219      request_message = _get_request_message(SearchIssues())
4220  
4221      # Build kwargs for search_issues
4222      search_kwargs = {
4223          "experiment_id": request_message.experiment_id or None,
4224          "filter_string": request_message.filter_string or None,
4225          "page_token": request_message.page_token or None,
4226      }
4227  
4228      if request_message.HasField("max_results"):
4229          search_kwargs["max_results"] = request_message.max_results
4230  
4231      if request_message.HasField("include_trace_count"):
4232          search_kwargs["include_trace_count"] = request_message.include_trace_count
4233  
4234      issues = _get_tracking_store().search_issues(**search_kwargs)
4235  
4236      issue_protos = [issue.to_proto() for issue in issues]
4237      response_message = SearchIssues.Response(
4238          issues=issue_protos, next_page_token=issues.token or ""
4239      )
4240      return _wrap_response(response_message)
4241  
4242  
4243  @catch_mlflow_exception
4244  @_disable_if_artifacts_only
4245  def _invoke_issue_detection_handler():
4246      """
4247      Invoke issue detection on traces asynchronously.
4248  
4249      This is a UI-only AJAX endpoint for running issue detection from the frontend.
4250      """
4251      from mlflow.genai.discovery.job import _fetch_provider_credentials, invoke_issue_detection_job
4252      from mlflow.server.jobs import submit_job
4253  
4254      _validate_content_type(request, ["application/json"])
4255  
4256      request_json = _get_validated_flask_request_json(
4257          schema={
4258              "experiment_id": [_assert_required, _assert_string],
4259              "trace_ids": [_assert_required, _assert_array],
4260              "categories": [_assert_required, _assert_array],
4261              "provider": [_assert_required, _assert_string],
4262              "model": [_assert_string],
4263              "secret_id": [_assert_string],
4264              "endpoint_name": [_assert_string],
4265          }
4266      )
4267  
4268      experiment_id = request_json.get("experiment_id")
4269      trace_ids = request_json.get("trace_ids", [])
4270      categories = request_json.get("categories", [])
4271      provider = request_json.get("provider")
4272      model = request_json.get("model")
4273      secret_id = request_json.get("secret_id")
4274      endpoint_name = request_json.get("endpoint_name")
4275  
4276      if not endpoint_name and not (provider and model):
4277          raise MlflowException(
4278              "Either 'endpoint_name' or both 'provider' and 'model' must be provided"
4279          )
4280  
4281      # Fetch credentials required for executing the job
4282      if secret_id:
4283          store = _get_tracking_store()
4284          credentials = _fetch_provider_credentials(store, provider, secret_id)
4285      else:
4286          credentials = None
4287  
4288      # Create the run upfront so we can return run_id immediately
4289      model_name = f"gateway:/{endpoint_name}" if endpoint_name else f"{provider}:/{model}"
4290      tags = {
4291          MLFLOW_RUN_TYPE: MLFLOW_RUN_TYPE_ISSUE_DETECTION,
4292          "categories": ",".join(categories),
4293          "model": model_name,
4294          "total_traces": len(trace_ids),
4295      }
4296      if endpoint_name:
4297          tags["endpoint_name"] = endpoint_name
4298      run = mlflow.start_run(
4299          experiment_id=experiment_id,
4300          tags=tags,
4301      )
4302      run_id = run.info.run_id
4303  
4304      job = submit_job(
4305          function=invoke_issue_detection_job,
4306          params={
4307              "experiment_id": experiment_id,
4308              "trace_ids": trace_ids,
4309              "categories": categories,
4310              "run_id": run_id,
4311              "model": model_name,
4312          },
4313          extra_envs=credentials,
4314      )
4315      # Tag the run with job ID for later retrieval
4316      mlflow.set_tag(MLFLOW_ISSUE_DETECTION_JOB_ID, job.job_id)
4317      mlflow.end_run(RunStatus.to_string(RunStatus.RUNNING))
4318  
4319      return jsonify({"job_id": job.job_id, "run_id": run_id})
4320  
4321  
4322  @catch_mlflow_exception
4323  @_disable_if_artifacts_only
4324  def _get_job(job_id):
4325      from mlflow.server.jobs import get_job
4326  
4327      job = get_job(job_id)
4328      return jsonify({
4329          "status": str(job.status),
4330          "result": job.parsed_result,
4331          "status_details": job.status_details,
4332      })
4333  
4334  
4335  @catch_mlflow_exception
4336  @_disable_if_artifacts_only
4337  def _cancel_job(job_id):
4338      from mlflow.server.jobs import cancel_job
4339  
4340      job = cancel_job(job_id)
4341      return jsonify({
4342          "status": str(job.status),
4343          "result": job.parsed_result,
4344      })
4345  
4346  
4347  # Deprecated MLflow Tracing APIs. Kept for backward compatibility but do not use.
4348  
4349  
4350  @catch_mlflow_exception
4351  @_disable_if_artifacts_only
4352  def _deprecated_start_trace_v2():
4353      """
4354      A request handler for `POST /mlflow/traces` to create a new TraceInfo record in tracking store.
4355      """
4356      request_message = _get_request_message(
4357          StartTrace(),
4358          schema={
4359              "experiment_id": [_assert_string],
4360              "timestamp_ms": [_assert_intlike],
4361              "request_metadata": [_assert_map_key_present],
4362              "tags": [_assert_map_key_present],
4363          },
4364      )
4365      request_metadata = {e.key: e.value for e in request_message.request_metadata}
4366      tags = {e.key: e.value for e in request_message.tags}
4367  
4368      trace_info = _get_tracking_store().deprecated_start_trace_v2(
4369          experiment_id=request_message.experiment_id,
4370          timestamp_ms=request_message.timestamp_ms,
4371          request_metadata=request_metadata,
4372          tags=tags,
4373      )
4374      response_message = StartTrace.Response(trace_info=trace_info.to_proto())
4375      return _wrap_response(response_message)
4376  
4377  
4378  @catch_mlflow_exception
4379  @_disable_if_artifacts_only
4380  def _deprecated_end_trace_v2(request_id):
4381      """
4382      A request handler for `PATCH /mlflow/traces/{request_id}` to mark an existing TraceInfo
4383      record completed in tracking store.
4384      """
4385      request_message = _get_request_message(
4386          EndTrace(),
4387          schema={
4388              "timestamp_ms": [_assert_intlike],
4389              "status": [_assert_string],
4390              "request_metadata": [_assert_map_key_present],
4391              "tags": [_assert_map_key_present],
4392          },
4393      )
4394      request_metadata = {e.key: e.value for e in request_message.request_metadata}
4395      tags = {e.key: e.value for e in request_message.tags}
4396  
4397      trace_info = _get_tracking_store().deprecated_end_trace_v2(
4398          request_id=request_id,
4399          timestamp_ms=request_message.timestamp_ms,
4400          status=TraceStatus.from_proto(request_message.status),
4401          request_metadata=request_metadata,
4402          tags=tags,
4403      )
4404  
4405      if isinstance(trace_info, TraceInfo):
4406          trace_info = TraceInfoV2.from_v3(trace_info)
4407  
4408      response_message = EndTrace.Response(trace_info=trace_info.to_proto())
4409      return _wrap_response(response_message)
4410  
4411  
4412  @catch_mlflow_exception
4413  @_disable_if_artifacts_only
4414  def _deprecated_get_trace_info_v2(request_id):
4415      """
4416      A request handler for `GET /mlflow/traces/{request_id}/info` to retrieve
4417      an existing TraceInfo record from tracking store.
4418      """
4419      trace_info = _get_tracking_store().get_trace_info(request_id)
4420      trace_info = TraceInfoV2.from_v3(trace_info)
4421      response_message = GetTraceInfo.Response(trace_info=trace_info.to_proto())
4422      return _wrap_response(response_message)
4423  
4424  
4425  @catch_mlflow_exception
4426  @_disable_if_artifacts_only
4427  def _deprecated_search_traces_v2():
4428      """
4429      A request handler for `GET /mlflow/traces` to search for TraceInfo records in tracking store.
4430      """
4431      request_message = _get_request_message(
4432          SearchTraces(),
4433          schema={
4434              "experiment_ids": [
4435                  _assert_array,
4436                  _assert_item_type_string,
4437                  _assert_required,
4438              ],
4439              "filter": [_assert_string],
4440              "max_results": [
4441                  _assert_intlike,
4442                  lambda x: _assert_less_than_or_equal(int(x), 500),
4443              ],
4444              "order_by": [_assert_array, _assert_item_type_string],
4445              "page_token": [_assert_string],
4446          },
4447      )
4448  
4449      traces, token = _get_tracking_store().search_traces(
4450          experiment_ids=request_message.experiment_ids,
4451          filter_string=request_message.filter,
4452          max_results=request_message.max_results,
4453          order_by=request_message.order_by,
4454          page_token=request_message.page_token or None,
4455      )
4456      traces = [TraceInfoV2.from_v3(t) for t in traces]
4457      response_message = SearchTraces.Response()
4458      response_message.traces.extend([e.to_proto() for e in traces])
4459      if token:
4460          response_message.next_page_token = token
4461      return _wrap_response(response_message)
4462  
4463  
4464  # Logged Models APIs
4465  
4466  
4467  @catch_mlflow_exception
4468  @_disable_if_artifacts_only
4469  def get_logged_model_artifact_handler(model_id: str):
4470      artifact_file_path = request.args.get("artifact_file_path")
4471      if not artifact_file_path:
4472          raise MlflowException(
4473              'Request must include the "artifact_file_path" query parameter.',
4474              error_code=BAD_REQUEST,
4475          )
4476      validate_path_is_safe(artifact_file_path)
4477  
4478      logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id)
4479      if _is_servable_proxied_run_artifact_root(logged_model.artifact_location):
4480          artifact_repo = _get_artifact_repo_mlflow_artifacts()
4481          artifact_path = _get_proxied_run_artifact_destination_path(
4482              proxied_artifact_root=logged_model.artifact_location,
4483              relative_path=artifact_file_path,
4484          )
4485          artifact_path = _get_workspace_scoped_repo_path_if_enabled(artifact_path)
4486      else:
4487          artifact_repo = get_artifact_repository(logged_model.artifact_location)
4488          artifact_path = artifact_file_path
4489  
4490      return _send_artifact(artifact_repo, artifact_path)
4491  
4492  
4493  @catch_mlflow_exception
4494  @_disable_if_artifacts_only
4495  def _create_logged_model():
4496      request_message = _get_request_message(
4497          CreateLoggedModel(),
4498          schema={
4499              "experiment_id": [_assert_string, _assert_required],
4500              "name": [_assert_string],
4501              "model_type": [_assert_string],
4502              "source_run_id": [_assert_string],
4503              "params": [_assert_array],
4504              "tags": [_assert_array],
4505          },
4506      )
4507  
4508      model = _get_tracking_store().create_logged_model(
4509          experiment_id=request_message.experiment_id,
4510          name=request_message.name or None,
4511          model_type=request_message.model_type,
4512          source_run_id=request_message.source_run_id,
4513          params=(
4514              [LoggedModelParameter.from_proto(param) for param in request_message.params]
4515              if request_message.params
4516              else None
4517          ),
4518          tags=(
4519              [LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags]
4520              if request_message.tags
4521              else None
4522          ),
4523      )
4524      response_message = CreateLoggedModel.Response(model=model.to_proto())
4525      return _wrap_response(response_message)
4526  
4527  
4528  @catch_mlflow_exception
4529  @_disable_if_artifacts_only
4530  def _log_logged_model_params(model_id: str):
4531      request_message = _get_request_message(
4532          LogLoggedModelParamsRequest(),
4533          schema={
4534              "model_id": [_assert_string, _assert_required],
4535              "params": [_assert_array],
4536          },
4537      )
4538      params = (
4539          [LoggedModelParameter.from_proto(param) for param in request_message.params]
4540          if request_message.params
4541          else []
4542      )
4543      _get_tracking_store().log_logged_model_params(model_id, params)
4544      return _wrap_response(LogLoggedModelParamsRequest.Response())
4545  
4546  
4547  @catch_mlflow_exception
4548  @_disable_if_artifacts_only
4549  def _get_logged_model(model_id: str):
4550      allow_deleted = request.args.get("allow_deleted", "false").lower() == "true"
4551      model = _get_tracking_store().get_logged_model(model_id, allow_deleted=allow_deleted)
4552      response_message = GetLoggedModel.Response(model=model.to_proto())
4553      return _wrap_response(response_message)
4554  
4555  
4556  @catch_mlflow_exception
4557  @_disable_if_artifacts_only
4558  def _finalize_logged_model(model_id: str):
4559      request_message = _get_request_message(
4560          FinalizeLoggedModel(),
4561          schema={
4562              "model_id": [_assert_string, _assert_required],
4563              "status": [_assert_intlike, _assert_required],
4564          },
4565      )
4566      model = _get_tracking_store().finalize_logged_model(
4567          request_message.model_id, LoggedModelStatus.from_int(request_message.status)
4568      )
4569      response_message = FinalizeLoggedModel.Response(model=model.to_proto())
4570      return _wrap_response(response_message)
4571  
4572  
4573  @catch_mlflow_exception
4574  @_disable_if_artifacts_only
4575  def _delete_logged_model(model_id: str):
4576      _get_tracking_store().delete_logged_model(model_id)
4577      return _wrap_response(DeleteLoggedModel.Response())
4578  
4579  
4580  @catch_mlflow_exception
4581  @_disable_if_artifacts_only
4582  def _set_logged_model_tags(model_id: str):
4583      request_message = _get_request_message(
4584          SetLoggedModelTags(),
4585          schema={"tags": [_assert_array]},
4586      )
4587      tags = [LoggedModelTag(key=tag.key, value=tag.value) for tag in request_message.tags]
4588      _get_tracking_store().set_logged_model_tags(model_id, tags)
4589      return _wrap_response(SetLoggedModelTags.Response())
4590  
4591  
4592  @catch_mlflow_exception
4593  @_disable_if_artifacts_only
4594  def _delete_logged_model_tag(model_id: str, tag_key: str):
4595      _get_tracking_store().delete_logged_model_tag(model_id, tag_key)
4596      return _wrap_response(DeleteLoggedModelTag.Response())
4597  
4598  
4599  @catch_mlflow_exception
4600  @_disable_if_artifacts_only
4601  def _search_logged_models():
4602      request_message = _get_request_message(
4603          SearchLoggedModels(),
4604          schema={
4605              "experiment_ids": [
4606                  _assert_array,
4607                  _assert_item_type_string,
4608                  _assert_required,
4609              ],
4610              "filter": [_assert_string],
4611              "datasets": [_assert_array],
4612              "max_results": [_assert_intlike],
4613              "order_by": [_assert_array],
4614              "page_token": [_assert_string],
4615          },
4616      )
4617      models = _get_tracking_store().search_logged_models(
4618          # Convert `RepeatedScalarContainer` objects (experiment_ids and order_by) to `list`
4619          # to avoid serialization issues
4620          experiment_ids=list(request_message.experiment_ids),
4621          filter_string=request_message.filter or None,
4622          datasets=(
4623              [
4624                  {
4625                      "dataset_name": d.dataset_name,
4626                      "dataset_digest": d.dataset_digest or None,
4627                  }
4628                  for d in request_message.datasets
4629              ]
4630              if request_message.datasets
4631              else None
4632          ),
4633          max_results=request_message.max_results or None,
4634          order_by=(
4635              [
4636                  {
4637                      "field_name": ob.field_name,
4638                      "ascending": ob.ascending,
4639                      "dataset_name": ob.dataset_name or None,
4640                      "dataset_digest": ob.dataset_digest or None,
4641                  }
4642                  for ob in request_message.order_by
4643              ]
4644              if request_message.order_by
4645              else None
4646          ),
4647          page_token=request_message.page_token or None,
4648      )
4649      response_message = SearchLoggedModels.Response()
4650      response_message.models.extend([e.to_proto() for e in models])
4651      if models.token:
4652          response_message.next_page_token = models.token
4653      return _wrap_response(response_message)
4654  
4655  
4656  @catch_mlflow_exception
4657  @_disable_if_artifacts_only
4658  def _list_logged_model_artifacts(model_id: str):
4659      request_message = _get_request_message(
4660          ListLoggedModelArtifacts(),
4661          schema={"artifact_directory_path": [_assert_string]},
4662      )
4663      if request_message.HasField("artifact_directory_path"):
4664          artifact_path = validate_path_is_safe(request_message.artifact_directory_path)
4665      else:
4666          artifact_path = None
4667  
4668      return _list_logged_model_artifacts_impl(model_id, artifact_path)
4669  
4670  
4671  def _list_logged_model_artifacts_impl(
4672      model_id: str, artifact_directory_path: str | None
4673  ) -> Response:
4674      response = ListLoggedModelArtifacts.Response()
4675      logged_model: LoggedModel = _get_tracking_store().get_logged_model(model_id)
4676      if _is_servable_proxied_run_artifact_root(logged_model.artifact_location):
4677          artifacts = _list_artifacts_for_proxied_run_artifact_root(
4678              proxied_artifact_root=logged_model.artifact_location,
4679              relative_path=artifact_directory_path,
4680          )
4681      else:
4682          artifacts = get_artifact_repository(logged_model.artifact_location).list_artifacts(
4683              artifact_directory_path
4684          )
4685  
4686      response.files.extend([a.to_proto() for a in artifacts])
4687      response.root_uri = logged_model.artifact_location
4688      return _wrap_response(response)
4689  
4690  
4691  # =============================================================================
4692  # Scorer Management Handlers
4693  # =============================================================================
4694  
4695  
4696  @catch_mlflow_exception
4697  @_disable_if_artifacts_only
4698  def _register_scorer():
4699      request_message = _get_request_message(
4700          RegisterScorer(),
4701          schema={
4702              "experiment_id": [_assert_required, _assert_string],
4703              "name": [_assert_required, _assert_string],
4704              "serialized_scorer": [_assert_required, _assert_string],
4705          },
4706      )
4707      # Decorator scorers contain a `call_source` field that is executed via exec() during
4708      # deserialization. The Python client blocks this via `_check_can_be_registered()`, but
4709      # that check is client-side only and can be bypassed by calling the REST API directly.
4710      # Enforce the same restriction here in the server handler so it applies regardless of
4711      # how the request arrives.
4712      try:
4713          serialized_data = json.loads(request_message.serialized_scorer)
4714      except json.JSONDecodeError as e:
4715          raise MlflowException.invalid_parameter_value("serialized_scorer must be valid JSON") from e
4716      if serialized_data.get("call_source") is not None:
4717          raise MlflowException.invalid_parameter_value(
4718              DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR
4719          )
4720      scorer_version = _get_tracking_store().register_scorer(
4721          request_message.experiment_id,
4722          request_message.name,
4723          request_message.serialized_scorer,
4724      )
4725      response_message = RegisterScorer.Response()
4726      response_message.version = scorer_version.scorer_version
4727      response_message.scorer_id = scorer_version.scorer_id
4728      response_message.experiment_id = scorer_version.experiment_id
4729      response_message.name = scorer_version.scorer_name
4730      response_message.serialized_scorer = scorer_version._serialized_scorer
4731      response_message.creation_time = scorer_version.creation_time
4732      response = Response(mimetype="application/json")
4733      response.set_data(message_to_json(response_message))
4734      return response
4735  
4736  
4737  @catch_mlflow_exception
4738  @_disable_if_artifacts_only
4739  def _list_scorers():
4740      request_message = _get_request_message(
4741          ListScorers(),
4742          schema={"experiment_id": [_assert_required, _assert_string]},
4743      )
4744      response_message = ListScorers.Response()
4745      scorers = _get_tracking_store().list_scorers(request_message.experiment_id)
4746      response_message.scorers.extend([scorer.to_proto() for scorer in scorers])
4747      response = Response(mimetype="application/json")
4748      response.set_data(message_to_json(response_message))
4749      return response
4750  
4751  
4752  @catch_mlflow_exception
4753  @_disable_if_artifacts_only
4754  def _list_scorer_versions():
4755      request_message = _get_request_message(
4756          ListScorerVersions(),
4757          schema={
4758              "experiment_id": [_assert_required, _assert_string],
4759              "name": [_assert_required, _assert_string],
4760          },
4761      )
4762      response_message = ListScorerVersions.Response()
4763      scorers = _get_tracking_store().list_scorer_versions(
4764          request_message.experiment_id, request_message.name
4765      )
4766      response_message.scorers.extend([scorer.to_proto() for scorer in scorers])
4767      response = Response(mimetype="application/json")
4768      response.set_data(message_to_json(response_message))
4769      return response
4770  
4771  
4772  @catch_mlflow_exception
4773  @_disable_if_artifacts_only
4774  def _get_scorer():
4775      request_message = _get_request_message(
4776          GetScorer(),
4777          schema={
4778              "experiment_id": [_assert_required, _assert_string],
4779              "name": [_assert_required, _assert_string],
4780              "version": [_assert_intlike],
4781          },
4782      )
4783      response_message = GetScorer.Response()
4784      scorer_version = _get_tracking_store().get_scorer(
4785          request_message.experiment_id,
4786          request_message.name,
4787          request_message.version if request_message.HasField("version") else None,
4788      )
4789      response_message.scorer.CopyFrom(scorer_version.to_proto())
4790      response = Response(mimetype="application/json")
4791      response.set_data(message_to_json(response_message))
4792      return response
4793  
4794  
4795  @catch_mlflow_exception
4796  @_disable_if_artifacts_only
4797  def _delete_scorer():
4798      request_message = _get_request_message(
4799          DeleteScorer(),
4800          schema={
4801              "experiment_id": [_assert_required, _assert_string],
4802              "name": [_assert_required, _assert_string],
4803              "version": [_assert_intlike],
4804          },
4805      )
4806      _get_tracking_store().delete_scorer(
4807          request_message.experiment_id,
4808          request_message.name,
4809          request_message.version if request_message.HasField("version") else None,
4810      )
4811      response_message = DeleteScorer.Response()
4812      response = Response(mimetype="application/json")
4813      response.set_data(message_to_json(response_message))
4814      return response
4815  
4816  
4817  @catch_mlflow_exception
4818  @_disable_if_artifacts_only
4819  def _get_online_scoring_configs():
4820      """
4821      Get online scoring configurations for a list of scorer IDs.
4822  
4823      Query Parameters:
4824          scorer_ids: List of scorer IDs to fetch configurations for.
4825  
4826      Returns:
4827          JSON response containing a list of configurations.
4828      """
4829      request_json = _get_validated_flask_request_json(
4830          flask_request=request,
4831          schema={
4832              "scorer_ids": [_assert_required, _assert_array, _assert_item_type_string],
4833          },
4834      )
4835  
4836      scorer_ids = request_json["scorer_ids"]
4837      configs = _get_tracking_store().get_online_scoring_configs(scorer_ids)
4838  
4839      response = Response(mimetype="application/json")
4840      response.set_data(json.dumps({"configs": [c.to_dict() for c in configs]}))
4841      return response
4842  
4843  
4844  @catch_mlflow_exception
4845  @_disable_if_artifacts_only
4846  def _upsert_online_scoring_config():
4847      """
4848      Update the online scoring configuration for a registered scorer.
4849  
4850      Request Body (JSON):
4851          experiment_id: The ID of the Experiment containing the scorer.
4852          name: The scorer name.
4853          sample_rate: The sampling rate (0.0 to 1.0).
4854          filter_string: Optional filter string for trace selection.
4855  
4856      Returns:
4857          JSON response containing the updated configuration.
4858      """
4859      request_json = _get_validated_flask_request_json(
4860          flask_request=request,
4861          schema={
4862              "experiment_id": [_assert_required, _assert_string],
4863              "name": [_assert_required, _assert_string],
4864              "sample_rate": [_assert_required],
4865              "filter_string": [],
4866          },
4867      )
4868  
4869      filter_string = request_json.get("filter_string")
4870      if filter_string is not None and not isinstance(filter_string, str):
4871          raise MlflowException(
4872              f"Invalid value {filter_string!r} for parameter 'filter_string' supplied: "
4873              f"Value was of type '{type(filter_string).__name__}'. "
4874              "Expected type 'str' or None.",
4875              error_code=INVALID_PARAMETER_VALUE,
4876          )
4877  
4878      config = _get_tracking_store().upsert_online_scoring_config(
4879          experiment_id=request_json["experiment_id"],
4880          scorer_name=request_json["name"],
4881          sample_rate=float(request_json["sample_rate"]),
4882          filter_string=filter_string,
4883      )
4884  
4885      response = Response(mimetype="application/json")
4886      response.set_data(json.dumps({"config": config.to_dict()}))
4887      return response
4888  
4889  
4890  # =============================================================================
4891  # Secrets Management Handlers
4892  # =============================================================================
4893  
4894  
4895  @catch_mlflow_exception
4896  @_disable_if_artifacts_only
4897  def _create_gateway_secret():
4898      request_message = _get_request_message(
4899          CreateGatewaySecret(),
4900          schema={
4901              "secret_name": [_assert_required, _assert_string],
4902              "secret_value": [_assert_secret_value],
4903              "provider": [_assert_string],
4904              "created_by": [_assert_string],
4905          },
4906      )
4907      # Empty map means no auth_config was provided
4908      auth_config = dict(request_message.auth_config) or None
4909  
4910      secret = _get_tracking_store().create_gateway_secret(
4911          secret_name=request_message.secret_name,
4912          secret_value=dict(request_message.secret_value),
4913          provider=request_message.provider or None,
4914          auth_config=auth_config,
4915          created_by=request_message.created_by or None,
4916      )
4917      response_message = CreateGatewaySecret.Response()
4918      response_message.secret.CopyFrom(secret.to_proto())
4919      return _wrap_response(response_message)
4920  
4921  
4922  @catch_mlflow_exception
4923  @_disable_if_artifacts_only
4924  def _get_gateway_secret_info():
4925      request_message = _get_request_message(
4926          GetGatewaySecretInfo(),
4927          schema={
4928              "secret_id": [_assert_required, _assert_string],
4929          },
4930      )
4931      secret = _get_tracking_store().get_secret_info(request_message.secret_id)
4932      response_message = GetGatewaySecretInfo.Response()
4933      response_message.secret.CopyFrom(secret.to_proto())
4934      return _wrap_response(response_message)
4935  
4936  
4937  @catch_mlflow_exception
4938  @_disable_if_artifacts_only
4939  def _update_gateway_secret():
4940      request_message = _get_request_message(
4941          UpdateGatewaySecret(),
4942          schema={
4943              "secret_id": [_assert_required, _assert_string],
4944              "updated_by": [_assert_string],
4945          },
4946      )
4947      # Empty map means no auth_config was provided
4948      auth_config = dict(request_message.auth_config) or None
4949  
4950      # Empty map means no update to secret_value
4951      secret_value = dict(request_message.secret_value) or None
4952  
4953      secret = _get_tracking_store().update_gateway_secret(
4954          secret_id=request_message.secret_id,
4955          secret_value=secret_value,
4956          auth_config=auth_config,
4957          updated_by=request_message.updated_by or None,
4958      )
4959      response_message = UpdateGatewaySecret.Response()
4960      response_message.secret.CopyFrom(secret.to_proto())
4961      return _wrap_response(response_message)
4962  
4963  
4964  @catch_mlflow_exception
4965  @_disable_if_artifacts_only
4966  def _delete_gateway_secret():
4967      request_message = _get_request_message(
4968          DeleteGatewaySecret(),
4969          schema={
4970              "secret_id": [_assert_required, _assert_string],
4971          },
4972      )
4973      _get_tracking_store().delete_gateway_secret(request_message.secret_id)
4974      response_message = DeleteGatewaySecret.Response()
4975      return _wrap_response(response_message)
4976  
4977  
4978  @catch_mlflow_exception
4979  @_disable_if_artifacts_only
4980  def _list_gateway_secrets():
4981      request_message = _get_request_message(
4982          ListGatewaySecretInfos(),
4983          schema={
4984              "provider": [_assert_string],
4985          },
4986      )
4987      secrets = _get_tracking_store().list_secret_infos(
4988          provider=request_message.provider or None,
4989      )
4990      response_message = ListGatewaySecretInfos.Response()
4991      response_message.secrets.extend([s.to_proto() for s in secrets])
4992      return _wrap_response(response_message)
4993  
4994  
4995  # =============================================================================
4996  # Endpoints Management Handlers
4997  # =============================================================================
4998  
4999  
5000  @catch_mlflow_exception
5001  @_disable_if_artifacts_only
5002  def _create_gateway_endpoint():
5003      request_message = _get_request_message(
5004          CreateGatewayEndpoint(),
5005          schema={
5006              "name": [_assert_required, _assert_string],
5007              "created_by": [_assert_string],
5008              "model_configs": [_assert_required],
5009              "routing_strategy": [_assert_string],
5010          },
5011      )
5012      if request_message.name and not is_valid_endpoint_name(request_message.name):
5013          raise MlflowException.invalid_parameter_value(
5014              f"Invalid endpoint name '{request_message.name}'. "
5015              "Name can only contain letters, numbers, underscores, hyphens, and dots."
5016          )
5017      # Convert proto fallback_config to entity FallbackConfig
5018      fallback_config = None
5019      if request_message.HasField("fallback_config"):
5020          fallback_config = FallbackConfig(
5021              strategy=FallbackStrategy.from_proto(request_message.fallback_config.strategy)
5022              if request_message.fallback_config.HasField("strategy")
5023              else None,
5024              max_attempts=request_message.fallback_config.max_attempts
5025              if request_message.fallback_config.HasField("max_attempts")
5026              else None,
5027          )
5028  
5029      model_configs = [
5030          GatewayEndpointModelConfig.from_proto(config) for config in request_message.model_configs
5031      ]
5032  
5033      # Determine experiment_id and usage_tracking
5034      experiment_id = (
5035          request_message.experiment_id if request_message.HasField("experiment_id") else None
5036      )
5037      usage_tracking = (
5038          request_message.usage_tracking if request_message.HasField("usage_tracking") else True
5039      )
5040  
5041      endpoint = _get_tracking_store().create_gateway_endpoint(
5042          name=request_message.name or None,
5043          model_configs=model_configs,
5044          created_by=request_message.created_by or None,
5045          routing_strategy=RoutingStrategyEntity.from_proto(request_message.routing_strategy)
5046          if request_message.HasField("routing_strategy")
5047          else None,
5048          fallback_config=fallback_config,
5049          experiment_id=experiment_id,
5050          usage_tracking=usage_tracking,
5051      )
5052      response_message = CreateGatewayEndpoint.Response()
5053      response_message.endpoint.CopyFrom(endpoint.to_proto())
5054      return _wrap_response(response_message)
5055  
5056  
5057  @catch_mlflow_exception
5058  @_disable_if_artifacts_only
5059  def _get_gateway_endpoint():
5060      request_message = _get_request_message(
5061          GetGatewayEndpoint(),
5062          schema={
5063              "endpoint_id": [_assert_string],
5064              "name": [_assert_string],
5065          },
5066      )
5067      endpoint = _get_tracking_store().get_gateway_endpoint(
5068          endpoint_id=request_message.endpoint_id or None,
5069          name=request_message.name or None,
5070      )
5071      response_message = GetGatewayEndpoint.Response()
5072      response_message.endpoint.CopyFrom(endpoint.to_proto())
5073      return _wrap_response(response_message)
5074  
5075  
5076  @catch_mlflow_exception
5077  @_disable_if_artifacts_only
5078  def _update_gateway_endpoint():
5079      request_message = _get_request_message(
5080          UpdateGatewayEndpoint(),
5081          schema={
5082              "endpoint_id": [_assert_required, _assert_string],
5083              "name": [_assert_string],
5084              "updated_by": [_assert_string],
5085              "routing_strategy": [_assert_string],
5086          },
5087      )
5088      if request_message.name and not is_valid_endpoint_name(request_message.name):
5089          raise MlflowException.invalid_parameter_value(
5090              f"Invalid endpoint name '{request_message.name}'. "
5091              "Name can only contain letters, numbers, underscores, hyphens, and dots."
5092          )
5093      # Convert proto fallback_config to entity FallbackConfig
5094      fallback_config = None
5095      if request_message.HasField("fallback_config"):
5096          fallback_config = FallbackConfig(
5097              strategy=FallbackStrategy.from_proto(request_message.fallback_config.strategy)
5098              if request_message.fallback_config.HasField("strategy")
5099              else None,
5100              max_attempts=request_message.fallback_config.max_attempts
5101              if request_message.fallback_config.HasField("max_attempts")
5102              else None,
5103          )
5104  
5105      # Convert proto model_configs to entity GatewayEndpointModelConfig list
5106      model_configs = None
5107      if request_message.model_configs:
5108          model_configs = [
5109              GatewayEndpointModelConfig.from_proto(config)
5110              for config in request_message.model_configs
5111          ]
5112  
5113      # Determine experiment_id and usage_tracking
5114      experiment_id = (
5115          request_message.experiment_id if request_message.HasField("experiment_id") else None
5116      )
5117      usage_tracking = (
5118          request_message.usage_tracking if request_message.HasField("usage_tracking") else None
5119      )
5120  
5121      endpoint = _get_tracking_store().update_gateway_endpoint(
5122          endpoint_id=request_message.endpoint_id,
5123          name=request_message.name or None,
5124          model_configs=model_configs,
5125          updated_by=request_message.updated_by or None,
5126          routing_strategy=RoutingStrategyEntity.from_proto(request_message.routing_strategy)
5127          if request_message.HasField("routing_strategy")
5128          else None,
5129          fallback_config=fallback_config,
5130          experiment_id=experiment_id,
5131          usage_tracking=usage_tracking,
5132      )
5133      response_message = UpdateGatewayEndpoint.Response()
5134      response_message.endpoint.CopyFrom(endpoint.to_proto())
5135      return _wrap_response(response_message)
5136  
5137  
5138  @catch_mlflow_exception
5139  @_disable_if_artifacts_only
5140  def _delete_gateway_endpoint():
5141      request_message = _get_request_message(
5142          DeleteGatewayEndpoint(),
5143          schema={
5144              "endpoint_id": [_assert_required, _assert_string],
5145          },
5146      )
5147      _get_tracking_store().delete_gateway_endpoint(request_message.endpoint_id)
5148      response_message = DeleteGatewayEndpoint.Response()
5149      return _wrap_response(response_message)
5150  
5151  
5152  @catch_mlflow_exception
5153  @_disable_if_artifacts_only
5154  def _list_gateway_endpoints():
5155      request_message = _get_request_message(
5156          ListGatewayEndpoints(),
5157          schema={
5158              "provider": [_assert_string],
5159          },
5160      )
5161      endpoints = _get_tracking_store().list_gateway_endpoints(
5162          provider=request_message.provider or None,
5163      )
5164      response_message = ListGatewayEndpoints.Response()
5165      response_message.endpoints.extend([e.to_proto() for e in endpoints])
5166      return _wrap_response(response_message)
5167  
5168  
5169  # =============================================================================
5170  # Model Definitions Management Handlers
5171  # =============================================================================
5172  
5173  
5174  @catch_mlflow_exception
5175  @_disable_if_artifacts_only
5176  def _create_gateway_model_definition():
5177      request_message = _get_request_message(
5178          CreateGatewayModelDefinition(),
5179          schema={
5180              "name": [_assert_required, _assert_string],
5181              "secret_id": [_assert_required, _assert_string],
5182              "provider": [_assert_required, _assert_string],
5183              "model_name": [_assert_required, _assert_string],
5184              "created_by": [_assert_string],
5185          },
5186      )
5187      model_definition = _get_tracking_store().create_gateway_model_definition(
5188          name=request_message.name,
5189          secret_id=request_message.secret_id,
5190          provider=request_message.provider,
5191          model_name=request_message.model_name,
5192          created_by=request_message.created_by or None,
5193      )
5194      response_message = CreateGatewayModelDefinition.Response()
5195      response_message.model_definition.CopyFrom(model_definition.to_proto())
5196      return _wrap_response(response_message)
5197  
5198  
5199  @catch_mlflow_exception
5200  @_disable_if_artifacts_only
5201  def _get_gateway_model_definition():
5202      request_message = _get_request_message(
5203          GetGatewayModelDefinition(),
5204          schema={
5205              "model_definition_id": [_assert_required, _assert_string],
5206          },
5207      )
5208      model_definition = _get_tracking_store().get_gateway_model_definition(
5209          request_message.model_definition_id
5210      )
5211      response_message = GetGatewayModelDefinition.Response()
5212      response_message.model_definition.CopyFrom(model_definition.to_proto())
5213      return _wrap_response(response_message)
5214  
5215  
5216  @catch_mlflow_exception
5217  @_disable_if_artifacts_only
5218  def _list_gateway_model_definitions():
5219      request_message = _get_request_message(
5220          ListGatewayModelDefinitions(),
5221          schema={
5222              "provider": [_assert_string],
5223              "secret_id": [_assert_string],
5224          },
5225      )
5226      model_definitions = _get_tracking_store().list_gateway_model_definitions(
5227          provider=request_message.provider or None,
5228          secret_id=request_message.secret_id or None,
5229      )
5230      response_message = ListGatewayModelDefinitions.Response()
5231      response_message.model_definitions.extend([m.to_proto() for m in model_definitions])
5232      return _wrap_response(response_message)
5233  
5234  
5235  @catch_mlflow_exception
5236  @_disable_if_artifacts_only
5237  def _update_gateway_model_definition():
5238      request_message = _get_request_message(
5239          UpdateGatewayModelDefinition(),
5240          schema={
5241              "model_definition_id": [_assert_required, _assert_string],
5242              "name": [_assert_string],
5243              "secret_id": [_assert_string],
5244              "model_name": [_assert_string],
5245              "updated_by": [_assert_string],
5246              "provider": [_assert_string],
5247          },
5248      )
5249      model_definition = _get_tracking_store().update_gateway_model_definition(
5250          model_definition_id=request_message.model_definition_id,
5251          name=request_message.name or None,
5252          secret_id=request_message.secret_id or None,
5253          model_name=request_message.model_name or None,
5254          updated_by=request_message.updated_by or None,
5255          provider=request_message.provider or None,
5256      )
5257      response_message = UpdateGatewayModelDefinition.Response()
5258      response_message.model_definition.CopyFrom(model_definition.to_proto())
5259      return _wrap_response(response_message)
5260  
5261  
5262  @catch_mlflow_exception
5263  @_disable_if_artifacts_only
5264  def _delete_gateway_model_definition():
5265      request_message = _get_request_message(
5266          DeleteGatewayModelDefinition(),
5267          schema={
5268              "model_definition_id": [_assert_required, _assert_string],
5269          },
5270      )
5271      _get_tracking_store().delete_gateway_model_definition(request_message.model_definition_id)
5272      response_message = DeleteGatewayModelDefinition.Response()
5273      return _wrap_response(response_message)
5274  
5275  
5276  # =============================================================================
5277  # Endpoint Model Mappings Management Handlers
5278  # =============================================================================
5279  
5280  
5281  @catch_mlflow_exception
5282  @_disable_if_artifacts_only
5283  def _attach_model_to_gateway_endpoint():
5284      request_message = _get_request_message(
5285          AttachModelToGatewayEndpoint(),
5286          schema={
5287              "endpoint_id": [_assert_required, _assert_string],
5288              "model_config": [_assert_required],
5289              "created_by": [_assert_string],
5290          },
5291      )
5292  
5293      model_config = GatewayEndpointModelConfig.from_proto(request_message.model_config)
5294  
5295      mapping = _get_tracking_store().attach_model_to_endpoint(
5296          endpoint_id=request_message.endpoint_id,
5297          model_config=model_config,
5298          created_by=request_message.created_by or None,
5299      )
5300      response_message = AttachModelToGatewayEndpoint.Response()
5301      response_message.mapping.CopyFrom(mapping.to_proto())
5302      return _wrap_response(response_message)
5303  
5304  
5305  @catch_mlflow_exception
5306  @_disable_if_artifacts_only
5307  def _detach_model_from_gateway_endpoint():
5308      request_message = _get_request_message(
5309          DetachModelFromGatewayEndpoint(),
5310          schema={
5311              "endpoint_id": [_assert_required, _assert_string],
5312              "model_definition_id": [_assert_required, _assert_string],
5313          },
5314      )
5315      _get_tracking_store().detach_model_from_endpoint(
5316          endpoint_id=request_message.endpoint_id,
5317          model_definition_id=request_message.model_definition_id,
5318      )
5319      response_message = DetachModelFromGatewayEndpoint.Response()
5320      return _wrap_response(response_message)
5321  
5322  
5323  # =============================================================================
5324  # Endpoint Bindings Management Handlers
5325  # =============================================================================
5326  
5327  
5328  @catch_mlflow_exception
5329  @_disable_if_artifacts_only
5330  def _create_gateway_endpoint_binding():
5331      request_message = _get_request_message(
5332          CreateGatewayEndpointBinding(),
5333          schema={
5334              "endpoint_id": [_assert_required, _assert_string],
5335              "resource_type": [_assert_required, _assert_string],
5336              "resource_id": [_assert_required, _assert_string],
5337              "created_by": [_assert_string],
5338          },
5339      )
5340      binding = _get_tracking_store().create_endpoint_binding(
5341          endpoint_id=request_message.endpoint_id,
5342          resource_type=GatewayResourceType(request_message.resource_type),
5343          resource_id=request_message.resource_id,
5344          created_by=request_message.created_by or None,
5345      )
5346      response_message = CreateGatewayEndpointBinding.Response()
5347      response_message.binding.CopyFrom(binding.to_proto())
5348      return _wrap_response(response_message)
5349  
5350  
5351  @catch_mlflow_exception
5352  @_disable_if_artifacts_only
5353  def _delete_gateway_endpoint_binding():
5354      request_message = _get_request_message(
5355          DeleteGatewayEndpointBinding(),
5356          schema={
5357              "endpoint_id": [_assert_required, _assert_string],
5358              "resource_type": [_assert_required, _assert_string],
5359              "resource_id": [_assert_required, _assert_string],
5360          },
5361      )
5362      _get_tracking_store().delete_endpoint_binding(
5363          endpoint_id=request_message.endpoint_id,
5364          resource_type=request_message.resource_type,
5365          resource_id=request_message.resource_id,
5366      )
5367      response_message = DeleteGatewayEndpointBinding.Response()
5368      return _wrap_response(response_message)
5369  
5370  
5371  @catch_mlflow_exception
5372  @_disable_if_artifacts_only
5373  def _list_gateway_endpoint_bindings():
5374      request_message = _get_request_message(
5375          ListGatewayEndpointBindings(),
5376          schema={
5377              "endpoint_id": [_assert_string],
5378              "resource_type": [_assert_string],
5379              "resource_id": [_assert_string],
5380          },
5381      )
5382      bindings = _get_tracking_store().list_endpoint_bindings(
5383          endpoint_id=request_message.endpoint_id or None,
5384          resource_type=request_message.resource_type or None,
5385          resource_id=request_message.resource_id or None,
5386      )
5387      response_message = ListGatewayEndpointBindings.Response()
5388      response_message.bindings.extend([b.to_proto() for b in bindings])
5389      return _wrap_response(response_message)
5390  
5391  
5392  @catch_mlflow_exception
5393  @_disable_if_artifacts_only
5394  def _set_gateway_endpoint_tag():
5395      request_message = _get_request_message(
5396          SetGatewayEndpointTag(),
5397          schema={
5398              "endpoint_id": [_assert_required, _assert_string],
5399              "key": [_assert_required, _assert_string],
5400              "value": [_assert_string],
5401          },
5402      )
5403      tag = GatewayEndpointTag(request_message.key, request_message.value)
5404      _get_tracking_store().set_gateway_endpoint_tag(request_message.endpoint_id, tag)
5405      response_message = SetGatewayEndpointTag.Response()
5406      response = Response(mimetype="application/json")
5407      response.set_data(message_to_json(response_message))
5408      return response
5409  
5410  
5411  @catch_mlflow_exception
5412  @_disable_if_artifacts_only
5413  def _delete_gateway_endpoint_tag():
5414      request_message = _get_request_message(
5415          DeleteGatewayEndpointTag(),
5416          schema={
5417              "endpoint_id": [_assert_required, _assert_string],
5418              "key": [_assert_required, _assert_string],
5419          },
5420      )
5421      _get_tracking_store().delete_gateway_endpoint_tag(
5422          request_message.endpoint_id, request_message.key
5423      )
5424      response_message = DeleteGatewayEndpointTag.Response()
5425      response = Response(mimetype="application/json")
5426      response.set_data(message_to_json(response_message))
5427      return response
5428  
5429  
5430  # =============================================================================
5431  # Budget Policy Management Handlers
5432  # =============================================================================
5433  
5434  
5435  @catch_mlflow_exception
5436  @_disable_if_artifacts_only
5437  def _create_budget_policy():
5438      request_message = _get_request_message(
5439          CreateGatewayBudgetPolicy(),
5440          schema={
5441              "budget_unit": [_assert_required],
5442              "budget_amount": [_assert_required],
5443              "duration": [_assert_required],
5444              "target_scope": [_assert_required],
5445              "budget_action": [_assert_required],
5446              "created_by": [_assert_string],
5447          },
5448      )
5449      budget_unit = BudgetUnit.from_proto(request_message.budget_unit)
5450      if budget_unit is None:
5451          raise MlflowException(
5452              message=f"Invalid budget_unit: {request_message.budget_unit}",
5453              error_code=INVALID_PARAMETER_VALUE,
5454          )
5455      duration_unit = BudgetDurationUnit.from_proto(request_message.duration.unit)
5456      if duration_unit is None:
5457          raise MlflowException(
5458              message=f"Invalid duration.unit: {request_message.duration.unit}",
5459              error_code=INVALID_PARAMETER_VALUE,
5460          )
5461      if request_message.duration.value <= 0:
5462          raise MlflowException(
5463              message=f"duration.value must be a positive integer, got "
5464              f"{request_message.duration.value}",
5465              error_code=INVALID_PARAMETER_VALUE,
5466          )
5467      target_scope = BudgetTargetScope.from_proto(request_message.target_scope)
5468      if target_scope is None:
5469          raise MlflowException(
5470              message=f"Invalid target_scope: {request_message.target_scope}",
5471              error_code=INVALID_PARAMETER_VALUE,
5472          )
5473      budget_action = BudgetAction.from_proto(request_message.budget_action)
5474      if budget_action is None:
5475          raise MlflowException(
5476              message=f"Invalid budget_action: {request_message.budget_action}",
5477              error_code=INVALID_PARAMETER_VALUE,
5478          )
5479      store = _get_tracking_store()
5480      policy = store.create_budget_policy(
5481          budget_unit=budget_unit,
5482          budget_amount=request_message.budget_amount,
5483          duration=BudgetDuration(unit=duration_unit, value=request_message.duration.value),
5484          target_scope=target_scope,
5485          budget_action=budget_action,
5486          created_by=request_message.created_by or None,
5487      )
5488      get_budget_tracker().invalidate()
5489      maybe_refresh_budget_policies(store)
5490      response_message = CreateGatewayBudgetPolicy.Response()
5491      response_message.budget_policy.CopyFrom(policy.to_proto())
5492      return _wrap_response(response_message)
5493  
5494  
5495  @catch_mlflow_exception
5496  @_disable_if_artifacts_only
5497  def _get_budget_policy():
5498      request_message = _get_request_message(
5499          GetGatewayBudgetPolicy(),
5500          schema={
5501              "budget_policy_id": [_assert_required, _assert_string],
5502          },
5503      )
5504      policy = _get_tracking_store().get_budget_policy(
5505          budget_policy_id=request_message.budget_policy_id,
5506      )
5507      response_message = GetGatewayBudgetPolicy.Response()
5508      response_message.budget_policy.CopyFrom(policy.to_proto())
5509      return _wrap_response(response_message)
5510  
5511  
5512  @catch_mlflow_exception
5513  @_disable_if_artifacts_only
5514  def _update_budget_policy():
5515      request_message = _get_request_message(
5516          UpdateGatewayBudgetPolicy(),
5517          schema={
5518              "budget_policy_id": [_assert_required, _assert_string],
5519              "updated_by": [_assert_string],
5520          },
5521      )
5522      budget_unit = None
5523      if request_message.HasField("budget_unit"):
5524          budget_unit = BudgetUnit.from_proto(request_message.budget_unit)
5525          if budget_unit is None:
5526              raise MlflowException(
5527                  message=f"Invalid budget_unit: {request_message.budget_unit}",
5528                  error_code=INVALID_PARAMETER_VALUE,
5529              )
5530      duration = None
5531      if request_message.HasField("duration"):
5532          duration_unit = BudgetDurationUnit.from_proto(request_message.duration.unit)
5533          if duration_unit is None:
5534              raise MlflowException(
5535                  message=f"Invalid duration.unit: {request_message.duration.unit}",
5536                  error_code=INVALID_PARAMETER_VALUE,
5537              )
5538          if request_message.duration.value <= 0:
5539              raise MlflowException(
5540                  message=f"duration.value must be a positive integer, got "
5541                  f"{request_message.duration.value}",
5542                  error_code=INVALID_PARAMETER_VALUE,
5543              )
5544          duration = BudgetDuration(unit=duration_unit, value=request_message.duration.value)
5545      target_scope = None
5546      if request_message.HasField("target_scope"):
5547          target_scope = BudgetTargetScope.from_proto(request_message.target_scope)
5548          if target_scope is None:
5549              raise MlflowException(
5550                  message=f"Invalid target_scope: {request_message.target_scope}",
5551                  error_code=INVALID_PARAMETER_VALUE,
5552              )
5553      budget_action = None
5554      if request_message.HasField("budget_action"):
5555          budget_action = BudgetAction.from_proto(request_message.budget_action)
5556          if budget_action is None:
5557              raise MlflowException(
5558                  message=f"Invalid budget_action: {request_message.budget_action}",
5559                  error_code=INVALID_PARAMETER_VALUE,
5560              )
5561      store = _get_tracking_store()
5562      policy = store.update_budget_policy(
5563          budget_policy_id=request_message.budget_policy_id,
5564          budget_unit=budget_unit,
5565          budget_amount=request_message.budget_amount
5566          if request_message.HasField("budget_amount")
5567          else None,
5568          duration=duration,
5569          target_scope=target_scope,
5570          budget_action=budget_action,
5571          updated_by=request_message.updated_by or None,
5572      )
5573      get_budget_tracker().invalidate()
5574      maybe_refresh_budget_policies(store)
5575      response_message = UpdateGatewayBudgetPolicy.Response()
5576      response_message.budget_policy.CopyFrom(policy.to_proto())
5577      return _wrap_response(response_message)
5578  
5579  
5580  @catch_mlflow_exception
5581  @_disable_if_artifacts_only
5582  def _delete_budget_policy():
5583      request_message = _get_request_message(
5584          DeleteGatewayBudgetPolicy(),
5585          schema={
5586              "budget_policy_id": [_assert_required, _assert_string],
5587          },
5588      )
5589      store = _get_tracking_store()
5590      store.delete_budget_policy(request_message.budget_policy_id)
5591      get_budget_tracker().invalidate()
5592      maybe_refresh_budget_policies(store)
5593      response_message = DeleteGatewayBudgetPolicy.Response()
5594      return _wrap_response(response_message)
5595  
5596  
5597  @catch_mlflow_exception
5598  @_disable_if_artifacts_only
5599  def _list_budget_policies():
5600      request_message = _get_request_message(
5601          ListGatewayBudgetPolicies(),
5602          schema={
5603              "max_results": [_assert_intlike],
5604              "page_token": [_assert_string],
5605          },
5606      )
5607      budget_policies = _get_tracking_store().list_budget_policies(
5608          max_results=request_message.max_results or SEARCH_MAX_RESULTS_DEFAULT,
5609          page_token=request_message.page_token or None,
5610      )
5611      response_message = ListGatewayBudgetPolicies.Response()
5612      response_message.budget_policies.extend([p.to_proto() for p in budget_policies])
5613      if budget_policies.token:
5614          response_message.next_page_token = budget_policies.token
5615      return _wrap_response(response_message)
5616  
5617  
5618  @catch_mlflow_exception
5619  @_disable_if_artifacts_only
5620  def _list_budget_windows():
5621      _get_request_message(ListGatewayBudgetWindows())
5622      store = _get_tracking_store()
5623      maybe_refresh_budget_policies(store)
5624      windows = get_budget_tracker().get_all_windows()
5625      response_message = ListGatewayBudgetWindows.Response()
5626      for w in windows:
5627          window_msg = ListGatewayBudgetWindows.BudgetWindow(
5628              budget_policy_id=w.policy.budget_policy_id,
5629              window_start_ms=int(w.window_start.timestamp() * 1000),
5630              window_end_ms=int(w.window_end.timestamp() * 1000),
5631              current_spend=w.cumulative_spend,
5632          )
5633          response_message.windows.append(window_msg)
5634      return _wrap_response(response_message)
5635  
5636  
5637  @catch_mlflow_exception
5638  @_disable_if_artifacts_only
5639  def _create_gateway_guardrail():
5640      request_message = _get_request_message(
5641          CreateGatewayGuardrail(),
5642          schema={
5643              "name": [_assert_required, _assert_string],
5644              "scorer_id": [_assert_required, _assert_string],
5645              "scorer_version": [_assert_required, _assert_intlike],
5646              "stage": [_assert_required],
5647              "action": [_assert_required],
5648              "action_endpoint_id": [_assert_string],
5649          },
5650      )
5651      from mlflow.entities.gateway_guardrail import GuardrailAction, GuardrailStage
5652  
5653      stage = GuardrailStage.from_proto(request_message.stage)
5654      if stage is None:
5655          raise MlflowException(
5656              message=f"Invalid stage: {request_message.stage}",
5657              error_code=INVALID_PARAMETER_VALUE,
5658          )
5659      action = GuardrailAction.from_proto(request_message.action)
5660      if action is None:
5661          raise MlflowException(
5662              message=f"Invalid action: {request_message.action}",
5663              error_code=INVALID_PARAMETER_VALUE,
5664          )
5665      guardrail = _get_tracking_store().create_gateway_guardrail(
5666          name=request_message.name,
5667          scorer_id=request_message.scorer_id,
5668          scorer_version=request_message.scorer_version,
5669          stage=stage,
5670          action=action,
5671          action_endpoint_id=request_message.action_endpoint_id or None,
5672          created_by=_get_user(),
5673      )
5674      response_message = CreateGatewayGuardrail.Response()
5675      response_message.guardrail.CopyFrom(guardrail.to_proto())
5676      return _wrap_response(response_message)
5677  
5678  
5679  @catch_mlflow_exception
5680  @_disable_if_artifacts_only
5681  def _get_gateway_guardrail():
5682      request_message = _get_request_message(
5683          GetGatewayGuardrail(),
5684          schema={"guardrail_id": [_assert_required, _assert_string]},
5685      )
5686      guardrail = _get_tracking_store().get_gateway_guardrail(
5687          guardrail_id=request_message.guardrail_id,
5688      )
5689      response_message = GetGatewayGuardrail.Response()
5690      response_message.guardrail.CopyFrom(guardrail.to_proto())
5691      return _wrap_response(response_message)
5692  
5693  
5694  @catch_mlflow_exception
5695  @_disable_if_artifacts_only
5696  def _delete_gateway_guardrail():
5697      request_message = _get_request_message(
5698          DeleteGatewayGuardrail(),
5699          schema={"guardrail_id": [_assert_required, _assert_string]},
5700      )
5701      _get_tracking_store().delete_gateway_guardrail(request_message.guardrail_id)
5702      return _wrap_response(DeleteGatewayGuardrail.Response())
5703  
5704  
5705  @catch_mlflow_exception
5706  @_disable_if_artifacts_only
5707  def _list_gateway_guardrails():
5708      request_message = _get_request_message(
5709          ListGatewayGuardrails(),
5710          schema={
5711              "max_results": [_assert_intlike],
5712              "page_token": [_assert_string],
5713          },
5714      )
5715      guardrails = _get_tracking_store().list_gateway_guardrails(
5716          max_results=request_message.max_results or SEARCH_MAX_RESULTS_DEFAULT,
5717          page_token=request_message.page_token or None,
5718      )
5719      response_message = ListGatewayGuardrails.Response()
5720      response_message.guardrails.extend([g.to_proto() for g in guardrails])
5721      if guardrails.token:
5722          response_message.next_page_token = guardrails.token
5723      return _wrap_response(response_message)
5724  
5725  
5726  @catch_mlflow_exception
5727  @_disable_if_artifacts_only
5728  def _add_guardrail_to_endpoint():
5729      request_message = _get_request_message(
5730          AddGuardrailToEndpoint(),
5731          schema={
5732              "endpoint_id": [_assert_required, _assert_string],
5733              "guardrail_id": [_assert_required, _assert_string],
5734              "execution_order": [_assert_intlike],
5735          },
5736      )
5737      config = _get_tracking_store().add_guardrail_to_endpoint(
5738          endpoint_id=request_message.endpoint_id,
5739          guardrail_id=request_message.guardrail_id,
5740          execution_order=(
5741              request_message.execution_order if request_message.HasField("execution_order") else None
5742          ),
5743          created_by=_get_user(),
5744      )
5745      response_message = AddGuardrailToEndpoint.Response()
5746      response_message.config.CopyFrom(config.to_proto())
5747      return _wrap_response(response_message)
5748  
5749  
5750  @catch_mlflow_exception
5751  @_disable_if_artifacts_only
5752  def _remove_guardrail_from_endpoint():
5753      request_message = _get_request_message(
5754          RemoveGuardrailFromEndpoint(),
5755          schema={
5756              "endpoint_id": [_assert_required, _assert_string],
5757              "guardrail_id": [_assert_required, _assert_string],
5758          },
5759      )
5760      _get_tracking_store().remove_guardrail_from_endpoint(
5761          endpoint_id=request_message.endpoint_id,
5762          guardrail_id=request_message.guardrail_id,
5763      )
5764      return _wrap_response(RemoveGuardrailFromEndpoint.Response())
5765  
5766  
5767  @catch_mlflow_exception
5768  @_disable_if_artifacts_only
5769  def _list_endpoint_guardrail_configs():
5770      request_message = _get_request_message(
5771          ListEndpointGuardrailConfigs(),
5772          schema={"endpoint_id": [_assert_required, _assert_string]},
5773      )
5774      configs = _get_tracking_store().list_endpoint_guardrail_configs(
5775          endpoint_id=request_message.endpoint_id,
5776      )
5777      response_message = ListEndpointGuardrailConfigs.Response()
5778      response_message.configs.extend([c.to_proto() for c in configs])
5779      return _wrap_response(response_message)
5780  
5781  
5782  @catch_mlflow_exception
5783  def _update_endpoint_guardrail_config():
5784      request_message = _get_request_message(
5785          UpdateEndpointGuardrailConfig(),
5786          schema={
5787              "endpoint_id": [_assert_required, _assert_string],
5788              "guardrail_id": [_assert_required, _assert_string],
5789          },
5790      )
5791      kwargs = {
5792          "endpoint_id": request_message.endpoint_id,
5793          "guardrail_id": request_message.guardrail_id,
5794      }
5795      if request_message.HasField("execution_order"):
5796          kwargs["execution_order"] = request_message.execution_order
5797      config = _get_tracking_store().update_endpoint_guardrail_config(**kwargs)
5798      response_message = UpdateEndpointGuardrailConfig.Response()
5799      response_message.config.CopyFrom(config.to_proto())
5800      return _wrap_response(response_message)
5801  
5802  
5803  @catch_mlflow_exception
5804  def _get_server_info():
5805      from mlflow.store.tracking.file_store import FileStore
5806      from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
5807  
5808      store = _get_tracking_store()
5809  
5810      if isinstance(store, FileStore):
5811          store_type = "FileStore"
5812      elif isinstance(store, SqlAlchemyStore):
5813          store_type = "SqlStore"
5814      else:
5815          store_type = None
5816      return jsonify({
5817          "store_type": store_type,
5818          "workspaces_enabled": MLFLOW_ENABLE_WORKSPACES.get(),
5819      })
5820  
5821  
5822  @catch_mlflow_exception
5823  @_disable_if_artifacts_only
5824  def _list_supported_providers():
5825      try:
5826          providers = get_all_providers()
5827          return jsonify({"providers": sorted(providers)})
5828      except ImportError as e:
5829          raise MlflowException(str(e), error_code=INVALID_PARAMETER_VALUE)
5830  
5831  
5832  @catch_mlflow_exception
5833  @_disable_if_artifacts_only
5834  def _list_supported_models():
5835      try:
5836          provider_filter = request.args.get("provider")
5837          models = get_models(provider=provider_filter)
5838          return jsonify({"models": models})
5839      except ImportError as e:
5840          raise MlflowException(str(e), error_code=INVALID_PARAMETER_VALUE)
5841  
5842  
5843  @catch_mlflow_exception
5844  @_disable_if_artifacts_only
5845  def _get_provider_config():
5846      try:
5847          provider = request.args.get("provider")
5848          config = get_provider_config_response(provider)
5849          return jsonify(config)
5850      except (ImportError, ValueError) as e:
5851          raise MlflowException(str(e), error_code=INVALID_PARAMETER_VALUE)
5852  
5853  
5854  @catch_mlflow_exception
5855  @_disable_if_artifacts_only
5856  def _get_secrets_config():
5857      using_default_passphrase = not os.environ.get(CRYPTO_KEK_PASSPHRASE_ENV_VAR)
5858      return jsonify({
5859          "secrets_available": True,
5860          "using_default_passphrase": using_default_passphrase,
5861      })
5862  
5863  
5864  @catch_mlflow_exception
5865  @_disable_if_artifacts_only
5866  def _invoke_scorer_handler():
5867      """
5868      Invoke a scorer on traces asynchronously.
5869  
5870      This is a UI-only AJAX endpoint for invoking scorers from the frontend.
5871      """
5872      _validate_content_type(request, ["application/json"])
5873  
5874      args = request.json
5875      experiment_id = args.get("experiment_id")
5876      serialized_scorer = args.get("serialized_scorer")
5877      trace_ids = args.get("trace_ids", [])
5878      log_assessments = args.get("log_assessments", False)
5879  
5880      if not experiment_id:
5881          raise MlflowException(
5882              "Missing required parameter: experiment_id",
5883              error_code=INVALID_PARAMETER_VALUE,
5884          )
5885      if not serialized_scorer:
5886          raise MlflowException(
5887              "Missing required parameter: serialized_scorer",
5888              error_code=INVALID_PARAMETER_VALUE,
5889          )
5890      if not trace_ids:
5891          raise MlflowException(
5892              "Please select at least one trace to evaluate.",
5893              error_code=INVALID_PARAMETER_VALUE,
5894          )
5895  
5896      from mlflow.genai.scorers.base import Scorer
5897      from mlflow.genai.scorers.job import get_trace_batches_for_scorer, invoke_scorer_job
5898      from mlflow.server.jobs import submit_job
5899  
5900      scorer = Scorer.model_validate_json(serialized_scorer)
5901  
5902      tracking_store = _get_tracking_store()
5903      batches = get_trace_batches_for_scorer(trace_ids, scorer, tracking_store)
5904  
5905      # Extract the authenticated username so that job subprocesses can make
5906      # gateway requests authorised as the original user (not the admin).
5907      username = request.authorization.username if request.authorization else None
5908  
5909      jobs = []
5910      for batch_trace_ids in batches:
5911          job = submit_job(
5912              function=invoke_scorer_job,
5913              params={
5914                  "experiment_id": experiment_id,
5915                  "serialized_scorer": serialized_scorer,
5916                  "trace_ids": batch_trace_ids,
5917                  "log_assessments": log_assessments,
5918                  "username": username,
5919              },
5920          )
5921          jobs.append({"job_id": job.job_id, "trace_ids": batch_trace_ids})
5922  
5923      return jsonify({"jobs": jobs})
5924  
5925  
5926  def _get_rest_path(base_path, version=2):
5927      return _add_static_prefix(f"/api/{version}.0{base_path}")
5928  
5929  
5930  def _get_ajax_path(base_path, version=2):
5931      return _add_static_prefix(f"/ajax-api/{version}.0{base_path}")
5932  
5933  
5934  def _add_static_prefix(route: str) -> str:
5935      if prefix := os.environ.get(STATIC_PREFIX_ENV_VAR):
5936          return prefix.rstrip("/") + route
5937      return route
5938  
5939  
5940  def _get_paths(base_path, version=2):
5941      """
5942      A service endpoints base path is typically something like /mlflow/experiment.
5943      We should register paths like /api/2.0/mlflow/experiment and
5944      /ajax-api/2.0/mlflow/experiment in the Flask router.
5945      """
5946      base_path = _convert_path_parameter_to_flask_format(base_path)
5947      return [_get_rest_path(base_path, version), _get_ajax_path(base_path, version)]
5948  
5949  
5950  def _convert_path_parameter_to_flask_format(path):
5951      """
5952      Converts path parameter format to Flask compatible format.
5953  
5954      Some protobuf endpoint paths contain parameters like /mlflow/trace/{request_id}.
5955      This can be interpreted correctly by gRPC framework like Armeria, but Flask does
5956      not understand it. Instead, we need to specify it with a different format,
5957      like /mlflow/trace/<request_id>.
5958      """
5959      # Handle simple parameters like {trace_id}
5960      path = re.sub(r"{(\w+)}", r"<\1>", path)
5961  
5962      # Handle Databricks-specific syntax like {assessment.trace_id} -> <trace_id>
5963      # This is needed because Databricks can extract trace_id from request body,
5964      # but Flask needs it in the URL path
5965      return re.sub(r"{assessment\.trace_id}", r"<trace_id>", path)
5966  
5967  
5968  def get_handler(request_class):
5969      """
5970      Args:
5971          request_class: The type of protobuf message
5972      """
5973      return HANDLERS.get(request_class, _not_implemented)
5974  
5975  
5976  def get_service_endpoints(service, get_handler):
5977      ret = []
5978      for service_method in service.DESCRIPTOR.methods:
5979          endpoints = service_method.GetOptions().Extensions[databricks_pb2.rpc].endpoints
5980          for endpoint in endpoints:
5981              for http_path in _get_paths(endpoint.path, version=endpoint.since.major):
5982                  handler = get_handler(service().GetRequestClass(service_method))
5983                  ret.append((http_path, handler, [endpoint.method]))
5984      return ret
5985  
5986  
5987  def get_endpoints(get_handler=get_handler):
5988      """
5989      Returns:
5990          List of tuples (path, handler, methods)
5991      """
5992      return (
5993          get_service_endpoints(MlflowService, get_handler)
5994          + get_internal_online_scoring_endpoints()
5995          + get_service_endpoints(ModelRegistryService, get_handler)
5996          + get_service_endpoints(MlflowArtifactsService, get_handler)
5997          + get_service_endpoints(WebhookService, get_handler)
5998          + [(_add_static_prefix("/graphql"), _graphql, ["GET", "POST"])]
5999          # NB: Use _get_paths() so that the endpoint is reachable at both
6000          # <static-prefix>/api/3.0/mlflow/server-info (for the Python client)
6001          # and <static-prefix>/ajax-api/3.0/mlflow/server-info (for the frontend).
6002          + [
6003              (_path, _get_server_info, ["GET"])
6004              for _path in _get_paths("/mlflow/server-info", version=3)
6005          ]
6006          + get_gateway_endpoints()
6007          + get_demo_endpoints()
6008          + get_issues_detection_endpoints()
6009          + get_job_endpoints()
6010      )
6011  
6012  
6013  def get_gateway_endpoints():
6014      """Returns endpoint tuples for gateway provider/model discovery APIs and scorer invocation."""
6015      return [
6016          (
6017              _get_ajax_path("/mlflow/gateway/supported-providers", version=3),
6018              _list_supported_providers,
6019              ["GET"],
6020          ),
6021          (
6022              _get_ajax_path("/mlflow/gateway/supported-models", version=3),
6023              _list_supported_models,
6024              ["GET"],
6025          ),
6026          (
6027              _get_ajax_path("/mlflow/gateway/provider-config", version=3),
6028              _get_provider_config,
6029              ["GET"],
6030          ),
6031          (
6032              _get_ajax_path("/mlflow/gateway/secrets/config", version=3),
6033              _get_secrets_config,
6034              ["GET"],
6035          ),
6036          (
6037              _get_ajax_path("/mlflow/scorer/invoke", version=3),
6038              _invoke_scorer_handler,
6039              ["POST"],
6040          ),
6041      ]
6042  
6043  
6044  def get_issues_detection_endpoints():
6045      return [
6046          (
6047              _get_ajax_path("/mlflow/issues/invoke", version=3),
6048              _invoke_issue_detection_handler,
6049              ["POST"],
6050          ),
6051      ]
6052  
6053  
6054  def get_job_endpoints():
6055      return [
6056          (
6057              _get_ajax_path("/mlflow/jobs/cancel/<job_id>", version=3),
6058              _cancel_job,
6059              ["PATCH"],
6060          ),
6061          (
6062              _get_ajax_path("/mlflow/jobs/<job_id>", version=3),
6063              _get_job,
6064              ["GET"],
6065          ),
6066      ]
6067  
6068  
6069  # Demo APIs
6070  
6071  # Serialize demo generation so concurrent requests (e.g. FastAPI running Flask
6072  # handlers in a thread pool) cannot race on the process-wide MLFLOW_WORKSPACE
6073  # env var that WorkspaceContext temporarily sets during generate_all_demos.
6074  _demo_generate_lock = threading.Lock()
6075  
6076  
6077  def get_demo_endpoints():
6078      """Returns endpoint tuples for demo data generation and deletion APIs."""
6079      return [
6080          (
6081              _get_ajax_path("/mlflow/demo/generate", version=3),
6082              _generate_demo,
6083              ["POST"],
6084          ),
6085          (
6086              _get_ajax_path("/mlflow/demo/delete", version=3),
6087              _delete_demo,
6088              ["POST"],
6089          ),
6090      ]
6091  
6092  
6093  @catch_mlflow_exception
6094  @_disable_if_artifacts_only
6095  def _generate_demo():
6096      """Generate demo data for registered demo generators.
6097  
6098      Accepts an optional JSON body with a ``features`` list to generate only specific
6099      features (e.g. ``{"features": ["traces", "prompts"]}``). When omitted, all features
6100      are generated.
6101      """
6102      from mlflow.demo import generate_all_demos
6103      from mlflow.demo.base import DEMO_EXPERIMENT_NAME
6104      from mlflow.demo.registry import demo_registry
6105  
6106      request_json = request.get_json(silent=True) or {}
6107      features = request_json.get("features")
6108  
6109      store = _get_tracking_store()
6110      experiment = store.get_experiment_by_name(DEMO_EXPERIMENT_NAME)
6111  
6112      generator_names = demo_registry.list_generators()
6113      if features is not None:
6114          generator_names = [n for n in generator_names if n in features]
6115  
6116      all_exist = False
6117      if experiment and experiment.lifecycle_stage == "active":
6118          all_exist = all(demo_registry.get(name)().is_generated() for name in generator_names)
6119  
6120      if experiment and all_exist:
6121          return jsonify({
6122              "status": "exists",
6123              "experiment_id": experiment.experiment_id,
6124              "features_generated": [],
6125              "navigation_url": f"/experiments/{experiment.experiment_id}",
6126          })
6127  
6128      with _demo_generate_lock:
6129          results = generate_all_demos(features=features)
6130  
6131      experiment = store.get_experiment_by_name(DEMO_EXPERIMENT_NAME)
6132      experiment_id = experiment.experiment_id if experiment else None
6133      navigation_url = f"/experiments/{experiment_id}" if experiment_id else "/experiments"
6134  
6135      return jsonify({
6136          "status": "created",
6137          "experiment_id": experiment_id,
6138          "features_generated": [r.feature for r in results],
6139          "navigation_url": navigation_url,
6140      })
6141  
6142  
6143  @catch_mlflow_exception
6144  @_disable_if_artifacts_only
6145  def _delete_demo():
6146      """Delete demo data for all registered demo generators.
6147  
6148      Performs a full hard delete of the demo experiment and all associated data,
6149      equivalent to what `mlflow gc` would do. This ensures the demo data is
6150      completely removed rather than just soft-deleted.
6151      """
6152      from mlflow.demo.base import DEMO_EXPERIMENT_NAME
6153      from mlflow.demo.registry import demo_registry
6154  
6155      deleted_features = []
6156      for name in demo_registry.list_generators():
6157          generator = demo_registry.get(name)()
6158          if generator._data_exists():
6159              generator.delete_demo()
6160              deleted_features.append(name)
6161  
6162      store = _get_tracking_store()
6163      experiment = store.get_experiment_by_name(DEMO_EXPERIMENT_NAME)
6164      if experiment and experiment.lifecycle_stage == "active":
6165          store.delete_experiment(experiment.experiment_id)
6166  
6167      return jsonify({
6168          "status": "deleted",
6169          "features_deleted": deleted_features,
6170      })
6171  
6172  
6173  def get_internal_online_scoring_endpoints():
6174      """Returns endpoint definitions for internal (non public) online scoring APIs."""
6175      return [
6176          (
6177              _get_ajax_path("/mlflow/scorers/online-configs", version=3),
6178              _get_online_scoring_configs,
6179              ["GET"],
6180          ),
6181          (
6182              _get_rest_path("/mlflow/scorers/online-configs", version=3),
6183              _get_online_scoring_configs,
6184              ["GET"],
6185          ),
6186          (
6187              _get_ajax_path("/mlflow/scorers/online-config", version=3),
6188              _upsert_online_scoring_config,
6189              ["PUT"],
6190          ),
6191          (
6192              _get_rest_path("/mlflow/scorers/online-config", version=3),
6193              _upsert_online_scoring_config,
6194              ["PUT"],
6195          ),
6196      ]
6197  
6198  
6199  # Evaluation Dataset APIs
6200  
6201  
6202  @catch_mlflow_exception
6203  @_disable_if_artifacts_only
6204  def _create_dataset_handler():
6205      request_message = _get_request_message(
6206          CreateDataset(),
6207          schema={
6208              "name": [_assert_required, _assert_string],
6209              "experiment_ids": [_assert_array],
6210              "tags": [_assert_string],
6211          },
6212      )
6213  
6214      tags = None
6215      if hasattr(request_message, "tags") and request_message.tags:
6216          tags = json.loads(request_message.tags)
6217  
6218      dataset = _get_tracking_store().create_dataset(
6219          name=request_message.name,
6220          experiment_ids=list(request_message.experiment_ids)
6221          if request_message.experiment_ids
6222          else None,
6223          tags=tags,
6224      )
6225  
6226      response_message = CreateDataset.Response()
6227      response_message.dataset.CopyFrom(dataset.to_proto())
6228      return _wrap_response(response_message)
6229  
6230  
6231  @catch_mlflow_exception
6232  @_disable_if_artifacts_only
6233  def _get_dataset_handler(dataset_id):
6234      dataset = _get_tracking_store().get_dataset(dataset_id)
6235  
6236      response_message = GetDataset.Response()
6237      response_message.dataset.CopyFrom(dataset.to_proto())
6238      return _wrap_response(response_message)
6239  
6240  
6241  @catch_mlflow_exception
6242  @_disable_if_artifacts_only
6243  def _delete_dataset_handler(dataset_id):
6244      _get_tracking_store().delete_dataset(dataset_id)
6245  
6246      response_message = DeleteDataset.Response()
6247      return _wrap_response(response_message)
6248  
6249  
6250  @catch_mlflow_exception
6251  @_disable_if_artifacts_only
6252  def _search_evaluation_datasets_handler():
6253      request_message = _get_request_message(
6254          SearchEvaluationDatasets(),
6255          schema={
6256              "experiment_ids": [_assert_array],
6257              "filter_string": [_assert_string],
6258              "max_results": [_assert_intlike],
6259              "order_by": [_assert_array],
6260              "page_token": [_assert_string],
6261          },
6262      )
6263  
6264      datasets = _get_tracking_store().search_datasets(
6265          experiment_ids=list(request_message.experiment_ids)
6266          if request_message.experiment_ids
6267          else None,
6268          filter_string=request_message.filter_string or None,
6269          max_results=request_message.max_results or None,
6270          order_by=list(request_message.order_by) if request_message.order_by else None,
6271          page_token=request_message.page_token or None,
6272      )
6273  
6274      response_message = SearchEvaluationDatasets.Response()
6275      response_message.datasets.extend([d.to_proto() for d in datasets])
6276      if datasets.token:
6277          response_message.next_page_token = datasets.token
6278  
6279      return _wrap_response(response_message)
6280  
6281  
6282  @catch_mlflow_exception
6283  @_disable_if_artifacts_only
6284  def _set_dataset_tags_handler(dataset_id):
6285      request_message = _get_request_message(
6286          SetDatasetTags(),
6287          schema={
6288              "tags": [_assert_required, _assert_string],
6289          },
6290      )
6291  
6292      tags = json.loads(request_message.tags)
6293  
6294      _get_tracking_store().set_dataset_tags(
6295          dataset_id=dataset_id,
6296          tags=tags,
6297      )
6298  
6299      response_message = SetDatasetTags.Response()
6300      return _wrap_response(response_message)
6301  
6302  
6303  @catch_mlflow_exception
6304  @_disable_if_artifacts_only
6305  def _delete_dataset_tag_handler(dataset_id, key):
6306      _get_tracking_store().delete_dataset_tag(
6307          dataset_id=dataset_id,
6308          key=key,
6309      )
6310  
6311      response_message = DeleteDatasetTag.Response()
6312      return _wrap_response(response_message)
6313  
6314  
6315  @catch_mlflow_exception
6316  @_disable_if_artifacts_only
6317  def _upsert_dataset_records_handler(dataset_id):
6318      request_message = _get_request_message(
6319          UpsertDatasetRecords(),
6320          schema={
6321              "records": [_assert_required, _assert_string],
6322          },
6323      )
6324  
6325      records = json.loads(request_message.records)
6326  
6327      result = _get_tracking_store().upsert_dataset_records(
6328          dataset_id=dataset_id,
6329          records=records,
6330      )
6331  
6332      response_message = UpsertDatasetRecords.Response()
6333      response_message.inserted_count = result["inserted"]
6334      response_message.updated_count = result["updated"]
6335  
6336      return _wrap_response(response_message)
6337  
6338  
6339  def _get_dataset_experiment_ids_handler(dataset_id):
6340      """
6341      Get experiment IDs associated with an evaluation dataset.
6342      """
6343      experiment_ids = _get_tracking_store().get_dataset_experiment_ids(dataset_id=dataset_id)
6344  
6345      response_message = GetDatasetExperimentIds.Response()
6346      response_message.experiment_ids.extend(experiment_ids)
6347  
6348      return _wrap_response(response_message)
6349  
6350  
6351  @catch_mlflow_exception
6352  @_disable_if_artifacts_only
6353  def _add_dataset_to_experiments_handler(dataset_id):
6354      request_message = _get_request_message(
6355          AddDatasetToExperiments(),
6356          schema={
6357              "experiment_ids": [_assert_array],
6358          },
6359      )
6360  
6361      dataset = _get_tracking_store().add_dataset_to_experiments(
6362          dataset_id=dataset_id,
6363          experiment_ids=request_message.experiment_ids,
6364      )
6365  
6366      response_message = AddDatasetToExperiments.Response()
6367      response_message.dataset.CopyFrom(dataset.to_proto())
6368      return _wrap_response(response_message)
6369  
6370  
6371  @catch_mlflow_exception
6372  @_disable_if_artifacts_only
6373  def _remove_dataset_from_experiments_handler(dataset_id):
6374      request_message = _get_request_message(
6375          RemoveDatasetFromExperiments(),
6376          schema={
6377              "experiment_ids": [_assert_array],
6378          },
6379      )
6380  
6381      dataset = _get_tracking_store().remove_dataset_from_experiments(
6382          dataset_id=dataset_id,
6383          experiment_ids=request_message.experiment_ids,
6384      )
6385  
6386      response_message = RemoveDatasetFromExperiments.Response()
6387      response_message.dataset.CopyFrom(dataset.to_proto())
6388      return _wrap_response(response_message)
6389  
6390  
6391  def _get_dataset_records_handler(dataset_id):
6392      request_message = _get_request_message(
6393          GetDatasetRecords(),
6394          schema={
6395              "max_results": [_assert_intlike],
6396              "page_token": [_assert_string],
6397          },
6398      )
6399  
6400      max_results = request_message.max_results or 1000
6401      page_token = request_message.page_token or None
6402  
6403      # Use the pagination-aware method
6404      records, next_page_token = _get_tracking_store()._load_dataset_records(
6405          dataset_id, max_results=max_results, page_token=page_token
6406      )
6407  
6408      response_message = GetDatasetRecords.Response()
6409  
6410      records_dicts = [record.to_dict() for record in records]
6411      response_message.records = json.dumps(records_dicts)
6412  
6413      if next_page_token:
6414          response_message.next_page_token = next_page_token
6415  
6416      return _wrap_response(response_message)
6417  
6418  
6419  @catch_mlflow_exception
6420  @_disable_if_artifacts_only
6421  def _delete_dataset_records_handler(dataset_id):
6422      request_message = _get_request_message(
6423          DeleteDatasetRecords(),
6424          schema={
6425              "dataset_record_ids": [_assert_array],
6426          },
6427      )
6428  
6429      deleted_count = _get_tracking_store().delete_dataset_records(
6430          dataset_id=dataset_id,
6431          dataset_record_ids=list(request_message.dataset_record_ids),
6432      )
6433  
6434      response_message = DeleteDatasetRecords.Response()
6435      response_message.deleted_count = deleted_count
6436      return _wrap_response(response_message)
6437  
6438  
6439  # Cache for telemetry config with 3 hour TTL
6440  _telemetry_config_cache = TTLCache(maxsize=1, ttl=10800)
6441  
6442  
6443  def _get_or_fetch_ui_telemetry_config():
6444      if (config := _telemetry_config_cache.get("config")) is None:
6445          config = fetch_ui_telemetry_config()
6446          _telemetry_config_cache["config"] = config
6447      return config
6448  
6449  
6450  @catch_mlflow_exception
6451  def get_ui_telemetry_handler():
6452      """
6453      GET handler for /telemetry endpoint.
6454      Returns the telemetry client configuration by fetching it directly.
6455      """
6456      if is_telemetry_disabled():
6457          return jsonify(FALLBACK_UI_CONFIG)
6458  
6459      config = _get_or_fetch_ui_telemetry_config()
6460  
6461      # UI telemetry should be also disabled if overall telemetry is disabled
6462      disable_ui_telemetry = config.get("disable_ui_telemetry", True) or config.get(
6463          "disable_telemetry", True
6464      )
6465      response = {
6466          "disable_ui_telemetry": disable_ui_telemetry,
6467          "disable_ui_events": config.get("disable_ui_events", []),
6468          "ui_rollout_percentage": config.get("ui_rollout_percentage", 0),
6469      }
6470      return jsonify(response)
6471  
6472  
6473  @catch_mlflow_exception
6474  def post_ui_telemetry_handler():
6475      """
6476      POST handler for /telemetry endpoint.
6477      Accepts telemetry records and adds them to the telemetry client.
6478      """
6479      try:
6480          if is_telemetry_disabled():
6481              return jsonify({"status": "disabled"})
6482  
6483          data = request.json.get("records", [])
6484  
6485          if not data:
6486              return jsonify({"status": "success"})
6487  
6488          if (client := get_telemetry_client()) is None:
6489              return jsonify({"status": "disabled"})
6490  
6491          # check cached config to see if telemetry is disabled
6492          # if so, don't process the records. we don't rely on the
6493          # config from the telemetry client because it is only fetched
6494          # once, so it won't be updated unless the server is restarted.
6495          config = _get_or_fetch_ui_telemetry_config()
6496  
6497          # if updated telemetry config is disabled / missing, tell the UI to stop sending records
6498          if config.get("disable_ui_telemetry", True) or config.get("disable_telemetry", True):
6499              return jsonify({"status": "disabled"})
6500  
6501          server_installation_id = get_or_create_installation_id()
6502          records = [
6503              Record(
6504                  event_name=event["event_name"],
6505                  timestamp_ns=event["timestamp_ns"],
6506                  params=event["params"],
6507                  status=Status.SUCCESS,
6508                  installation_id=event["installation_id"],
6509                  session_id=event["session_id"],
6510                  server_installation_id=server_installation_id,
6511                  duration_ms=0,
6512              )
6513              for event in data
6514          ]
6515  
6516          client.add_records(records)
6517  
6518          return jsonify({"status": "success"})
6519      except Exception as e:
6520          _logger.debug(f"Failed to process UI telemetry records: {e}")
6521          # if we run into unexpected errors, likely something is wrong
6522          # with the data format. if we return success, the UI will continue
6523          # to send records. if we return an error, the UI will retry sending
6524          # records. the safest thing to do is to tell the UI to stop sending
6525          return jsonify({"status": "disabled"})
6526  
6527  
6528  def _parse_prompt_uri(prompt_uri: str) -> tuple[str, str]:
6529      """
6530      Parse a prompt URI to extract the prompt name and version.
6531  
6532      Args:
6533          prompt_uri: Prompt URI in the format "prompts:/prompt_name/version"
6534  
6535      Returns:
6536          A tuple of (prompt_name, version). Returns empty strings if parsing fails.
6537      """
6538      try:
6539          # Format: "prompts:/prompt_name/version"
6540          if prompt_uri.startswith("prompts:/"):
6541              parts = prompt_uri.replace("prompts:/", "").split("/")
6542              if len(parts) >= 2:
6543                  return parts[0], parts[1]
6544      except Exception:
6545          pass
6546      return "", ""
6547  
6548  
6549  @catch_mlflow_exception
6550  @_disable_if_artifacts_only
6551  def _create_prompt_optimization_job():
6552      # These imports must be local to avoid circular import with mlflow.server.jobs
6553      from mlflow.genai.datasets import get_dataset as get_genai_dataset
6554      from mlflow.genai.optimize.job import OptimizerType, optimize_prompts_job
6555      from mlflow.server.jobs import submit_job
6556  
6557      request_message = _get_request_message(
6558          CreatePromptOptimizationJob(),
6559          schema={
6560              "experiment_id": [_assert_string],
6561              "source_prompt_uri": [_assert_string, _assert_required],
6562              "config": [_assert_required],
6563              "tags": [_assert_array],
6564          },
6565      )
6566  
6567      prompt_uri = request_message.source_prompt_uri or ""
6568      if not prompt_uri:
6569          raise MlflowException(
6570              "source_prompt_uri is required for optimization job",
6571              error_code=INVALID_PARAMETER_VALUE,
6572          )
6573  
6574      config = request_message.config
6575      dataset_id = config.dataset_id or ""
6576  
6577      scorers = list(config.scorers) if config.scorers else []
6578  
6579      optimizer_type = OptimizerType.from_proto(config.optimizer_type)
6580  
6581      experiment_id = (request_message.experiment_id or "").strip()
6582      if not experiment_id:
6583          raise MlflowException(
6584              "experiment_id is required for optimization job",
6585              error_code=INVALID_PARAMETER_VALUE,
6586          )
6587  
6588      # Parse optimizer_config_json to dict for the job function
6589      # Validate before creating run to avoid creating unused runs on validation failure
6590      optimizer_config = None
6591      if config.optimizer_config_json:
6592          try:
6593              optimizer_config = json.loads(config.optimizer_config_json)
6594          except json.JSONDecodeError as e:
6595              raise MlflowException(
6596                  f"Invalid JSON in optimizer_config_json: {e}",
6597                  error_code=INVALID_PARAMETER_VALUE,
6598              )
6599  
6600      # Create MLflow run upfront so run_id is immediately available
6601      # The job will resume this run when it starts executing
6602      tracking_store = _get_tracking_store()
6603      start_time = int(time.time() * 1000)
6604  
6605      # Parse prompt name and version from URI for more descriptive run name
6606      prompt_name, prompt_version = _parse_prompt_uri(prompt_uri)
6607      run_name = f"optimize_prompt_{optimizer_type}_{prompt_name}_{prompt_version}_{start_time}"
6608  
6609      run = tracking_store.create_run(
6610          experiment_id=experiment_id,
6611          user_id=_get_user(),
6612          start_time=start_time,
6613          tags=[],
6614          run_name=run_name,
6615      )
6616      run_id = run.info.run_id
6617  
6618      # Log optimization config as run parameters
6619      params_to_log = [
6620          Param("source_prompt_uri", prompt_uri),
6621          Param("optimizer_type", optimizer_type),
6622          Param("dataset_id", dataset_id),
6623          Param("scorer_names", json.dumps(scorers)),
6624      ]
6625      if config.optimizer_config_json:
6626          params_to_log.append(Param("optimizer_config_json", config.optimizer_config_json))
6627      tracking_store.log_batch(run_id=run_id, metrics=[], params=params_to_log, tags=[])
6628  
6629      # Link the evaluation dataset to the run for lineage tracking (if dataset_id is provided)
6630      if dataset_id:
6631          dataset = get_genai_dataset(dataset_id=dataset_id)
6632          dataset_input = DatasetInput(
6633              dataset=dataset._to_mlflow_entity(),
6634              tags=[InputTag(key="mlflow.data.context", value="optimization")],
6635          )
6636          tracking_store.log_inputs(run_id=run_id, datasets=[dataset_input])
6637  
6638      params = {
6639          "run_id": run_id,
6640          "experiment_id": experiment_id,
6641          "prompt_uri": prompt_uri,
6642          "dataset_id": dataset_id,
6643          "optimizer_type": optimizer_type,
6644          "optimizer_config": optimizer_config,
6645          "scorer_names": scorers,
6646      }
6647  
6648      job_entity = submit_job(optimize_prompts_job, params)
6649  
6650      response_message = CreatePromptOptimizationJob.Response()
6651      optimization_job = PromptOptimizationJobProto()
6652      optimization_job.job_id = job_entity.job_id
6653      optimization_job.run_id = run_id
6654      optimization_job.state.status = JobStatus.JOB_STATUS_PENDING
6655      optimization_job.creation_timestamp_ms = job_entity.creation_time
6656      optimization_job.experiment_id = experiment_id
6657      optimization_job.config.CopyFrom(config)
6658      optimization_job.source_prompt_uri = prompt_uri
6659  
6660      for tag in request_message.tags:
6661          job_tag = optimization_job.tags.add()
6662          job_tag.key = tag.key
6663          job_tag.value = tag.value
6664  
6665      response_message.job.CopyFrom(optimization_job)
6666      return _wrap_response(response_message)
6667  
6668  
6669  def _build_prompt_optimization_job_from_entity(job_entity):
6670      from mlflow.genai.optimize.job import OptimizerType
6671  
6672      optimization_job = PromptOptimizationJobProto()
6673      optimization_job.job_id = job_entity.job_id
6674      optimization_job.state.status = job_entity.status.to_proto()
6675      optimization_job.creation_timestamp_ms = job_entity.creation_time
6676  
6677      params = json.loads(job_entity.params)
6678      if "experiment_id" in params:
6679          optimization_job.experiment_id = params["experiment_id"]
6680      if "prompt_uri" in params:
6681          optimization_job.source_prompt_uri = params["prompt_uri"]
6682  
6683      if run_id := params.get("run_id"):
6684          optimization_job.run_id = run_id
6685  
6686      # Populate config from job params
6687      config = optimization_job.config
6688      if "optimizer_type" in params:
6689          try:
6690              optimizer_type = OptimizerType(params["optimizer_type"])
6691              config.optimizer_type = optimizer_type.to_proto()
6692          except (ValueError, KeyError):
6693              pass
6694      if params.get("dataset_id"):
6695          config.dataset_id = params["dataset_id"]
6696      if "scorer_names" in params:
6697          try:
6698              scorer_names = params["scorer_names"]
6699              if isinstance(scorer_names, str):
6700                  scorer_names = json.loads(scorer_names)
6701              if isinstance(scorer_names, list):
6702                  config.scorers.extend(scorer_names)
6703          except (json.JSONDecodeError, TypeError):
6704              pass
6705      if params.get("optimizer_config"):
6706          optimizer_config = params["optimizer_config"]
6707          if isinstance(optimizer_config, dict):
6708              config.optimizer_config_json = json.dumps(optimizer_config)
6709          elif isinstance(optimizer_config, str):
6710              config.optimizer_config_json = optimizer_config
6711  
6712      # Get optimized_prompt_uri from job result (only available when job succeeds)
6713      if job_entity.status.name == "SUCCEEDED" and job_entity.parsed_result:
6714          result = job_entity.parsed_result
6715          if isinstance(result, dict) and result.get("optimized_prompt_uri"):
6716              optimization_job.optimized_prompt_uri = result["optimized_prompt_uri"]
6717  
6718      # If job failed, add error message to state
6719      if job_entity.status.name == "FAILED" and job_entity.parsed_result:
6720          optimization_job.state.error_message = str(job_entity.parsed_result)
6721  
6722      return optimization_job
6723  
6724  
6725  @catch_mlflow_exception
6726  @_disable_if_artifacts_only
6727  def _get_prompt_optimization_job(job_id):
6728      from mlflow.server.jobs import get_job
6729  
6730      job_entity = get_job(job_id)
6731      optimization_job = _build_prompt_optimization_job_from_entity(job_entity)
6732  
6733      # Fetch MLflow run to get evaluation scores from metrics
6734      try:
6735          mlflow_run = _get_tracking_store().get_run(optimization_job.run_id)
6736          run_metrics = mlflow_run.data.metrics
6737  
6738          # Populate evaluation scores from run metrics
6739          # Aggregated scores are logged as "initial_eval_score" and "final_eval_score"
6740          # Per-scorer scores are logged as "initial_eval_score.<scorer_name>" and
6741          # "final_eval_score.<scorer_name>"
6742          total_metric_calls = None
6743          for metric_name, metric_value in run_metrics.items():
6744              match metric_name.split(".", 1):
6745                  case ["initial_eval_score"]:
6746                      optimization_job.initial_eval_scores["aggregate"] = metric_value
6747                  case ["final_eval_score"]:
6748                      optimization_job.final_eval_scores["aggregate"] = metric_value
6749                  case ["initial_eval_score", scorer_name]:
6750                      optimization_job.initial_eval_scores[scorer_name] = metric_value
6751                  case ["final_eval_score", scorer_name]:
6752                      optimization_job.final_eval_scores[scorer_name] = metric_value
6753                  case ["total_metric_calls"]:
6754                      total_metric_calls = metric_value
6755  
6756          if total_metric_calls is not None:
6757              params = json.loads(job_entity.params)
6758              optimizer_config = params.get("optimizer_config", {})
6759              if max_metric_calls := optimizer_config.get("max_metric_calls"):
6760                  progress = round(min(total_metric_calls / max_metric_calls, 1.0), 2)
6761                  optimization_job.state.metadata["progress"] = str(progress)
6762  
6763      except Exception as e:
6764          _logger.debug("Failed to fetch run details for optimization job %s: %s", job_id, e)
6765  
6766      response_message = GetPromptOptimizationJob.Response()
6767      response_message.job.CopyFrom(optimization_job)
6768      return _wrap_response(response_message)
6769  
6770  
6771  @catch_mlflow_exception
6772  @_disable_if_artifacts_only
6773  def _search_prompt_optimization_jobs():
6774      request_message = _get_request_message(
6775          SearchPromptOptimizationJobs(),
6776          schema={
6777              "experiment_id": [_assert_required, _assert_string],
6778          },
6779      )
6780  
6781      job_store = _get_job_store()
6782  
6783      # Search for optimize_prompts jobs in the specified experiment
6784      jobs = job_store.list_jobs(
6785          job_name="optimize_prompts",
6786          params={"experiment_id": request_message.experiment_id},
6787      )
6788  
6789      response_message = SearchPromptOptimizationJobs.Response()
6790  
6791      for job_entity in jobs:
6792          optimization_job = _build_prompt_optimization_job_from_entity(job_entity)
6793          response_message.jobs.append(optimization_job)
6794  
6795      return _wrap_response(response_message)
6796  
6797  
6798  @catch_mlflow_exception
6799  @_disable_if_artifacts_only
6800  def _cancel_prompt_optimization_job(job_id):
6801      # This import must be local to avoid circular import with mlflow.server.jobs
6802      from mlflow.server.jobs import cancel_job
6803  
6804      job_entity = cancel_job(job_id)
6805      optimization_job = _build_prompt_optimization_job_from_entity(job_entity)
6806      # Override status to CANCELED since cancel_job may not update the entity status immediately
6807      optimization_job.state.status = JobStatus.JOB_STATUS_CANCELED
6808  
6809      # Terminate the underlying MLflow run if it exists
6810      if optimization_job.run_id:
6811          try:
6812              _get_tracking_store().update_run_info(
6813                  run_id=optimization_job.run_id,
6814                  run_status=RunStatus.KILLED,
6815                  end_time=get_current_time_millis(),
6816                  run_name=None,
6817              )
6818          except Exception:
6819              # If the run doesn't exist or is already terminated, log warning and continue
6820              _logger.warning(
6821                  "Failed to terminate MLflow run '%s' when canceling job '%s'",
6822                  optimization_job.run_id,
6823                  job_id,
6824              )
6825  
6826      response_message = CancelPromptOptimizationJob.Response()
6827      response_message.job.CopyFrom(optimization_job)
6828      return _wrap_response(response_message)
6829  
6830  
6831  @catch_mlflow_exception
6832  @_disable_if_artifacts_only
6833  def _delete_prompt_optimization_job(job_id):
6834      job_store = _get_job_store()
6835      job_entity = job_store.get_job(job_id)
6836      optimization_job = _build_prompt_optimization_job_from_entity(job_entity)
6837      run_id = optimization_job.run_id
6838  
6839      job_store.delete_jobs(job_ids=[job_id])
6840  
6841      # Delete the associated MLflow run if it exists.
6842      # Check if run exists before attempting deletion - user may have
6843      # deleted it manually before the job deletion request.
6844      if run_id:
6845          try:
6846              _get_tracking_store().get_run(run_id)
6847              _get_tracking_store().delete_run(run_id)
6848          except MlflowException:
6849              pass
6850  
6851      response_message = DeletePromptOptimizationJob.Response()
6852      return _wrap_response(response_message)
6853  
6854  
6855  HANDLERS = {
6856      # Tracking Server APIs
6857      CreateExperiment: _create_experiment,
6858      GetExperiment: _get_experiment,
6859      GetExperimentByName: _get_experiment_by_name,
6860      DeleteExperiment: _delete_experiment,
6861      RestoreExperiment: _restore_experiment,
6862      UpdateExperiment: _update_experiment,
6863      CreateRun: _create_run,
6864      UpdateRun: _update_run,
6865      DeleteRun: _delete_run,
6866      RestoreRun: _restore_run,
6867      LogParam: _log_param,
6868      LogMetric: _log_metric,
6869      SetExperimentTag: _set_experiment_tag,
6870      DeleteExperimentTag: _delete_experiment_tag,
6871      SetTag: _set_tag,
6872      DeleteTag: _delete_tag,
6873      LogBatch: _log_batch,
6874      LogModel: _log_model,
6875      GetRun: _get_run,
6876      SearchRuns: _search_runs,
6877      ListArtifacts: _list_artifacts,
6878      CreatePresignedUploadUrl: _create_presigned_upload_url,
6879      GetMetricHistory: _get_metric_history,
6880      GetMetricHistoryBulkInterval: get_metric_history_bulk_interval_handler,
6881      SearchExperiments: _search_experiments,
6882      LogInputs: _log_inputs,
6883      LogOutputs: _log_outputs,
6884      # Evaluation Dataset APIs
6885      CreateDataset: _create_dataset_handler,
6886      GetDataset: _get_dataset_handler,
6887      DeleteDataset: _delete_dataset_handler,
6888      SearchEvaluationDatasets: _search_evaluation_datasets_handler,
6889      SetDatasetTags: _set_dataset_tags_handler,
6890      DeleteDatasetTag: _delete_dataset_tag_handler,
6891      UpsertDatasetRecords: _upsert_dataset_records_handler,
6892      GetDatasetExperimentIds: _get_dataset_experiment_ids_handler,
6893      GetDatasetRecords: _get_dataset_records_handler,
6894      DeleteDatasetRecords: _delete_dataset_records_handler,
6895      AddDatasetToExperiments: _add_dataset_to_experiments_handler,
6896      RemoveDatasetFromExperiments: _remove_dataset_from_experiments_handler,
6897      # Model Registry APIs
6898      CreateRegisteredModel: _create_registered_model,
6899      GetRegisteredModel: _get_registered_model,
6900      DeleteRegisteredModel: _delete_registered_model,
6901      UpdateRegisteredModel: _update_registered_model,
6902      RenameRegisteredModel: _rename_registered_model,
6903      SearchRegisteredModels: _search_registered_models,
6904      GetLatestVersions: _get_latest_versions,
6905      CreateModelVersion: _create_model_version,
6906      GetModelVersion: _get_model_version,
6907      DeleteModelVersion: _delete_model_version,
6908      UpdateModelVersion: _update_model_version,
6909      TransitionModelVersionStage: _transition_stage,
6910      GetModelVersionDownloadUri: _get_model_version_download_uri,
6911      SearchModelVersions: _search_model_versions,
6912      SetRegisteredModelTag: _set_registered_model_tag,
6913      DeleteRegisteredModelTag: _delete_registered_model_tag,
6914      SetModelVersionTag: _set_model_version_tag,
6915      DeleteModelVersionTag: _delete_model_version_tag,
6916      SetRegisteredModelAlias: _set_registered_model_alias,
6917      DeleteRegisteredModelAlias: _delete_registered_model_alias,
6918      GetModelVersionByAlias: _get_model_version_by_alias,
6919      # Webhook APIs
6920      CreateWebhook: _create_webhook,
6921      ListWebhooks: _list_webhooks,
6922      GetWebhook: _get_webhook,
6923      UpdateWebhook: _update_webhook,
6924      DeleteWebhook: _delete_webhook,
6925      TestWebhook: _test_webhook,
6926      # MLflow Artifacts APIs
6927      DownloadArtifact: _download_artifact,
6928      UploadArtifact: _upload_artifact,
6929      ListArtifactsMlflowArtifacts: _list_artifacts_mlflow_artifacts,
6930      DeleteArtifact: _delete_artifact_mlflow_artifacts,
6931      CreateMultipartUpload: _create_multipart_upload_artifact,
6932      CompleteMultipartUpload: _complete_multipart_upload_artifact,
6933      AbortMultipartUpload: _abort_multipart_upload_artifact,
6934      GetPresignedDownloadUrl: _get_presigned_download_url,
6935      # MLflow Tracing APIs (V3)
6936      StartTraceV3: _start_trace_v3,
6937      GetTraceInfoV3: _get_trace_info_v3,
6938      SearchTracesV3: _search_traces_v3,
6939      DeleteTracesV3: _delete_traces,
6940      CalculateTraceFilterCorrelation: _calculate_trace_filter_correlation,
6941      SetTraceTagV3: _set_trace_tag_v3,
6942      DeleteTraceTagV3: _delete_trace_tag_v3,
6943      LinkTracesToRun: _link_traces_to_run,
6944      LinkPromptsToTrace: _link_prompts_to_trace,
6945      BatchGetTraces: _batch_get_traces,
6946      BatchGetTraceInfos: _batch_get_trace_infos,
6947      GetTrace: _get_trace,
6948      QueryTraceMetrics: _query_trace_metrics,
6949      # Assessment APIs
6950      CreateAssessment: _create_assessment,
6951      GetAssessmentRequest: _get_assessment,
6952      UpdateAssessment: _update_assessment,
6953      DeleteAssessment: _delete_assessment,
6954      # Issue APIs
6955      CreateIssue: _create_issue,
6956      UpdateIssue: _update_issue,
6957      GetIssue: _get_issue,
6958      SearchIssues: _search_issues,
6959      # Legacy MLflow Tracing V2 APIs. Kept for backward compatibility but do not use.
6960      StartTrace: _deprecated_start_trace_v2,
6961      EndTrace: _deprecated_end_trace_v2,
6962      GetTraceInfo: _deprecated_get_trace_info_v2,
6963      SearchTraces: _deprecated_search_traces_v2,
6964      DeleteTraces: _delete_traces,
6965      SetTraceTag: _set_trace_tag,
6966      DeleteTraceTag: _delete_trace_tag,
6967      # Logged Models APIs
6968      CreateLoggedModel: _create_logged_model,
6969      GetLoggedModel: _get_logged_model,
6970      FinalizeLoggedModel: _finalize_logged_model,
6971      DeleteLoggedModel: _delete_logged_model,
6972      SetLoggedModelTags: _set_logged_model_tags,
6973      DeleteLoggedModelTag: _delete_logged_model_tag,
6974      SearchLoggedModels: _search_logged_models,
6975      ListLoggedModelArtifacts: _list_logged_model_artifacts,
6976      LogLoggedModelParamsRequest: _log_logged_model_params,
6977      # Scorer APIs
6978      RegisterScorer: _register_scorer,
6979      ListScorers: _list_scorers,
6980      ListScorerVersions: _list_scorer_versions,
6981      GetScorer: _get_scorer,
6982      DeleteScorer: _delete_scorer,
6983      # Secrets APIs
6984      CreateGatewaySecret: _create_gateway_secret,
6985      GetGatewaySecretInfo: _get_gateway_secret_info,
6986      UpdateGatewaySecret: _update_gateway_secret,
6987      DeleteGatewaySecret: _delete_gateway_secret,
6988      ListGatewaySecretInfos: _list_gateway_secrets,
6989      # Endpoints APIs
6990      CreateGatewayEndpoint: _create_gateway_endpoint,
6991      GetGatewayEndpoint: _get_gateway_endpoint,
6992      UpdateGatewayEndpoint: _update_gateway_endpoint,
6993      DeleteGatewayEndpoint: _delete_gateway_endpoint,
6994      ListGatewayEndpoints: _list_gateway_endpoints,
6995      # Model Definitions APIs
6996      CreateGatewayModelDefinition: _create_gateway_model_definition,
6997      GetGatewayModelDefinition: _get_gateway_model_definition,
6998      ListGatewayModelDefinitions: _list_gateway_model_definitions,
6999      UpdateGatewayModelDefinition: _update_gateway_model_definition,
7000      DeleteGatewayModelDefinition: _delete_gateway_model_definition,
7001      # Endpoint Model Mappings APIs
7002      AttachModelToGatewayEndpoint: _attach_model_to_gateway_endpoint,
7003      DetachModelFromGatewayEndpoint: _detach_model_from_gateway_endpoint,
7004      # Endpoint Bindings APIs
7005      CreateGatewayEndpointBinding: _create_gateway_endpoint_binding,
7006      DeleteGatewayEndpointBinding: _delete_gateway_endpoint_binding,
7007      ListGatewayEndpointBindings: _list_gateway_endpoint_bindings,
7008      # Endpoint Tags APIs
7009      SetGatewayEndpointTag: _set_gateway_endpoint_tag,
7010      DeleteGatewayEndpointTag: _delete_gateway_endpoint_tag,
7011      # Budget Policy APIs
7012      CreateGatewayBudgetPolicy: _create_budget_policy,
7013      GetGatewayBudgetPolicy: _get_budget_policy,
7014      UpdateGatewayBudgetPolicy: _update_budget_policy,
7015      DeleteGatewayBudgetPolicy: _delete_budget_policy,
7016      ListGatewayBudgetPolicies: _list_budget_policies,
7017      ListGatewayBudgetWindows: _list_budget_windows,
7018      # Guardrail APIs
7019      CreateGatewayGuardrail: _create_gateway_guardrail,
7020      GetGatewayGuardrail: _get_gateway_guardrail,
7021      DeleteGatewayGuardrail: _delete_gateway_guardrail,
7022      ListGatewayGuardrails: _list_gateway_guardrails,
7023      AddGuardrailToEndpoint: _add_guardrail_to_endpoint,
7024      RemoveGuardrailFromEndpoint: _remove_guardrail_from_endpoint,
7025      ListEndpointGuardrailConfigs: _list_endpoint_guardrail_configs,
7026      UpdateEndpointGuardrailConfig: _update_endpoint_guardrail_config,
7027      # Prompt Optimization APIs
7028      CreatePromptOptimizationJob: _create_prompt_optimization_job,
7029      GetPromptOptimizationJob: _get_prompt_optimization_job,
7030      SearchPromptOptimizationJobs: _search_prompt_optimization_jobs,
7031      CancelPromptOptimizationJob: _cancel_prompt_optimization_job,
7032      DeletePromptOptimizationJob: _delete_prompt_optimization_job,
7033      # Workspace APIs
7034      ListWorkspaces: _list_workspaces_handler,
7035      CreateWorkspace: _create_workspace_handler,
7036      GetWorkspace: _get_workspace_handler,
7037      UpdateWorkspace: _update_workspace_handler,
7038      DeleteWorkspace: _delete_workspace_handler,
7039  }