/ tests / server / test_handlers.py
test_handlers.py
   1  import json
   2  import uuid
   3  from dataclasses import asdict
   4  from datetime import datetime, timezone
   5  from unittest import mock
   6  
   7  import pytest
   8  from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
   9  
  10  import mlflow
  11  from mlflow.entities import (
  12      GatewayBudgetPolicy,
  13      Issue,
  14      IssueSeverity,
  15      IssueStatus,
  16      RunStatus,
  17      ScorerVersion,
  18      Span,
  19      Trace,
  20      TraceData,
  21      TraceInfo,
  22      TraceState,
  23      ViewType,
  24  )
  25  from mlflow.entities._job import Job as JobEntity
  26  from mlflow.entities._job_status import JobStatus
  27  from mlflow.entities.gateway_budget_policy import (
  28      BudgetAction,
  29      BudgetDuration,
  30      BudgetDurationUnit,
  31      BudgetTargetScope,
  32      BudgetUnit,
  33  )
  34  from mlflow.entities.model_registry import (
  35      ModelVersion,
  36      ModelVersionTag,
  37      PromptVersion,
  38      RegisteredModel,
  39      RegisteredModelTag,
  40  )
  41  from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY, PROMPT_TEXT_TAG_KEY
  42  from mlflow.entities.presigned_download import PresignedDownloadUrlResponse
  43  from mlflow.entities.presigned_upload import CreatePresignedUploadResponse
  44  from mlflow.entities.trace_location import TraceLocation as EntityTraceLocation
  45  from mlflow.entities.trace_metrics import (
  46      AggregationType,
  47      MetricAggregation,
  48      MetricDataPoint,
  49      MetricViewType,
  50  )
  51  from mlflow.environment_variables import MLFLOW_ENABLE_WORKSPACES
  52  from mlflow.exceptions import MlflowException, MlflowNotImplementedException
  53  from mlflow.gateway.budget_tracker.in_memory import InMemoryBudgetTracker
  54  from mlflow.genai.scorers.online.entities import OnlineScoringConfig
  55  from mlflow.protos.databricks_pb2 import (
  56      INTERNAL_ERROR,
  57      INVALID_PARAMETER_VALUE,
  58      NOT_IMPLEMENTED,
  59      RESOURCE_DOES_NOT_EXIST,
  60      ErrorCode,
  61  )
  62  from mlflow.protos.issues_pb2 import (
  63      CreateIssue,
  64      SearchIssues,
  65      UpdateIssue,
  66  )
  67  from mlflow.protos.model_registry_pb2 import (
  68      CreateModelVersion,
  69      CreateRegisteredModel,
  70      DeleteModelVersion,
  71      DeleteModelVersionTag,
  72      DeleteRegisteredModel,
  73      DeleteRegisteredModelAlias,
  74      DeleteRegisteredModelTag,
  75      GetLatestVersions,
  76      GetModelVersion,
  77      GetModelVersionByAlias,
  78      GetModelVersionDownloadUri,
  79      GetRegisteredModel,
  80      RenameRegisteredModel,
  81      SearchModelVersions,
  82      SearchRegisteredModels,
  83      SetModelVersionTag,
  84      SetRegisteredModelAlias,
  85      SetRegisteredModelTag,
  86      TransitionModelVersionStage,
  87      UpdateModelVersion,
  88      UpdateRegisteredModel,
  89  )
  90  from mlflow.protos.prompt_optimization_pb2 import (
  91      OPTIMIZER_TYPE_GEPA,
  92      OPTIMIZER_TYPE_METAPROMPT,
  93      OPTIMIZER_TYPE_UNSPECIFIED,
  94  )
  95  from mlflow.protos.service_pb2 import (
  96      BatchGetTraceInfos,
  97      BatchGetTraces,
  98      CalculateTraceFilterCorrelation,
  99      CreateExperiment,
 100      DeleteScorer,
 101      DeleteTraceTag,
 102      DeleteTraceTagV3,
 103      GatewayEndpoint,
 104      GetGatewayEndpoint,
 105      GetScorer,
 106      GetTrace,
 107      LinkPromptsToTrace,
 108      ListScorers,
 109      ListScorerVersions,
 110      QueryTraceMetrics,
 111      RegisterScorer,
 112      SearchExperiments,
 113      SearchLoggedModels,
 114      SearchRuns,
 115      SearchTraces,
 116      SearchTracesV3,
 117      SetTraceTag,
 118      SetTraceTagV3,
 119      TraceLocation,
 120  )
 121  from mlflow.protos.webhooks_pb2 import ListWebhooks
 122  from mlflow.server import (
 123      ARTIFACTS_DESTINATION_ENV_VAR,
 124      BACKEND_STORE_URI_ENV_VAR,
 125      SERVE_ARTIFACTS_ENV_VAR,
 126      app,
 127  )
 128  from mlflow.server.handlers import (
 129      ARTIFACT_STREAM_CHUNK_SIZE,
 130      STATIC_PREFIX_ENV_VAR,
 131      ModelRegistryStoreRegistryWrapper,
 132      TrackingStoreRegistryWrapper,
 133      _batch_get_trace_infos,
 134      _batch_get_traces,
 135      _calculate_trace_filter_correlation,
 136      _cancel_prompt_optimization_job,
 137      _convert_path_parameter_to_flask_format,
 138      _create_dataset_handler,
 139      _create_experiment,
 140      _create_issue,
 141      _create_model_version,
 142      _create_presigned_upload_url,
 143      _create_prompt_optimization_job,
 144      _create_registered_model,
 145      _delete_artifact_mlflow_artifacts,
 146      _delete_dataset_handler,
 147      _delete_dataset_tag_handler,
 148      _delete_model_version,
 149      _delete_model_version_tag,
 150      _delete_registered_model,
 151      _delete_registered_model_alias,
 152      _delete_registered_model_tag,
 153      _delete_scorer,
 154      _delete_trace_tag,
 155      _delete_trace_tag_v3,
 156      _deprecated_search_traces_v2,
 157      _download_artifact,
 158      _get_ajax_path,
 159      _get_dataset_experiment_ids_handler,
 160      _get_dataset_handler,
 161      _get_dataset_records_handler,
 162      _get_gateway_endpoint,
 163      _get_issue,
 164      _get_latest_versions,
 165      _get_model_version,
 166      _get_model_version_by_alias,
 167      _get_model_version_download_uri,
 168      _get_presigned_download_url,
 169      _get_registered_model,
 170      _get_request_message,
 171      _get_rest_path,
 172      _get_scorer,
 173      _get_trace,
 174      _get_trace_artifact_repo,
 175      _get_workspace_scoped_repo_path_if_enabled,
 176      _link_prompts_to_trace,
 177      _list_artifacts_for_proxied_run_artifact_root,
 178      _list_scorer_versions,
 179      _list_scorers,
 180      _list_webhooks,
 181      _log_batch,
 182      _query_trace_metrics,
 183      _register_scorer,
 184      _rename_registered_model,
 185      _search_evaluation_datasets_handler,
 186      _search_experiments,
 187      _search_issues,
 188      _search_logged_models,
 189      _search_model_versions,
 190      _search_registered_models,
 191      _search_runs,
 192      _search_traces_v3,
 193      _set_dataset_tags_handler,
 194      _set_model_version_tag,
 195      _set_registered_model_alias,
 196      _set_registered_model_tag,
 197      _set_trace_tag,
 198      _set_trace_tag_v3,
 199      _transition_stage,
 200      _update_issue,
 201      _update_model_version,
 202      _update_registered_model,
 203      _upsert_dataset_records_handler,
 204      _validate_source_run,
 205      catch_mlflow_exception,
 206      get_artifact_handler,
 207      get_endpoints,
 208      get_logged_model_artifact_handler,
 209      get_model_version_artifact_handler,
 210      get_trace_artifact_handler,
 211      get_ui_telemetry_handler,
 212      post_ui_telemetry_handler,
 213      upload_artifact_handler,
 214  )
 215  from mlflow.store._unity_catalog.registry.rest_store import UcModelRegistryStore
 216  from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
 217  from mlflow.store.artifact.local_artifact_repo import LocalArtifactRepository
 218  from mlflow.store.artifact.s3_artifact_repo import S3ArtifactRepository
 219  from mlflow.store.entities.paged_list import PagedList
 220  from mlflow.store.model_registry import (
 221      SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
 222      SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
 223  )
 224  from mlflow.store.model_registry.rest_store import RestStore as ModelRegistryRestStore
 225  from mlflow.store.tracking import MAX_RESULTS_QUERY_TRACE_METRICS
 226  from mlflow.store.tracking.databricks_rest_store import DatabricksTracingRestStore
 227  from mlflow.telemetry.schemas import Record, Status
 228  from mlflow.tracing.analysis import TraceFilterCorrelationResult
 229  from mlflow.tracing.utils import build_otel_context
 230  from mlflow.utils.mlflow_tags import MLFLOW_ARTIFACT_LOCATION
 231  from mlflow.utils.proto_json_utils import message_to_json
 232  from mlflow.utils.validation import MAX_BATCH_LOG_REQUEST_SIZE
 233  from mlflow.utils.workspace_context import WorkspaceContext
 234  from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME
 235  
 236  
 237  @pytest.fixture
 238  def mock_get_request_message():
 239      with mock.patch("mlflow.server.handlers._get_request_message") as m:
 240          yield m
 241  
 242  
 243  @pytest.fixture
 244  def mock_get_request_json():
 245      with mock.patch("mlflow.server.handlers._get_request_json") as m:
 246          yield m
 247  
 248  
 249  @pytest.fixture
 250  def mock_tracking_store():
 251      with mock.patch("mlflow.server.handlers._get_tracking_store") as m:
 252          mock_store = mock.MagicMock()
 253          m.return_value = mock_store
 254          yield mock_store
 255  
 256  
 257  @pytest.fixture
 258  def mock_model_registry_store():
 259      with mock.patch("mlflow.server.handlers._get_model_registry_store") as m:
 260          mock_store = mock.MagicMock()
 261          mock_store.list_webhooks_by_event.return_value = PagedList([], None)
 262          m.return_value = mock_store
 263          yield mock_store
 264  
 265  
 266  @pytest.fixture
 267  def enable_serve_artifacts(monkeypatch):
 268      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
 269  
 270  
 271  @pytest.fixture
 272  def mock_evaluation_dataset():
 273      from mlflow.protos.datasets_pb2 import Dataset as ProtoDataset
 274  
 275      dataset = mock.MagicMock()
 276      dataset.dataset_id = "d-1234567890abcdef1234567890abcdef"
 277      dataset.name = "test_dataset"
 278      dataset.digest = "abc123"
 279      dataset.created_time = 1234567890
 280      dataset.last_update_time = 1234567890
 281      dataset.created_by = "test_user"
 282      dataset.last_updated_by = "test_user"
 283      dataset.tags = {"env": "test", "version": "1.0"}
 284      dataset.experiment_ids = ["0", "1"]
 285      dataset._records = []
 286      dataset.schema = json.dumps({
 287          "inputs": {"question": "string"},
 288          "expectations": {"accuracy": "float"},
 289      })
 290      dataset.profile = json.dumps({"record_count": 0})
 291  
 292      proto_dataset = ProtoDataset()
 293      proto_dataset.dataset_id = dataset.dataset_id
 294      proto_dataset.name = dataset.name
 295      proto_dataset.digest = dataset.digest
 296      proto_dataset.created_time = dataset.created_time
 297      proto_dataset.last_update_time = dataset.last_update_time
 298      proto_dataset.created_by = dataset.created_by
 299      proto_dataset.last_updated_by = dataset.last_updated_by
 300      proto_dataset.schema = dataset.schema
 301      proto_dataset.profile = dataset.profile
 302  
 303      dataset.to_proto = mock.MagicMock(return_value=proto_dataset)
 304  
 305      return dataset
 306  
 307  
 308  @pytest.fixture
 309  def mock_telemetry_config_cache():
 310      with mock.patch("mlflow.server.handlers._telemetry_config_cache", {}) as m:
 311          yield m
 312  
 313  
 314  @pytest.fixture
 315  def bypass_telemetry_env_check(monkeypatch):
 316      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_TESTING_TELEMETRY", False)
 317      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_IN_CI_ENV_OR_TESTING", False)
 318      monkeypatch.setattr(mlflow.telemetry.utils, "_IS_MLFLOW_DEV_VERSION", False)
 319  
 320  
 321  @pytest.fixture
 322  def mock_job_store():
 323      with mock.patch("mlflow.server.handlers._get_job_store") as m:
 324          mock_store = mock.MagicMock()
 325          m.return_value = mock_store
 326          yield mock_store
 327  
 328  
 329  def _create_mock_job(
 330      job_id="job-123",
 331      job_name="optimize_prompts",
 332      status_name="PENDING",
 333      params=None,
 334      result=None,
 335      creation_time=1234567890000,
 336      status_details=None,
 337  ):
 338      from mlflow.entities._job import Job
 339      from mlflow.entities._job_status import JobStatus
 340  
 341      if params is None:
 342          params = {
 343              "experiment_id": "exp-123",
 344              "prompt_uri": "prompts:/my-prompt/1",
 345              "run_id": "run-456",
 346          }
 347  
 348      return Job(
 349          job_id=job_id,
 350          creation_time=creation_time,
 351          job_name=job_name,
 352          params=json.dumps(params),
 353          timeout=None,
 354          status=JobStatus.from_str(status_name),
 355          result=json.dumps(result) if result and status_name == "SUCCEEDED" else result,
 356          retry_count=0,
 357          last_update_time=creation_time,
 358          status_details=status_details,
 359      )
 360  
 361  
 362  def _create_mock_run(run_id="run-456", params=None, metrics=None):
 363      mock_run = mock.MagicMock()
 364      mock_run.info.run_id = run_id
 365      mock_run.data.params = params or {}
 366      mock_run.data.metrics = metrics or {}
 367      return mock_run
 368  
 369  
 370  def test_health():
 371      with app.test_client() as c:
 372          response = c.get("/health")
 373          assert response.status_code == 200
 374          assert response.get_data().decode() == "OK"
 375  
 376  
 377  def test_version():
 378      with app.test_client() as c:
 379          response = c.get("/version")
 380          assert response.status_code == 200
 381          assert response.get_data().decode() == mlflow.__version__
 382  
 383  
 384  def test_server_info():
 385      with app.test_client() as c:
 386          response = c.get("/api/3.0/mlflow/server-info")
 387          assert response.status_code == 200
 388          data = response.get_json()
 389          assert data["store_type"] == "SqlStore"
 390          assert data["workspaces_enabled"] is False
 391  
 392  
 393  def test_get_endpoints():
 394      endpoints = get_endpoints()
 395      create_experiment_endpoint = [e for e in endpoints if e[1] == _create_experiment]
 396      assert len(create_experiment_endpoint) == 2
 397  
 398  
 399  def test_convert_path_parameter_to_flask_format():
 400      converted = _convert_path_parameter_to_flask_format("/mlflow/trace")
 401      assert "/mlflow/trace" == converted
 402  
 403      converted = _convert_path_parameter_to_flask_format("/mlflow/trace/{request_id}")
 404      assert "/mlflow/trace/<request_id>" == converted
 405  
 406      converted = _convert_path_parameter_to_flask_format("/mlflow/{foo}/{bar}/{baz}")
 407      assert "/mlflow/<foo>/<bar>/<baz>" == converted
 408  
 409  
 410  def test_all_model_registry_endpoints_available():
 411      endpoints = {handler: method for (path, handler, method) in get_endpoints()}
 412  
 413      # Test that each of the handler is enabled as an endpoint with appropriate method.
 414      expected_endpoints = {
 415          "POST": [
 416              _create_registered_model,
 417              _create_model_version,
 418              _rename_registered_model,
 419              _transition_stage,
 420          ],
 421          "PATCH": [_update_registered_model, _update_model_version],
 422          "DELETE": [_delete_registered_model, _delete_registered_model],
 423          "GET": [
 424              _search_model_versions,
 425              _get_latest_versions,
 426              _get_registered_model,
 427              _get_model_version,
 428              _get_model_version_download_uri,
 429          ],
 430      }
 431      # TODO: efficient mechanism to test endpoint path
 432      for method, handlers in expected_endpoints.items():
 433          for handler in handlers:
 434              assert handler in endpoints
 435              assert endpoints[handler] == [method]
 436  
 437  
 438  def test_can_parse_json():
 439      request = mock.MagicMock()
 440      request.method = "POST"
 441      request.content_type = "application/json"
 442      request.get_json = mock.MagicMock()
 443      request.get_json.return_value = {"name": "hello"}
 444      msg = _get_request_message(CreateExperiment(), flask_request=request)
 445      assert msg.name == "hello"
 446  
 447  
 448  def test_can_parse_post_json_with_unknown_fields():
 449      request = mock.MagicMock()
 450      request.method = "POST"
 451      request.content_type = "application/json"
 452      request.get_json = mock.MagicMock()
 453      request.get_json.return_value = {"name": "hello", "WHAT IS THIS FIELD EVEN": "DOING"}
 454      msg = _get_request_message(CreateExperiment(), flask_request=request)
 455      assert msg.name == "hello"
 456  
 457  
 458  def test_can_parse_post_json_with_content_type_params():
 459      request = mock.MagicMock()
 460      request.method = "POST"
 461      request.content_type = "application/json; charset=utf-8"
 462      request.get_json = mock.MagicMock()
 463      request.get_json.return_value = {"name": "hello"}
 464      msg = _get_request_message(CreateExperiment(), flask_request=request)
 465      assert msg.name == "hello"
 466  
 467  
 468  def test_can_parse_get_json_with_unknown_fields():
 469      request = mock.MagicMock()
 470      request.method = "GET"
 471      request.args = {"name": "hello", "superDuperUnknown": "field"}
 472      msg = _get_request_message(CreateExperiment(), flask_request=request)
 473      assert msg.name == "hello"
 474  
 475  
 476  # Previous versions of the client sent a doubly string encoded JSON blob,
 477  # so this test ensures continued compliance with such clients.
 478  def test_can_parse_json_string():
 479      request = mock.MagicMock()
 480      request.method = "POST"
 481      request.content_type = "application/json"
 482      request.get_json = mock.MagicMock()
 483      request.get_json.return_value = '{"name": "hello2"}'
 484      msg = _get_request_message(CreateExperiment(), flask_request=request)
 485      assert msg.name == "hello2"
 486  
 487  
 488  def test_can_block_post_request_with_invalid_content_type():
 489      request = mock.MagicMock()
 490      request.method = "POST"
 491      request.content_type = "text/plain"
 492      request.get_json = mock.MagicMock()
 493      request.get_json.return_value = {"name": "hello"}
 494      with pytest.raises(MlflowException, match=r"Bad Request. Content-Type"):
 495          _get_request_message(CreateExperiment(), flask_request=request)
 496  
 497  
 498  def test_can_block_post_request_with_missing_content_type():
 499      request = mock.MagicMock()
 500      request.method = "POST"
 501      request.content_type = None
 502      request.get_json = mock.MagicMock()
 503      request.get_json.return_value = {"name": "hello"}
 504      with pytest.raises(MlflowException, match=r"Bad Request. Content-Type"):
 505          _get_request_message(CreateExperiment(), flask_request=request)
 506  
 507  
 508  def test_search_runs_default_view_type(mock_get_request_message, mock_tracking_store):
 509      """
 510      Search Runs default view type is filled in as ViewType.ACTIVE_ONLY
 511      """
 512      mock_get_request_message.return_value = SearchRuns(experiment_ids=["0"])
 513      mock_tracking_store.search_runs.return_value = PagedList([], None)
 514      _search_runs()
 515      _, kwargs = mock_tracking_store.search_runs.call_args
 516      assert kwargs["run_view_type"] == ViewType.ACTIVE_ONLY
 517  
 518  
 519  def test_search_runs_empty_page_token(mock_get_request_message, mock_tracking_store):
 520      """
 521      Test that empty page_token from protobuf is converted to None before calling store
 522      """
 523      # Create proto without setting page_token
 524      search_runs_proto = SearchRuns()
 525      search_runs_proto.experiment_ids.append("0")
 526      search_runs_proto.max_results = 10
 527      # Verify protobuf returns empty string for unset field
 528      assert search_runs_proto.page_token == ""
 529  
 530      mock_get_request_message.return_value = search_runs_proto
 531      mock_tracking_store.search_runs.return_value = PagedList([], None)
 532  
 533      _search_runs()
 534  
 535      # Verify store was called with None, not empty string
 536      mock_tracking_store.search_runs.assert_called_once()
 537      call_kwargs = mock_tracking_store.search_runs.call_args.kwargs
 538      assert call_kwargs["page_token"] is None  # page_token should be None, not ""
 539  
 540  
 541  def test_log_batch_api_req(mock_get_request_json):
 542      mock_get_request_json.return_value = "a" * (MAX_BATCH_LOG_REQUEST_SIZE + 1)
 543      response = _log_batch()
 544      assert response.status_code == 400
 545      json_response = json.loads(response.get_data())
 546      assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 547      assert (
 548          f"Batched logging API requests must be at most {MAX_BATCH_LOG_REQUEST_SIZE} bytes"
 549          in json_response["message"]
 550      )
 551  
 552  
 553  def test_catch_mlflow_exception():
 554      @catch_mlflow_exception
 555      def test_handler():
 556          raise MlflowException("test error", error_code=INTERNAL_ERROR)
 557  
 558      response = test_handler()
 559      json_response = json.loads(response.get_data())
 560      assert response.status_code == 500
 561      assert json_response["error_code"] == ErrorCode.Name(INTERNAL_ERROR)
 562      assert json_response["message"] == "test error"
 563  
 564  
 565  def test_mlflow_server_with_installed_plugin(tmp_path, monkeypatch):
 566      pytest.skip("FileStore is no longer supported.")
 567      from mlflow_test_plugin.file_store import PluginFileStore
 568  
 569      monkeypatch.setenv(BACKEND_STORE_URI_ENV_VAR, f"file-plugin:{tmp_path}")
 570      monkeypatch.setattr(mlflow.server.handlers, "_tracking_store", None)
 571      plugin_file_store = mlflow.server.handlers._get_tracking_store()
 572      assert isinstance(plugin_file_store, PluginFileStore)
 573      assert plugin_file_store.is_plugin
 574  
 575  
 576  def jsonify(obj):
 577      def _jsonify(obj):
 578          return json.loads(message_to_json(obj.to_proto()))
 579  
 580      if isinstance(obj, list):
 581          return [_jsonify(o) for o in obj]
 582      else:
 583          return _jsonify(obj)
 584  
 585  
 586  # Tests for Model Registry handlers
 587  def test_create_registered_model(mock_get_request_message, mock_model_registry_store):
 588      tags = [
 589          RegisteredModelTag(key="key", value="value"),
 590          RegisteredModelTag(key="anotherKey", value="some other value"),
 591      ]
 592      mock_get_request_message.return_value = CreateRegisteredModel(
 593          name="model_1", tags=[tag.to_proto() for tag in tags]
 594      )
 595      rm = RegisteredModel("model_1", tags=tags)
 596      mock_model_registry_store.create_registered_model.return_value = rm
 597      resp = _create_registered_model()
 598      _, args = mock_model_registry_store.create_registered_model.call_args
 599      assert args["name"] == "model_1"
 600      assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags}
 601      assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)}
 602  
 603  
 604  def test_get_registered_model(mock_get_request_message, mock_model_registry_store):
 605      name = "model1"
 606      mock_get_request_message.return_value = GetRegisteredModel(name=name)
 607      rmd = RegisteredModel(
 608          name=name,
 609          creation_timestamp=111,
 610          last_updated_timestamp=222,
 611          description="Test model",
 612          latest_versions=[],
 613      )
 614      mock_model_registry_store.get_registered_model.return_value = rmd
 615      resp = _get_registered_model()
 616      _, args = mock_model_registry_store.get_registered_model.call_args
 617      assert args == {"name": name}
 618      assert json.loads(resp.get_data()) == {"registered_model": jsonify(rmd)}
 619  
 620  
 621  def test_update_registered_model(mock_get_request_message, mock_model_registry_store):
 622      name = "model_1"
 623      description = "Test model"
 624      mock_get_request_message.return_value = UpdateRegisteredModel(
 625          name=name, description=description
 626      )
 627      rm2 = RegisteredModel(name, description=description)
 628      mock_model_registry_store.update_registered_model.return_value = rm2
 629      resp = _update_registered_model()
 630      _, args = mock_model_registry_store.update_registered_model.call_args
 631      assert args == {"name": name, "description": "Test model"}
 632      assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
 633  
 634  
 635  def test_rename_registered_model(mock_get_request_message, mock_model_registry_store):
 636      name = "model_1"
 637      new_name = "model_2"
 638      mock_get_request_message.return_value = RenameRegisteredModel(name=name, new_name=new_name)
 639      rm2 = RegisteredModel(new_name)
 640      mock_model_registry_store.rename_registered_model.return_value = rm2
 641      resp = _rename_registered_model()
 642      _, args = mock_model_registry_store.rename_registered_model.call_args
 643      assert args == {"name": name, "new_name": new_name}
 644      assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm2)}
 645  
 646  
 647  def test_delete_registered_model(mock_get_request_message, mock_model_registry_store):
 648      name = "model_1"
 649      mock_get_request_message.return_value = DeleteRegisteredModel(name=name)
 650      _delete_registered_model()
 651      _, args = mock_model_registry_store.delete_registered_model.call_args
 652      assert args == {"name": name}
 653  
 654  
 655  def test_search_registered_models(mock_get_request_message, mock_model_registry_store):
 656      rmds = [
 657          RegisteredModel(
 658              name="model_1",
 659              creation_timestamp=111,
 660              last_updated_timestamp=222,
 661              description="Test model",
 662              latest_versions=[],
 663          ),
 664          RegisteredModel(
 665              name="model_2",
 666              creation_timestamp=111,
 667              last_updated_timestamp=333,
 668              description="Another model",
 669              latest_versions=[],
 670          ),
 671      ]
 672      mock_get_request_message.return_value = SearchRegisteredModels()
 673      mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, None)
 674      resp = _search_registered_models()
 675      _, args = mock_model_registry_store.search_registered_models.call_args
 676      assert args == {
 677          "filter_string": "",
 678          "max_results": SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
 679          "order_by": [],
 680          "page_token": None,
 681      }
 682      assert json.loads(resp.get_data()) == {"registered_models": jsonify(rmds)}
 683  
 684      mock_get_request_message.return_value = SearchRegisteredModels(filter="hello")
 685      mock_model_registry_store.search_registered_models.return_value = PagedList(rmds[:1], "tok")
 686      resp = _search_registered_models()
 687      _, args = mock_model_registry_store.search_registered_models.call_args
 688      assert args == {
 689          "filter_string": "hello",
 690          "max_results": SEARCH_REGISTERED_MODEL_MAX_RESULTS_DEFAULT,
 691          "order_by": [],
 692          "page_token": None,
 693      }
 694      assert json.loads(resp.get_data()) == {
 695          "registered_models": jsonify(rmds[:1]),
 696          "next_page_token": "tok",
 697      }
 698  
 699      mock_get_request_message.return_value = SearchRegisteredModels(filter="hi", max_results=5)
 700      mock_model_registry_store.search_registered_models.return_value = PagedList([rmds[0]], "tik")
 701      resp = _search_registered_models()
 702      _, args = mock_model_registry_store.search_registered_models.call_args
 703      assert args == {"filter_string": "hi", "max_results": 5, "order_by": [], "page_token": None}
 704      assert json.loads(resp.get_data()) == {
 705          "registered_models": jsonify([rmds[0]]),
 706          "next_page_token": "tik",
 707      }
 708  
 709      mock_get_request_message.return_value = SearchRegisteredModels(
 710          filter="hey", max_results=500, order_by=["a", "B desc"], page_token="prev"
 711      )
 712      mock_model_registry_store.search_registered_models.return_value = PagedList(rmds, "DONE")
 713      resp = _search_registered_models()
 714      _, args = mock_model_registry_store.search_registered_models.call_args
 715      assert args == {
 716          "filter_string": "hey",
 717          "max_results": 500,
 718          "order_by": ["a", "B desc"],
 719          "page_token": "prev",
 720      }
 721      assert json.loads(resp.get_data()) == {
 722          "registered_models": jsonify(rmds),
 723          "next_page_token": "DONE",
 724      }
 725  
 726  
 727  def test_get_latest_versions(mock_get_request_message, mock_model_registry_store):
 728      name = "model1"
 729      mock_get_request_message.return_value = GetLatestVersions(name=name)
 730      mvds = [
 731          ModelVersion(
 732              name=name,
 733              version="5",
 734              creation_timestamp=1,
 735              last_updated_timestamp=12,
 736              description="v 5",
 737              user_id="u1",
 738              current_stage="Production",
 739              source="A/B",
 740              run_id=uuid.uuid4().hex,
 741              status="READY",
 742              status_message=None,
 743          ),
 744          ModelVersion(
 745              name=name,
 746              version="1",
 747              creation_timestamp=1,
 748              last_updated_timestamp=1200,
 749              description="v 1",
 750              user_id="u1",
 751              current_stage="Archived",
 752              source="A/B2",
 753              run_id=uuid.uuid4().hex,
 754              status="READY",
 755              status_message=None,
 756          ),
 757          ModelVersion(
 758              name=name,
 759              version="12",
 760              creation_timestamp=100,
 761              last_updated_timestamp=None,
 762              description="v 12",
 763              user_id="u2",
 764              current_stage="Staging",
 765              source="A/B3",
 766              run_id=uuid.uuid4().hex,
 767              status="READY",
 768              status_message=None,
 769          ),
 770      ]
 771      mock_model_registry_store.get_latest_versions.return_value = mvds
 772      resp = _get_latest_versions()
 773      _, args = mock_model_registry_store.get_latest_versions.call_args
 774      assert args == {"name": name, "stages": []}
 775      assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}
 776  
 777      for stages in [[], ["None"], ["Staging"], ["Staging", "Production"]]:
 778          mock_get_request_message.return_value = GetLatestVersions(name=name, stages=stages)
 779          _get_latest_versions()
 780          _, args = mock_model_registry_store.get_latest_versions.call_args
 781          assert args == {"name": name, "stages": stages}
 782  
 783  
 784  def test_create_model_version(mock_get_request_message, mock_model_registry_store):
 785      run_id = uuid.uuid4().hex
 786      tags = [
 787          ModelVersionTag(key="key", value="value"),
 788          ModelVersionTag(key="anotherKey", value="some other value"),
 789      ]
 790      run_link = "localhost:5000/path/to/run"
 791      mock_get_request_message.return_value = CreateModelVersion(
 792          name="model_1",
 793          source=f"runs:/{run_id}",
 794          run_id=run_id,
 795          run_link=run_link,
 796          tags=[tag.to_proto() for tag in tags],
 797      )
 798      mv = ModelVersion(
 799          name="model_1", version="12", creation_timestamp=123, tags=tags, run_link=run_link
 800      )
 801      mock_model_registry_store.create_model_version.return_value = mv
 802      resp = _create_model_version()
 803      _, args = mock_model_registry_store.create_model_version.call_args
 804      assert args["name"] == "model_1"
 805      assert args["source"] == f"runs:/{run_id}"
 806      assert args["run_id"] == run_id
 807      assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags}
 808      assert args["run_link"] == run_link
 809      assert json.loads(resp.get_data()) == {"model_version": jsonify(mv)}
 810  
 811  
 812  @pytest.mark.parametrize(
 813      "source",
 814      [
 815          "file:///etc/passwd",
 816          "file:///",
 817          "/etc/passwd",
 818          "file:///proc/self/environ",
 819          "file://remote-host/etc/passwd",
 820          "file://remote-host/",
 821      ],
 822  )
 823  def test_create_model_version_rejects_local_source_for_prompts(
 824      mock_get_request_message, mock_model_registry_store, source
 825  ):
 826      mock_get_request_message.return_value = CreateModelVersion(
 827          name="model_1",
 828          source=source,
 829          tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true").to_proto()],
 830      )
 831      resp = _create_model_version()
 832      assert resp.status_code == 400
 833      assert "Invalid prompt source" in resp.get_json()["message"]
 834  
 835  
 836  @pytest.mark.parametrize(
 837      "source",
 838      [
 839          "https://example.com/../../etc/passwd",
 840          "http://example.com/path/..%2f..%2fsecret",
 841      ],
 842  )
 843  def test_create_model_version_rejects_traversal_source_for_prompts(
 844      mock_get_request_message, mock_model_registry_store, source
 845  ):
 846      mock_get_request_message.return_value = CreateModelVersion(
 847          name="model_1",
 848          source=source,
 849          tags=[ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true").to_proto()],
 850      )
 851      resp = _create_model_version()
 852      assert resp.status_code == 400
 853      assert "Invalid model version source" in resp.get_json()["message"]
 854  
 855  
 856  def test_set_registered_model_tag(mock_get_request_message, mock_model_registry_store):
 857      name = "model1"
 858      tag = RegisteredModelTag(key="some weird key", value="some value")
 859      mock_get_request_message.return_value = SetRegisteredModelTag(
 860          name=name, key=tag.key, value=tag.value
 861      )
 862      _set_registered_model_tag()
 863      _, args = mock_model_registry_store.set_registered_model_tag.call_args
 864      assert args == {"name": name, "tag": tag}
 865  
 866  
 867  def test_delete_registered_model_tag(mock_get_request_message, mock_model_registry_store):
 868      name = "model1"
 869      key = "some weird key"
 870      mock_get_request_message.return_value = DeleteRegisteredModelTag(name=name, key=key)
 871      _delete_registered_model_tag()
 872      _, args = mock_model_registry_store.delete_registered_model_tag.call_args
 873      assert args == {"name": name, "key": key}
 874  
 875  
 876  def test_get_model_version_details(mock_get_request_message, mock_model_registry_store):
 877      mock_get_request_message.return_value = GetModelVersion(name="model1", version="32")
 878      mvd = ModelVersion(
 879          name="model1",
 880          version="5",
 881          creation_timestamp=1,
 882          last_updated_timestamp=12,
 883          description="v 5",
 884          user_id="u1",
 885          current_stage="Production",
 886          source="A/B",
 887          run_id=uuid.uuid4().hex,
 888          status="READY",
 889          status_message=None,
 890      )
 891      mock_model_registry_store.get_model_version.return_value = mvd
 892      resp = _get_model_version()
 893      _, args = mock_model_registry_store.get_model_version.call_args
 894      assert args == {"name": "model1", "version": "32"}
 895      assert json.loads(resp.get_data()) == {"model_version": jsonify(mvd)}
 896  
 897  
 898  def test_update_model_version(mock_get_request_message, mock_model_registry_store):
 899      name = "model1"
 900      version = "32"
 901      description = "Great model!"
 902      mock_get_request_message.return_value = UpdateModelVersion(
 903          name=name, version=version, description=description
 904      )
 905  
 906      mv = ModelVersion(name=name, version=version, creation_timestamp=123, description=description)
 907      mock_model_registry_store.update_model_version.return_value = mv
 908      _update_model_version()
 909      _, args = mock_model_registry_store.update_model_version.call_args
 910      assert args == {"name": name, "version": version, "description": description}
 911  
 912  
 913  def test_transition_model_version_stage(mock_get_request_message, mock_model_registry_store):
 914      name = "model1"
 915      version = "32"
 916      stage = "Production"
 917      mock_get_request_message.return_value = TransitionModelVersionStage(
 918          name=name, version=version, stage=stage
 919      )
 920      mv = ModelVersion(name=name, version=version, creation_timestamp=123, current_stage=stage)
 921      mock_model_registry_store.transition_model_version_stage.return_value = mv
 922      _transition_stage()
 923      _, args = mock_model_registry_store.transition_model_version_stage.call_args
 924      assert args == {
 925          "name": name,
 926          "version": version,
 927          "stage": stage,
 928          "archive_existing_versions": False,
 929      }
 930  
 931  
 932  def test_delete_model_version(mock_get_request_message, mock_model_registry_store):
 933      name = "model1"
 934      version = "32"
 935      mock_get_request_message.return_value = DeleteModelVersion(name=name, version=version)
 936      _delete_model_version()
 937      _, args = mock_model_registry_store.delete_model_version.call_args
 938      assert args == {"name": name, "version": version}
 939  
 940  
 941  def test_get_model_version_download_uri(mock_get_request_message, mock_model_registry_store):
 942      name = "model1"
 943      version = "32"
 944      mock_get_request_message.return_value = GetModelVersionDownloadUri(name=name, version=version)
 945      mock_model_registry_store.get_model_version_download_uri.return_value = "some/download/path"
 946      resp = _get_model_version_download_uri()
 947      _, args = mock_model_registry_store.get_model_version_download_uri.call_args
 948      assert args == {"name": name, "version": version}
 949      assert json.loads(resp.get_data()) == {"artifact_uri": "some/download/path"}
 950  
 951  
 952  def test_search_model_versions(mock_get_request_message, mock_model_registry_store):
 953      mvds = [
 954          ModelVersion(
 955              name="model_1",
 956              version="5",
 957              creation_timestamp=100,
 958              last_updated_timestamp=3200,
 959              description="v 5",
 960              user_id="u1",
 961              current_stage="Production",
 962              source="A/B/CD",
 963              run_id=uuid.uuid4().hex,
 964              status="READY",
 965              status_message=None,
 966          ),
 967          ModelVersion(
 968              name="model_1",
 969              version="12",
 970              creation_timestamp=110,
 971              last_updated_timestamp=2000,
 972              description="v 12",
 973              user_id="u2",
 974              current_stage="Production",
 975              source="A/B/CD",
 976              run_id=uuid.uuid4().hex,
 977              status="READY",
 978              status_message=None,
 979          ),
 980          ModelVersion(
 981              name="ads_model",
 982              version="8",
 983              creation_timestamp=200,
 984              last_updated_timestamp=1000,
 985              description="v 8",
 986              user_id="u1",
 987              current_stage="Staging",
 988              source="A/B/CD",
 989              run_id=uuid.uuid4().hex,
 990              status="READY",
 991              status_message=None,
 992          ),
 993          ModelVersion(
 994              name="fraud_detection_model",
 995              version="345",
 996              creation_timestamp=1000,
 997              last_updated_timestamp=999,
 998              description="newest version",
 999              user_id="u12",
1000              current_stage="None",
1001              source="A/B/CD",
1002              run_id=uuid.uuid4().hex,
1003              status="READY",
1004              status_message=None,
1005          ),
1006      ]
1007      mock_get_request_message.return_value = SearchModelVersions(filter="source_path = 'A/B/CD'")
1008      mock_model_registry_store.search_model_versions.return_value = PagedList(mvds, None)
1009      resp = _search_model_versions()
1010      mock_model_registry_store.search_model_versions.assert_called_with(
1011          filter_string="source_path = 'A/B/CD'",
1012          max_results=SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
1013          order_by=[],
1014          page_token=None,
1015      )
1016      assert json.loads(resp.get_data()) == {"model_versions": jsonify(mvds)}
1017  
1018      mock_get_request_message.return_value = SearchModelVersions(filter="name='model_1'")
1019      mock_model_registry_store.search_model_versions.return_value = PagedList(mvds[:1], "tok")
1020      resp = _search_model_versions()
1021      mock_model_registry_store.search_model_versions.assert_called_with(
1022          filter_string="name='model_1'",
1023          max_results=SEARCH_MODEL_VERSION_MAX_RESULTS_THRESHOLD,
1024          order_by=[],
1025          page_token=None,
1026      )
1027      assert json.loads(resp.get_data()) == {
1028          "model_versions": jsonify(mvds[:1]),
1029          "next_page_token": "tok",
1030      }
1031  
1032      mock_get_request_message.return_value = SearchModelVersions(filter="version<=12", max_results=2)
1033      mock_model_registry_store.search_model_versions.return_value = PagedList(
1034          [mvds[0], mvds[2]], "next"
1035      )
1036      resp = _search_model_versions()
1037      mock_model_registry_store.search_model_versions.assert_called_with(
1038          filter_string="version<=12", max_results=2, order_by=[], page_token=None
1039      )
1040      assert json.loads(resp.get_data()) == {
1041          "model_versions": jsonify([mvds[0], mvds[2]]),
1042          "next_page_token": "next",
1043      }
1044  
1045      mock_get_request_message.return_value = SearchModelVersions(
1046          filter="version<=12", max_results=2, order_by=["version DESC"], page_token="prev"
1047      )
1048      mock_model_registry_store.search_model_versions.return_value = PagedList(mvds[1:3], "next")
1049      resp = _search_model_versions()
1050      mock_model_registry_store.search_model_versions.assert_called_with(
1051          filter_string="version<=12", max_results=2, order_by=["version DESC"], page_token="prev"
1052      )
1053      assert json.loads(resp.get_data()) == {
1054          "model_versions": jsonify(mvds[1:3]),
1055          "next_page_token": "next",
1056      }
1057  
1058  
1059  def test_set_model_version_tag(mock_get_request_message, mock_model_registry_store):
1060      name = "model1"
1061      version = "1"
1062      tag = ModelVersionTag(key="some weird key", value="some value")
1063      mock_get_request_message.return_value = SetModelVersionTag(
1064          name=name, version=version, key=tag.key, value=tag.value
1065      )
1066      _set_model_version_tag()
1067      _, args = mock_model_registry_store.set_model_version_tag.call_args
1068      assert args == {"name": name, "version": version, "tag": tag}
1069  
1070  
1071  def test_delete_model_version_tag(mock_get_request_message, mock_model_registry_store):
1072      name = "model1"
1073      version = "1"
1074      key = "some weird key"
1075      mock_get_request_message.return_value = DeleteModelVersionTag(
1076          name=name, version=version, key=key
1077      )
1078      _delete_model_version_tag()
1079      _, args = mock_model_registry_store.delete_model_version_tag.call_args
1080      assert args == {"name": name, "version": version, "key": key}
1081  
1082  
1083  def test_set_registered_model_alias(mock_get_request_message, mock_model_registry_store):
1084      name = "model1"
1085      alias = "test_alias"
1086      version = "1"
1087      mock_get_request_message.return_value = SetRegisteredModelAlias(
1088          name=name, alias=alias, version=version
1089      )
1090      _set_registered_model_alias()
1091      _, args = mock_model_registry_store.set_registered_model_alias.call_args
1092      assert args == {"name": name, "alias": alias, "version": version}
1093  
1094  
1095  def test_delete_registered_model_alias(mock_get_request_message, mock_model_registry_store):
1096      name = "model1"
1097      alias = "test_alias"
1098      mock_get_request_message.return_value = DeleteRegisteredModelAlias(name=name, alias=alias)
1099      _delete_registered_model_alias()
1100      _, args = mock_model_registry_store.delete_registered_model_alias.call_args
1101      assert args == {"name": name, "alias": alias}
1102  
1103  
1104  def test_get_model_version_by_alias(mock_get_request_message, mock_model_registry_store):
1105      name = "model1"
1106      alias = "test_alias"
1107      mock_get_request_message.return_value = GetModelVersionByAlias(name=name, alias=alias)
1108      mvd = ModelVersion(
1109          name="model1",
1110          version="5",
1111          creation_timestamp=1,
1112          last_updated_timestamp=12,
1113          description="v 5",
1114          user_id="u1",
1115          current_stage="Production",
1116          source="A/B",
1117          run_id=uuid.uuid4().hex,
1118          status="READY",
1119          status_message=None,
1120          aliases=["test_alias"],
1121      )
1122      mock_model_registry_store.get_model_version_by_alias.return_value = mvd
1123      resp = _get_model_version_by_alias()
1124      _, args = mock_model_registry_store.get_model_version_by_alias.call_args
1125      assert args == {"name": name, "alias": alias}
1126      assert json.loads(resp.get_data()) == {"model_version": jsonify(mvd)}
1127  
1128  
1129  @pytest.mark.parametrize(
1130      "path",
1131      [
1132          "/path",
1133          "path/../to/file",
1134          "/etc/passwd",
1135          "/etc/passwd%00.jpg",
1136          "/file://etc/passwd",
1137          "%2E%2E%2F%2E%2E%2Fpath",
1138      ],
1139  )
1140  def test_delete_artifact_mlflow_artifacts_throws_for_malicious_path(enable_serve_artifacts, path):
1141      response = _delete_artifact_mlflow_artifacts(path)
1142      assert response.status_code == 400
1143      json_response = json.loads(response.get_data())
1144      assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1145      assert json_response["message"] == "Invalid path"
1146  
1147  
1148  def test_get_presigned_download_url_success(enable_serve_artifacts):
1149      from mlflow.store.artifact.artifact_repo import MultipartDownloadMixin
1150  
1151      class MockMultipartDownloadRepo(MultipartDownloadMixin):
1152          def get_download_presigned_url(self, artifact_path, expiration=300):
1153              return PresignedDownloadUrlResponse(
1154                  url="https://storage.example.com/presigned?token=abc",
1155                  headers={"x-custom-header": "value"},
1156                  file_size=1024,
1157              )
1158  
1159      artifact_path = "run_id/artifacts/model.pkl"
1160      with (
1161          app.test_request_context(method="GET"),
1162          mock.patch(
1163              "mlflow.server.handlers._get_artifact_repo_mlflow_artifacts",
1164              return_value=MockMultipartDownloadRepo(),
1165          ),
1166      ):
1167          response = _get_presigned_download_url(artifact_path)
1168  
1169      assert response.status_code == 200
1170      data = json.loads(response.get_data())
1171      assert data["url"] == "https://storage.example.com/presigned?token=abc"
1172      assert data["headers"] == {"x-custom-header": "value"}
1173      assert data["file_size"] == 1024
1174  
1175  
1176  @pytest.mark.parametrize(
1177      "path",
1178      [
1179          "/path",
1180          "path/../to/file",
1181          "/etc/passwd",
1182          "/etc/passwd%00.jpg",
1183          "/file://etc/passwd",
1184          "%2E%2E%2F%2E%2E%2Fpath",
1185      ],
1186  )
1187  def test_get_presigned_download_url_throws_for_malicious_path(enable_serve_artifacts, path):
1188      response = _get_presigned_download_url(path)
1189      assert response.status_code == 400
1190      json_response = json.loads(response.get_data())
1191      assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1192      assert json_response["message"] == "Invalid path"
1193  
1194  
1195  def test_get_presigned_download_url_unsupported_repo(enable_serve_artifacts, tmp_path):
1196      with (
1197          app.test_request_context(method="GET"),
1198          mock.patch(
1199              "mlflow.server.handlers._get_artifact_repo_mlflow_artifacts",
1200              return_value=LocalArtifactRepository(str(tmp_path)),
1201          ),
1202      ):
1203          response = _get_presigned_download_url("some/artifact/path")
1204  
1205      assert response.status_code == 501
1206      json_response = json.loads(response.get_data())
1207      assert json_response["error_code"] == ErrorCode.Name(NOT_IMPLEMENTED)
1208      assert "multipart" in json_response["message"].lower()
1209  
1210  
1211  # --- Presigned upload URL handler tests ---
1212  
1213  
1214  def test_create_presigned_upload_url_success():
1215      from mlflow.store.artifact.artifact_repo import PresignedUploadMixin
1216  
1217      class MockPresignedUploadRepo(PresignedUploadMixin):
1218          def create_presigned_upload_url(self, artifact_path, expiration=900):
1219              return CreatePresignedUploadResponse(
1220                  presigned_url="https://s3.amazonaws.com/bucket/artifacts/model.pkl?X-Amz-Signature=abc",
1221                  headers={"Content-Type": "application/octet-stream"},
1222              )
1223  
1224      mock_run = mock.MagicMock()
1225      mock_run.info.artifact_uri = "s3://bucket/0/abc123/artifacts"
1226  
1227      from mlflow.protos.service_pb2 import CreatePresignedUploadUrl
1228  
1229      request_proto = CreatePresignedUploadUrl()
1230      request_proto.run_id = "abc123"
1231      request_proto.path = "model.pkl"
1232  
1233      with (
1234          app.test_request_context(method="POST", content_type="application/json"),
1235          mock.patch(
1236              "mlflow.server.handlers._get_request_message",
1237              return_value=request_proto,
1238          ),
1239          mock.patch(
1240              "mlflow.server.handlers._get_tracking_store",
1241          ) as mock_store,
1242          mock.patch(
1243              "mlflow.server.handlers._get_artifact_repo",
1244              return_value=MockPresignedUploadRepo(),
1245          ),
1246      ):
1247          mock_store.return_value.get_run.return_value = mock_run
1248          response = _create_presigned_upload_url()
1249  
1250      assert response.status_code == 200
1251      data = json.loads(response.get_data())
1252      assert "presigned_url" in data
1253      assert "X-Amz-Signature" in data["presigned_url"]
1254      assert data["headers"] == {"Content-Type": "application/octet-stream"}
1255  
1256  
1257  def test_create_presigned_upload_url_unsupported_repo():
1258      mock_run = mock.MagicMock()
1259      mock_run.info.artifact_uri = "file:///tmp/artifacts"
1260  
1261      from mlflow.protos.service_pb2 import CreatePresignedUploadUrl
1262  
1263      request_proto = CreatePresignedUploadUrl()
1264      request_proto.run_id = "abc123"
1265      request_proto.path = "model.pkl"
1266  
1267      with (
1268          app.test_request_context(method="POST", content_type="application/json"),
1269          mock.patch(
1270              "mlflow.server.handlers._get_request_message",
1271              return_value=request_proto,
1272          ),
1273          mock.patch(
1274              "mlflow.server.handlers._get_tracking_store",
1275          ) as mock_store,
1276          mock.patch(
1277              "mlflow.server.handlers._get_artifact_repo",
1278              return_value=LocalArtifactRepository("/tmp/artifacts"),
1279          ),
1280      ):
1281          mock_store.return_value.get_run.return_value = mock_run
1282          response = _create_presigned_upload_url()
1283  
1284      assert response.status_code == 501
1285      json_response = json.loads(response.get_data())
1286      assert json_response["error_code"] == ErrorCode.Name(NOT_IMPLEMENTED)
1287      assert "presigned upload" in json_response["message"].lower()
1288  
1289  
1290  @pytest.mark.parametrize(
1291      "artifact_uri",
1292      [
1293          "mlflow-artifacts:/0/abc123/artifacts",
1294          "http://mlflow-server:5000/api/2.0/mlflow-artifacts/artifacts",
1295          "https://mlflow-server/api/2.0/mlflow-artifacts/artifacts",
1296      ],
1297  )
1298  def test_create_presigned_upload_url_rejects_proxy_artifact_uri(artifact_uri):
1299      mock_run = mock.MagicMock()
1300      mock_run.info.artifact_uri = artifact_uri
1301  
1302      from mlflow.protos.service_pb2 import CreatePresignedUploadUrl
1303  
1304      request_proto = CreatePresignedUploadUrl()
1305      request_proto.run_id = "abc123"
1306      request_proto.path = "model.pkl"
1307  
1308      with (
1309          app.test_request_context(method="POST", content_type="application/json"),
1310          mock.patch(
1311              "mlflow.server.handlers._get_request_message",
1312              return_value=request_proto,
1313          ),
1314          mock.patch(
1315              "mlflow.server.handlers._get_tracking_store",
1316          ) as mock_store,
1317      ):
1318          mock_store.return_value.get_run.return_value = mock_run
1319          response = _create_presigned_upload_url()
1320  
1321      assert response.status_code == 400
1322      json_response = json.loads(response.get_data())
1323      assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1324      assert "proxied" in json_response["message"].lower()
1325  
1326  
1327  def test_create_presigned_upload_url_invalid_run_id():
1328      from mlflow.protos.service_pb2 import CreatePresignedUploadUrl
1329  
1330      request_proto = CreatePresignedUploadUrl()
1331      request_proto.run_id = "nonexistent_run"
1332      request_proto.path = "model.pkl"
1333  
1334      with (
1335          app.test_request_context(method="POST", content_type="application/json"),
1336          mock.patch(
1337              "mlflow.server.handlers._get_request_message",
1338              return_value=request_proto,
1339          ),
1340          mock.patch(
1341              "mlflow.server.handlers._get_tracking_store",
1342          ) as mock_store,
1343      ):
1344          mock_store.return_value.get_run.side_effect = MlflowException(
1345              "Run 'nonexistent_run' not found",
1346              error_code=RESOURCE_DOES_NOT_EXIST,
1347          )
1348          response = _create_presigned_upload_url()
1349  
1350      assert response.status_code == 404
1351      json_response = json.loads(response.get_data())
1352      assert json_response["error_code"] == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
1353  
1354  
1355  @pytest.mark.parametrize(
1356      "path",
1357      [
1358          "../../../etc/passwd",
1359          "path/../to/file",
1360          "/etc/passwd",
1361          "/etc/passwd%00.jpg",
1362          "%2E%2E%2F%2E%2E%2Fpath",
1363      ],
1364  )
1365  def test_create_presigned_upload_url_rejects_path_traversal(path):
1366      from mlflow.protos.service_pb2 import CreatePresignedUploadUrl
1367  
1368      request_proto = CreatePresignedUploadUrl()
1369      request_proto.run_id = "abc123"
1370      request_proto.path = path
1371  
1372      with (
1373          app.test_request_context(method="POST", content_type="application/json"),
1374          mock.patch(
1375              "mlflow.server.handlers._get_request_message",
1376              return_value=request_proto,
1377          ),
1378      ):
1379          response = _create_presigned_upload_url()
1380  
1381      assert response.status_code == 400
1382      json_response = json.loads(response.get_data())
1383      assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1384  
1385  
1386  def test_create_presigned_upload_url_with_custom_expiration():
1387      from mlflow.store.artifact.artifact_repo import PresignedUploadMixin
1388  
1389      captured_expiration = {}
1390  
1391      class MockPresignedUploadRepo(PresignedUploadMixin):
1392          def create_presigned_upload_url(self, artifact_path, expiration=900):
1393              captured_expiration["value"] = expiration
1394              return CreatePresignedUploadResponse(
1395                  presigned_url="https://example.com/presigned",
1396                  headers={},
1397              )
1398  
1399      mock_run = mock.MagicMock()
1400      mock_run.info.artifact_uri = "s3://bucket/0/abc123/artifacts"
1401  
1402      from mlflow.protos.service_pb2 import CreatePresignedUploadUrl
1403  
1404      request_proto = CreatePresignedUploadUrl()
1405      request_proto.run_id = "abc123"
1406      request_proto.path = "model.pkl"
1407      request_proto.expiration = 60
1408  
1409      with (
1410          app.test_request_context(method="POST", content_type="application/json"),
1411          mock.patch(
1412              "mlflow.server.handlers._get_request_message",
1413              return_value=request_proto,
1414          ),
1415          mock.patch(
1416              "mlflow.server.handlers._get_tracking_store",
1417          ) as mock_store,
1418          mock.patch(
1419              "mlflow.server.handlers._get_artifact_repo",
1420              return_value=MockPresignedUploadRepo(),
1421          ),
1422      ):
1423          mock_store.return_value.get_run.return_value = mock_run
1424          response = _create_presigned_upload_url()
1425  
1426      assert response.status_code == 200
1427      assert captured_expiration["value"] == 60
1428  
1429  
1430  def test_create_presigned_upload_url_default_expiration():
1431      from mlflow.store.artifact.artifact_repo import PresignedUploadMixin
1432  
1433      captured_expiration = {}
1434  
1435      class MockPresignedUploadRepo(PresignedUploadMixin):
1436          def create_presigned_upload_url(self, artifact_path, expiration=900):
1437              captured_expiration["value"] = expiration
1438              return CreatePresignedUploadResponse(
1439                  presigned_url="https://example.com/presigned",
1440                  headers={},
1441              )
1442  
1443      mock_run = mock.MagicMock()
1444      mock_run.info.artifact_uri = "s3://bucket/0/abc123/artifacts"
1445  
1446      from mlflow.protos.service_pb2 import CreatePresignedUploadUrl
1447  
1448      # Don't set expiration - should default to 900
1449      request_proto = CreatePresignedUploadUrl()
1450      request_proto.run_id = "abc123"
1451      request_proto.path = "model.pkl"
1452  
1453      with (
1454          app.test_request_context(method="POST", content_type="application/json"),
1455          mock.patch(
1456              "mlflow.server.handlers._get_request_message",
1457              return_value=request_proto,
1458          ),
1459          mock.patch(
1460              "mlflow.server.handlers._get_tracking_store",
1461          ) as mock_store,
1462          mock.patch(
1463              "mlflow.server.handlers._get_artifact_repo",
1464              return_value=MockPresignedUploadRepo(),
1465          ),
1466      ):
1467          mock_store.return_value.get_run.return_value = mock_run
1468          response = _create_presigned_upload_url()
1469  
1470      assert response.status_code == 200
1471      assert captured_expiration["value"] == 900
1472  
1473  
1474  def test_create_presigned_upload_url_blocked_in_artifacts_only_mode(monkeypatch):
1475      from mlflow.server import ARTIFACTS_ONLY_ENV_VAR
1476  
1477      monkeypatch.setenv(ARTIFACTS_ONLY_ENV_VAR, "true")
1478  
1479      with app.test_request_context(method="POST", content_type="application/json"):
1480          response = _create_presigned_upload_url()
1481  
1482      assert response.status_code == 503
1483      assert "artifacts-only" in response.get_data(as_text=True).lower()
1484  
1485  
1486  @pytest.mark.parametrize(
1487      "uri",
1488      [
1489          "http://host#/abc/etc/",
1490          "http://host/;..%2F..%2Fetc",
1491      ],
1492  )
1493  def test_local_file_read_write_by_pass_vulnerability(uri):
1494      request = mock.MagicMock()
1495      request.method = "POST"
1496      request.content_type = "application/json; charset=utf-8"
1497      request.get_json = mock.MagicMock()
1498      request.get_json.return_value = {
1499          "name": "hello",
1500          "artifact_location": uri,
1501      }
1502      msg = _get_request_message(CreateExperiment(), flask_request=request)
1503      with mock.patch("mlflow.server.handlers._get_request_message", return_value=msg):
1504          response = _create_experiment()
1505          json_response = json.loads(response.get_data())
1506          assert (
1507              json_response["message"] == "'artifact_location' URL can't include fragments or params."
1508          )
1509  
1510      # Test if source is a local filesystem path, `_validate_source` validates that the run
1511      # artifact_uri is also a local filesystem path.
1512      run_id = uuid.uuid4().hex
1513      with mock.patch("mlflow.server.handlers._get_tracking_store") as mock_get_tracking_store:
1514          mock_get_tracking_store().get_run(
1515              run_id
1516          ).info.artifact_uri = f"http://host/{run_id}/artifacts/abc"
1517  
1518          with pytest.raises(
1519              MlflowException,
1520              match=(
1521                  "the run_id request parameter has to be specified and the local "
1522                  "path has to be contained within the artifact directory of the "
1523                  "run specified by the run_id"
1524              ),
1525          ):
1526              _validate_source_run("/local/path/xyz", run_id)
1527  
1528  
1529  @pytest.mark.parametrize(
1530      ("location", "expected_class", "expected_uri"),
1531      [
1532          ("file:///0/traces/123", LocalArtifactRepository, "file:///0/traces/123"),
1533          ("s3://bucket/0/traces/123", S3ArtifactRepository, "s3://bucket/0/traces/123"),
1534          (
1535              "wasbs://container@account.blob.core.windows.net/bucket/1/traces/123",
1536              AzureBlobArtifactRepository,
1537              "wasbs://container@account.blob.core.windows.net/bucket/1/traces/123",
1538          ),
1539          # Proxy URI must be resolved to the actual storage URI
1540          (
1541              "https://127.0.0.1/api/2.0/mlflow-artifacts/artifacts/2/traces/123",
1542              S3ArtifactRepository,
1543              "s3://bucket/2/traces/123",
1544          ),
1545          ("mlflow-artifacts:/1/traces/123", S3ArtifactRepository, "s3://bucket/1/traces/123"),
1546      ],
1547  )
1548  def test_get_trace_artifact_repo(location, expected_class, expected_uri, monkeypatch):
1549      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
1550      monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
1551      trace_info = TraceInfo(
1552          trace_id="123",
1553          trace_location=EntityTraceLocation.from_experiment_id("0"),
1554          request_time=0,
1555          execution_duration=1,
1556          state=TraceState.OK,
1557          tags={MLFLOW_ARTIFACT_LOCATION: location},
1558      )
1559      repo = _get_trace_artifact_repo(trace_info)
1560      assert isinstance(repo, expected_class)
1561      assert repo.artifact_uri == expected_uri
1562  
1563  
1564  ### Prompt Registry Tests ###
1565  def test_create_prompt_as_registered_model(mock_get_request_message, mock_model_registry_store):
1566      tags = [RegisteredModelTag(key=IS_PROMPT_TAG_KEY, value="true")]
1567      mock_get_request_message.return_value = CreateRegisteredModel(
1568          name="model_1", tags=[tag.to_proto() for tag in tags]
1569      )
1570      rm = RegisteredModel("model_1", tags=tags)
1571      mock_model_registry_store.create_registered_model.return_value = rm
1572      resp = _create_registered_model()
1573      _, args = mock_model_registry_store.create_registered_model.call_args
1574      assert args["name"] == "model_1"
1575      assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags}
1576      assert json.loads(resp.get_data()) == {"registered_model": jsonify(rm)}
1577  
1578  
1579  def test_create_prompt_as_model_version(mock_get_request_message, mock_model_registry_store):
1580      tags = [
1581          ModelVersionTag(key=IS_PROMPT_TAG_KEY, value="true"),
1582          ModelVersionTag(key=PROMPT_TEXT_TAG_KEY, value="some prompt text"),
1583      ]
1584      mock_get_request_message.return_value = CreateModelVersion(
1585          name="model_1",
1586          tags=[tag.to_proto() for tag in tags],
1587          source=None,
1588          run_id=None,
1589          run_link=None,
1590      )
1591      mv = ModelVersion(
1592          name="prompt_1", version="12", creation_timestamp=123, tags=tags, run_link=None
1593      )
1594      mock_model_registry_store.create_model_version.return_value = mv
1595      resp = _create_model_version()
1596      _, args = mock_model_registry_store.create_model_version.call_args
1597      assert args["name"] == "model_1"
1598      assert args["source"] == ""
1599      assert args["run_id"] == ""
1600      assert {tag.key: tag.value for tag in args["tags"]} == {tag.key: tag.value for tag in tags}
1601      assert args["run_link"] == ""
1602      assert json.loads(resp.get_data()) == {"model_version": jsonify(mv)}
1603  
1604  
1605  def test_create_evaluation_dataset(mock_tracking_store, mock_evaluation_dataset):
1606      mock_tracking_store.create_dataset.return_value = mock_evaluation_dataset
1607  
1608      with app.test_request_context(
1609          method="POST",
1610          json={
1611              "name": "test_dataset",
1612              "experiment_ids": ["0", "1"],
1613              "tags": json.dumps({"env": "test"}),
1614          },
1615      ):
1616          _create_dataset_handler()
1617  
1618      mock_tracking_store.create_dataset.assert_called_once_with(
1619          name="test_dataset",
1620          experiment_ids=["0", "1"],
1621          tags={"env": "test"},
1622      )
1623  
1624  
1625  def test_get_evaluation_dataset(mock_tracking_store, mock_evaluation_dataset):
1626      mock_tracking_store.get_dataset.return_value = mock_evaluation_dataset
1627  
1628      dataset_id = "d-1234567890abcdef1234567890abcdef"
1629      with app.test_request_context(method="GET"):
1630          _get_dataset_handler(dataset_id)
1631  
1632      mock_tracking_store.get_dataset.assert_called_once_with(dataset_id)
1633  
1634  
1635  def test_delete_evaluation_dataset(mock_tracking_store):
1636      dataset_id = "d-1234567890abcdef1234567890abcdef"
1637      with app.test_request_context(method="DELETE"):
1638          _delete_dataset_handler(dataset_id)
1639  
1640      mock_tracking_store.delete_dataset.assert_called_once_with(dataset_id)
1641  
1642  
1643  def test_search_datasets(mock_tracking_store):
1644      from mlflow.protos.datasets_pb2 import Dataset as ProtoDataset
1645  
1646      datasets = []
1647      for i in range(2):
1648          ds = mock.MagicMock()
1649          ds.name = f"dataset_{i}"
1650          proto = ProtoDataset()
1651          proto.dataset_id = f"d-{i:032d}"
1652          proto.name = ds.name
1653          ds.to_proto.return_value = proto
1654          datasets.append(ds)
1655  
1656      paged_list = PagedList(datasets, "next_token")
1657      mock_tracking_store.search_datasets.return_value = paged_list
1658  
1659      with app.test_request_context(
1660          method="POST",
1661          json={
1662              "experiment_ids": ["0", "1"],
1663              "filter_string": "name = 'dataset_1'",
1664              "max_results": 10,
1665              "order_by": ["name DESC"],
1666              "page_token": "token123",
1667          },
1668      ):
1669          _search_evaluation_datasets_handler()
1670  
1671      mock_tracking_store.search_datasets.assert_called_once_with(
1672          experiment_ids=["0", "1"],
1673          filter_string="name = 'dataset_1'",
1674          max_results=10,
1675          order_by=["name DESC"],
1676          page_token="token123",
1677      )
1678  
1679  
1680  def test_set_dataset_tags(mock_tracking_store):
1681      dataset_id = "d-1234567890abcdef1234567890abcdef"
1682      with app.test_request_context(
1683          method="POST",
1684          json={
1685              "tags": json.dumps({"env": "production", "version": "2.0"}),
1686          },
1687      ):
1688          _set_dataset_tags_handler(dataset_id)
1689  
1690      mock_tracking_store.set_dataset_tags.assert_called_once_with(
1691          dataset_id=dataset_id,
1692          tags={"env": "production", "version": "2.0"},
1693      )
1694  
1695  
1696  def test_delete_dataset_tag(mock_tracking_store):
1697      dataset_id = "d-1234567890abcdef1234567890abcdef"
1698      key = "deprecated_tag"
1699      with app.test_request_context(method="DELETE"):
1700          _delete_dataset_tag_handler(dataset_id, key)
1701  
1702      mock_tracking_store.delete_dataset_tag.assert_called_once_with(
1703          dataset_id=dataset_id,
1704          key=key,
1705      )
1706  
1707  
1708  def test_upsert_dataset_records(mock_tracking_store):
1709      mock_tracking_store.upsert_dataset_records.return_value = {
1710          "inserted": 2,
1711          "updated": 0,
1712      }
1713  
1714      dataset_id = "d-1234567890abcdef1234567890abcdef"
1715      records = [
1716          {"inputs": {"q": "test1"}, "expectations": {"score": 0.9}},
1717          {"inputs": {"q": "test2"}, "expectations": {"score": 0.8}},
1718      ]
1719  
1720      with app.test_request_context(
1721          method="POST",
1722          json={
1723              "records": json.dumps(records),
1724          },
1725      ):
1726          resp = _upsert_dataset_records_handler(dataset_id)
1727  
1728      mock_tracking_store.upsert_dataset_records.assert_called_once_with(
1729          dataset_id=dataset_id,
1730          records=records,
1731      )
1732  
1733      response_data = json.loads(resp.get_data())
1734      assert response_data["inserted_count"] == 2
1735      assert response_data["updated_count"] == 0
1736  
1737  
1738  def test_get_dataset_experiment_ids(mock_tracking_store):
1739      mock_tracking_store.get_dataset_experiment_ids.return_value = [
1740          "exp1",
1741          "exp2",
1742          "exp3",
1743      ]
1744  
1745      dataset_id = "d-1234567890abcdef1234567890abcdef"
1746      with app.test_request_context(method="GET"):
1747          resp = _get_dataset_experiment_ids_handler(dataset_id)
1748  
1749      mock_tracking_store.get_dataset_experiment_ids.assert_called_once_with(dataset_id=dataset_id)
1750  
1751      response_data = json.loads(resp.get_data())
1752      assert response_data["experiment_ids"] == ["exp1", "exp2", "exp3"]
1753  
1754  
1755  def test_get_dataset_records(mock_tracking_store):
1756      records = []
1757      for i in range(3):
1758          record = mock.MagicMock()
1759          record.dataset_id = "d-1234567890abcdef1234567890abcdef"
1760          record.dataset_record_id = f"r-00{i}"
1761          record.inputs = {"question": f"test{i}"}
1762          record.expectations = {"score": 0.9 - i * 0.1}
1763          record.tags = {}
1764          record.created_time = 1234567890 + i
1765          record.last_update_time = 1234567890 + i
1766          record.to_dict.return_value = {
1767              "dataset_id": record.dataset_id,
1768              "dataset_record_id": record.dataset_record_id,
1769              "inputs": record.inputs,
1770              "expectations": record.expectations,
1771              "tags": record.tags,
1772              "created_time": record.created_time,
1773              "last_update_time": record.last_update_time,
1774          }
1775          records.append(record)
1776  
1777      mock_tracking_store._load_dataset_records.return_value = (records, None)
1778  
1779      dataset_id = "d-1234567890abcdef1234567890abcdef"
1780      with app.test_request_context(method="GET"):
1781          resp = _get_dataset_records_handler(dataset_id)
1782  
1783      mock_tracking_store._load_dataset_records.assert_called_with(
1784          dataset_id, max_results=1000, page_token=None
1785      )
1786  
1787      response_data = json.loads(resp.get_data())
1788      records_data = json.loads(response_data["records"])
1789      assert len(records_data) == 3
1790      assert records_data[0]["dataset_record_id"] == "r-000"
1791  
1792      mock_tracking_store._load_dataset_records.return_value = (records[:2], "token_page2")
1793  
1794      with app.test_request_context(
1795          method="GET",
1796          json={
1797              "max_results": 2,
1798              "page_token": None,
1799          },
1800      ):
1801          resp = _get_dataset_records_handler(dataset_id)
1802  
1803      mock_tracking_store._load_dataset_records.assert_called_with(
1804          dataset_id, max_results=2, page_token=None
1805      )
1806  
1807      response_data = json.loads(resp.get_data())
1808      records_data = json.loads(response_data["records"])
1809      assert len(records_data) == 2
1810      assert response_data["next_page_token"] == "token_page2"
1811  
1812      mock_tracking_store._load_dataset_records.return_value = (records[2:], None)
1813  
1814      with app.test_request_context(
1815          method="GET",
1816          json={
1817              "max_results": 2,
1818              "page_token": "token_page2",
1819          },
1820      ):
1821          resp = _get_dataset_records_handler(dataset_id)
1822  
1823      mock_tracking_store._load_dataset_records.assert_called_with(
1824          dataset_id, max_results=2, page_token="token_page2"
1825      )
1826  
1827      response_data = json.loads(resp.get_data())
1828      records_data = json.loads(response_data["records"])
1829      assert len(records_data) == 1
1830      assert "next_page_token" not in response_data or response_data["next_page_token"] == ""
1831  
1832  
1833  def test_get_dataset_records_empty(mock_tracking_store):
1834      mock_tracking_store._load_dataset_records.return_value = ([], None)
1835  
1836      dataset_id = "d-1234567890abcdef1234567890abcdef"
1837      with app.test_request_context(method="GET"):
1838          resp = _get_dataset_records_handler(dataset_id)
1839  
1840      response_data = json.loads(resp.get_data())
1841      records_data = json.loads(response_data["records"])
1842      assert len(records_data) == 0
1843      assert "next_page_token" not in response_data or response_data["next_page_token"] == ""
1844  
1845  
1846  def test_get_dataset_records_pagination(mock_tracking_store):
1847      dataset_id = "d-1234567890abcdef1234567890abcdef"
1848      all_records = []
1849      for i in range(50):
1850          record = mock.Mock()
1851          record.dataset_record_id = f"r-{i:03d}"
1852          record.inputs = {"q": f"Question {i}"}
1853          record.expectations = {"a": f"Answer {i}"}
1854          record.tags = {}
1855          record.source_type = "TRACE"
1856          record.source_id = f"trace-{i}"
1857          record.created_time = 1609459200 + i
1858          record.to_dict.return_value = {
1859              "dataset_record_id": f"r-{i:03d}",
1860              "inputs": {"q": f"Question {i}"},
1861              "expectations": {"a": f"Answer {i}"},
1862              "tags": {},
1863              "source_type": "TRACE",
1864              "source_id": f"trace-{i}",
1865              "created_time": 1609459200 + i,
1866          }
1867          all_records.append(record)
1868      mock_tracking_store._load_dataset_records.return_value = (all_records[:20], "token_20")
1869  
1870      with app.test_request_context(
1871          method="GET",
1872          json={"max_results": 20},
1873      ):
1874          resp = _get_dataset_records_handler(dataset_id)
1875  
1876      mock_tracking_store._load_dataset_records.assert_called_with(
1877          dataset_id, max_results=20, page_token=None
1878      )
1879  
1880      response_data = json.loads(resp.get_data())
1881      records_data = json.loads(response_data["records"])
1882      assert len(records_data) == 20
1883      assert response_data["next_page_token"] == "token_20"
1884      assert records_data[0]["dataset_record_id"] == "r-000"
1885      assert records_data[19]["dataset_record_id"] == "r-019"
1886      mock_tracking_store._load_dataset_records.return_value = (all_records[20:40], "token_40")
1887  
1888      with app.test_request_context(
1889          method="GET",
1890          json={"max_results": 20, "page_token": "token_20"},
1891      ):
1892          resp = _get_dataset_records_handler(dataset_id)
1893  
1894      mock_tracking_store._load_dataset_records.assert_called_with(
1895          dataset_id, max_results=20, page_token="token_20"
1896      )
1897  
1898      response_data = json.loads(resp.get_data())
1899      records_data = json.loads(response_data["records"])
1900      assert len(records_data) == 20
1901      assert response_data["next_page_token"] == "token_40"
1902      assert records_data[0]["dataset_record_id"] == "r-020"
1903      mock_tracking_store._load_dataset_records.return_value = (all_records[40:], None)
1904  
1905      with app.test_request_context(
1906          method="GET",
1907          json={"max_results": 20, "page_token": "token_40"},
1908      ):
1909          resp = _get_dataset_records_handler(dataset_id)
1910  
1911      response_data = json.loads(resp.get_data())
1912      records_data = json.loads(response_data["records"])
1913      assert len(records_data) == 10
1914      assert "next_page_token" not in response_data or response_data["next_page_token"] == ""
1915      assert records_data[0]["dataset_record_id"] == "r-040"
1916      assert records_data[9]["dataset_record_id"] == "r-049"
1917  
1918  
1919  def test_register_scorer(mock_get_request_message, mock_tracking_store):
1920      experiment_id = "123"
1921      name = "accuracy_scorer"
1922      serialized_scorer = '{"name": "accuracy_scorer"}'
1923  
1924      mock_get_request_message.return_value = RegisterScorer(
1925          experiment_id=experiment_id, name=name, serialized_scorer=serialized_scorer
1926      )
1927  
1928      mock_scorer_version = ScorerVersion(
1929          experiment_id=experiment_id,
1930          scorer_name=name,
1931          scorer_version=1,
1932          serialized_scorer=serialized_scorer,
1933          creation_time=1234567890,
1934          scorer_id="test-scorer-id",
1935      )
1936      mock_tracking_store.register_scorer.return_value = mock_scorer_version
1937  
1938      resp = _register_scorer()
1939  
1940      mock_tracking_store.register_scorer.assert_called_once_with(
1941          experiment_id, name, serialized_scorer
1942      )
1943  
1944      response_data = json.loads(resp.get_data())
1945      assert response_data == {
1946          "version": 1,
1947          "scorer_id": "test-scorer-id",
1948          "experiment_id": experiment_id,
1949          "name": name,
1950          "serialized_scorer": serialized_scorer,
1951          "creation_time": 1234567890,
1952      }
1953  
1954  
1955  def test_register_scorer_rejects_decorator_scorer(mock_get_request_message, mock_tracking_store):
1956      from mlflow.genai.scorers.scorer_utils import DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR
1957  
1958      serialized_scorer = json.dumps({"name": "my_scorer", "call_source": "    return 1.0\n"})
1959      mock_get_request_message.return_value = RegisterScorer(
1960          experiment_id="123", name="my_scorer", serialized_scorer=serialized_scorer
1961      )
1962      resp = _register_scorer()
1963      assert resp.status_code == 400
1964      assert DECORATOR_SCORER_REGISTRATION_NOT_SUPPORTED_ERROR in resp.get_json()["message"]
1965      mock_tracking_store.register_scorer.assert_not_called()
1966  
1967  
1968  def test_list_scorers(mock_get_request_message, mock_tracking_store):
1969      experiment_id = "123"
1970  
1971      mock_get_request_message.return_value = ListScorers(experiment_id=experiment_id)
1972  
1973      # Create mock scorers
1974      scorers = [
1975          ScorerVersion(
1976              experiment_id=123,
1977              scorer_name="accuracy_scorer",
1978              scorer_version=1,
1979              serialized_scorer="serialized_accuracy_scorer",
1980              creation_time=12345,
1981          ),
1982          ScorerVersion(
1983              experiment_id=123,
1984              scorer_name="safety_scorer",
1985              scorer_version=2,
1986              serialized_scorer="serialized_safety_scorer",
1987              creation_time=12345,
1988          ),
1989      ]
1990  
1991      mock_tracking_store.list_scorers.return_value = scorers
1992  
1993      resp = _list_scorers()
1994  
1995      # Verify the tracking store was called with correct arguments
1996      mock_tracking_store.list_scorers.assert_called_once_with(experiment_id)
1997  
1998      # Verify the response
1999      response_data = json.loads(resp.get_data())
2000      assert len(response_data["scorers"]) == 2
2001      assert response_data["scorers"][0]["scorer_name"] == "accuracy_scorer"
2002      assert response_data["scorers"][0]["scorer_version"] == 1
2003      assert response_data["scorers"][0]["serialized_scorer"] == "serialized_accuracy_scorer"
2004      assert response_data["scorers"][1]["scorer_name"] == "safety_scorer"
2005      assert response_data["scorers"][1]["scorer_version"] == 2
2006      assert response_data["scorers"][1]["serialized_scorer"] == "serialized_safety_scorer"
2007  
2008  
2009  def test_list_scorer_versions(mock_get_request_message, mock_tracking_store):
2010      experiment_id = "123"
2011      name = "accuracy_scorer"
2012  
2013      mock_get_request_message.return_value = ListScorerVersions(
2014          experiment_id=experiment_id, name=name
2015      )
2016  
2017      # Create mock scorers with multiple versions
2018      scorers = [
2019          ScorerVersion(
2020              experiment_id=123,
2021              scorer_name="accuracy_scorer",
2022              scorer_version=1,
2023              serialized_scorer="serialized_accuracy_scorer_v1",
2024              creation_time=12345,
2025          ),
2026          ScorerVersion(
2027              experiment_id=123,
2028              scorer_name="accuracy_scorer",
2029              scorer_version=2,
2030              serialized_scorer="serialized_accuracy_scorer_v2",
2031              creation_time=12345,
2032          ),
2033      ]
2034  
2035      mock_tracking_store.list_scorer_versions.return_value = scorers
2036  
2037      resp = _list_scorer_versions()
2038  
2039      # Verify the tracking store was called with correct arguments
2040      mock_tracking_store.list_scorer_versions.assert_called_once_with(experiment_id, name)
2041  
2042      # Verify the response
2043      response_data = json.loads(resp.get_data())
2044      assert len(response_data["scorers"]) == 2
2045      assert response_data["scorers"][0]["scorer_version"] == 1
2046      assert response_data["scorers"][0]["serialized_scorer"] == "serialized_accuracy_scorer_v1"
2047      assert response_data["scorers"][1]["scorer_version"] == 2
2048      assert response_data["scorers"][1]["serialized_scorer"] == "serialized_accuracy_scorer_v2"
2049  
2050  
2051  def test_get_scorer_with_version(mock_get_request_message, mock_tracking_store):
2052      experiment_id = "123"
2053      name = "accuracy_scorer"
2054      version = 2
2055  
2056      mock_get_request_message.return_value = GetScorer(
2057          experiment_id=experiment_id, name=name, version=version
2058      )
2059  
2060      # Mock the return value as a ScorerVersion entity
2061      mock_scorer_version = ScorerVersion(
2062          experiment_id=123,
2063          scorer_name="accuracy_scorer",
2064          scorer_version=2,
2065          serialized_scorer="serialized_accuracy_scorer_v2",
2066          creation_time=1640995200000,
2067      )
2068      mock_tracking_store.get_scorer.return_value = mock_scorer_version
2069  
2070      resp = _get_scorer()
2071  
2072      # Verify the tracking store was called with correct arguments (positional)
2073      mock_tracking_store.get_scorer.assert_called_once_with(experiment_id, name, version)
2074  
2075      # Verify the response
2076      response_data = json.loads(resp.get_data())
2077      assert response_data["scorer"]["experiment_id"] == 123
2078      assert response_data["scorer"]["scorer_name"] == "accuracy_scorer"
2079      assert response_data["scorer"]["scorer_version"] == 2
2080      assert response_data["scorer"]["serialized_scorer"] == "serialized_accuracy_scorer_v2"
2081      assert response_data["scorer"]["creation_time"] == 1640995200000
2082  
2083  
2084  def test_get_scorer_without_version(mock_get_request_message, mock_tracking_store):
2085      experiment_id = "123"
2086      name = "accuracy_scorer"
2087  
2088      mock_get_request_message.return_value = GetScorer(experiment_id=experiment_id, name=name)
2089  
2090      # Mock the return value as a ScorerVersion entity
2091      mock_scorer_version = ScorerVersion(
2092          experiment_id=123,
2093          scorer_name="accuracy_scorer",
2094          scorer_version=3,
2095          serialized_scorer="serialized_accuracy_scorer_latest",
2096          creation_time=1640995200000,
2097      )
2098      mock_tracking_store.get_scorer.return_value = mock_scorer_version
2099  
2100      resp = _get_scorer()
2101  
2102      # Verify the tracking store was called with correct arguments (positional, version=None)
2103      mock_tracking_store.get_scorer.assert_called_once_with(experiment_id, name, None)
2104  
2105      # Verify the response
2106      response_data = json.loads(resp.get_data())
2107      assert response_data["scorer"]["experiment_id"] == 123
2108      assert response_data["scorer"]["scorer_name"] == "accuracy_scorer"
2109      assert response_data["scorer"]["scorer_version"] == 3
2110      assert response_data["scorer"]["serialized_scorer"] == "serialized_accuracy_scorer_latest"
2111      assert response_data["scorer"]["creation_time"] == 1640995200000
2112  
2113  
2114  def test_delete_scorer_with_version(mock_get_request_message, mock_tracking_store):
2115      experiment_id = "123"
2116      name = "accuracy_scorer"
2117      version = 2
2118  
2119      mock_get_request_message.return_value = DeleteScorer(
2120          experiment_id=experiment_id, name=name, version=version
2121      )
2122  
2123      resp = _delete_scorer()
2124  
2125      # Verify the tracking store was called with correct arguments (positional)
2126      mock_tracking_store.delete_scorer.assert_called_once_with(experiment_id, name, version)
2127  
2128      # Verify the response (should be empty for delete operations)
2129      response_data = json.loads(resp.get_data())
2130      assert response_data == {}
2131  
2132  
2133  def test_delete_scorer_without_version(mock_get_request_message, mock_tracking_store):
2134      experiment_id = "123"
2135      name = "accuracy_scorer"
2136  
2137      mock_get_request_message.return_value = DeleteScorer(experiment_id=experiment_id, name=name)
2138  
2139      resp = _delete_scorer()
2140  
2141      # Verify the tracking store was called with correct arguments (positional, version=None)
2142      mock_tracking_store.delete_scorer.assert_called_once_with(experiment_id, name, None)
2143  
2144      # Verify the response (should be empty for delete operations)
2145      response_data = json.loads(resp.get_data())
2146      assert response_data == {}
2147  
2148  
2149  def test_get_online_scoring_configs_batch(mock_tracking_store):
2150      mock_configs = [
2151          OnlineScoringConfig(
2152              online_scoring_config_id="cfg-1",
2153              scorer_id="scorer-1",
2154              sample_rate=0.5,
2155              filter_string="status = 'OK'",
2156              experiment_id="exp1",
2157          ),
2158          OnlineScoringConfig(
2159              online_scoring_config_id="cfg-2",
2160              scorer_id="scorer-2",
2161              sample_rate=0.8,
2162              experiment_id="exp1",
2163          ),
2164      ]
2165      mock_tracking_store.get_online_scoring_configs.return_value = mock_configs
2166  
2167      with app.test_client() as c:
2168          resp = c.get(
2169              "/ajax-api/3.0/mlflow/scorers/online-configs",
2170              query_string=[("scorer_ids", "scorer-1"), ("scorer_ids", "scorer-2")],
2171          )
2172          assert resp.status_code == 200
2173          data = resp.get_json()
2174          assert "configs" in data
2175          assert isinstance(data["configs"], list)
2176          assert len(data["configs"]) == 2
2177          configs_by_id = {c["scorer_id"]: c for c in data["configs"]}
2178          assert configs_by_id["scorer-1"]["sample_rate"] == 0.5
2179          assert configs_by_id["scorer-1"]["filter_string"] == "status = 'OK'"
2180          assert configs_by_id["scorer-2"]["sample_rate"] == 0.8
2181          assert configs_by_id["scorer-2"].get("filter_string") is None
2182  
2183      mock_tracking_store.get_online_scoring_configs.assert_called_once_with(["scorer-1", "scorer-2"])
2184  
2185  
2186  def test_get_online_scoring_configs_missing_param(mock_tracking_store):
2187      with app.test_client() as c:
2188          resp = c.get(
2189              "/ajax-api/3.0/mlflow/scorers/online-configs",
2190          )
2191          assert resp.status_code == 400
2192          data = resp.get_json()
2193          assert "scorer_ids" in data["message"]
2194  
2195  
2196  def test_calculate_trace_filter_correlation(mock_get_request_message, mock_tracking_store):
2197      experiment_ids = ["123", "456"]
2198      filter_string1 = "span.type = 'LLM'"
2199      filter_string2 = "feedback.quality > 0.8"
2200      base_filter = "request_time > 1000"
2201  
2202      mock_request = CalculateTraceFilterCorrelation(
2203          experiment_ids=experiment_ids,
2204          filter_string1=filter_string1,
2205          filter_string2=filter_string2,
2206          base_filter=base_filter,
2207      )
2208      mock_get_request_message.return_value = mock_request
2209  
2210      mock_result = TraceFilterCorrelationResult(
2211          npmi=0.456,
2212          npmi_smoothed=0.445,
2213          filter1_count=100,
2214          filter2_count=80,
2215          joint_count=50,
2216          total_count=200,
2217      )
2218      mock_tracking_store.calculate_trace_filter_correlation.return_value = mock_result
2219  
2220      resp = _calculate_trace_filter_correlation()
2221  
2222      mock_tracking_store.calculate_trace_filter_correlation.assert_called_once_with(
2223          experiment_ids=experiment_ids,
2224          filter_string1=filter_string1,
2225          filter_string2=filter_string2,
2226          base_filter=base_filter,
2227      )
2228  
2229      response_data = json.loads(resp.get_data())
2230      assert response_data["npmi"] == 0.456
2231      assert response_data["npmi_smoothed"] == 0.445
2232      assert response_data["filter1_count"] == 100
2233      assert response_data["filter2_count"] == 80
2234      assert response_data["joint_count"] == 50
2235      assert response_data["total_count"] == 200
2236  
2237  
2238  def test_calculate_trace_filter_correlation_without_base_filter(
2239      mock_get_request_message, mock_tracking_store
2240  ):
2241      experiment_ids = ["123"]
2242      filter_string1 = "span.type = 'LLM'"
2243      filter_string2 = "feedback.quality > 0.8"
2244  
2245      mock_request = CalculateTraceFilterCorrelation(
2246          experiment_ids=experiment_ids,
2247          filter_string1=filter_string1,
2248          filter_string2=filter_string2,
2249      )
2250      mock_get_request_message.return_value = mock_request
2251  
2252      mock_result = TraceFilterCorrelationResult(
2253          npmi=0.789,
2254          npmi_smoothed=0.775,
2255          filter1_count=50,
2256          filter2_count=40,
2257          joint_count=30,
2258          total_count=100,
2259      )
2260      mock_tracking_store.calculate_trace_filter_correlation.return_value = mock_result
2261  
2262      resp = _calculate_trace_filter_correlation()
2263  
2264      mock_tracking_store.calculate_trace_filter_correlation.assert_called_once_with(
2265          experiment_ids=experiment_ids,
2266          filter_string1=filter_string1,
2267          filter_string2=filter_string2,
2268          base_filter=None,
2269      )
2270  
2271      response_data = json.loads(resp.get_data())
2272      assert response_data["npmi"] == 0.789
2273      assert response_data["npmi_smoothed"] == 0.775
2274      assert response_data["filter1_count"] == 50
2275      assert response_data["filter2_count"] == 40
2276      assert response_data["joint_count"] == 30
2277      assert response_data["total_count"] == 100
2278  
2279  
2280  def test_calculate_trace_filter_correlation_with_nan_npmi(
2281      mock_get_request_message, mock_tracking_store
2282  ):
2283      experiment_ids = ["123"]
2284      filter_string1 = "span.type = 'LLM'"
2285      filter_string2 = "feedback.quality > 0.8"
2286  
2287      mock_request = CalculateTraceFilterCorrelation(
2288          experiment_ids=experiment_ids,
2289          filter_string1=filter_string1,
2290          filter_string2=filter_string2,
2291      )
2292      mock_get_request_message.return_value = mock_request
2293  
2294      mock_result = TraceFilterCorrelationResult(
2295          npmi=float("nan"),
2296          npmi_smoothed=None,
2297          filter1_count=0,
2298          filter2_count=0,
2299          joint_count=0,
2300          total_count=100,
2301      )
2302      mock_tracking_store.calculate_trace_filter_correlation.return_value = mock_result
2303  
2304      resp = _calculate_trace_filter_correlation()
2305  
2306      mock_tracking_store.calculate_trace_filter_correlation.assert_called_once_with(
2307          experiment_ids=experiment_ids,
2308          filter_string1=filter_string1,
2309          filter_string2=filter_string2,
2310          base_filter=None,
2311      )
2312  
2313      response_data = json.loads(resp.get_data())
2314      assert "npmi" not in response_data
2315      assert "npmi_smoothed" not in response_data
2316      assert response_data["filter1_count"] == 0
2317      assert response_data["filter2_count"] == 0
2318      assert response_data["joint_count"] == 0
2319      assert response_data["total_count"] == 100
2320  
2321  
2322  def test_databricks_tracking_store_registration():
2323      registry = TrackingStoreRegistryWrapper()
2324  
2325      # Test that the correct store type is returned for databricks scheme
2326      store = registry.get_store("databricks", artifact_uri=None)
2327      assert isinstance(store, DatabricksTracingRestStore)
2328  
2329      # Verify that the store was created with the right get_host_creds function
2330      # The RestStore should have a get_host_creds attribute that is a partial function
2331      assert hasattr(store, "get_host_creds")
2332      assert store.get_host_creds.func.__name__ == "get_databricks_host_creds"
2333      assert store.get_host_creds.args == ("databricks",)
2334  
2335  
2336  def test_databricks_model_registry_store_registration():
2337      registry = ModelRegistryStoreRegistryWrapper()
2338  
2339      # Test that the correct store type is returned for databricks
2340      store = registry.get_store("databricks")
2341      assert isinstance(store, ModelRegistryRestStore)
2342  
2343      # Verify that the store was created with the right get_host_creds function
2344      assert hasattr(store, "get_host_creds")
2345      assert store.get_host_creds.func.__name__ == "get_databricks_host_creds"
2346      assert store.get_host_creds.args == ("databricks",)
2347  
2348      # Test that the correct store type is returned for databricks-uc
2349      uc_store = registry.get_store("databricks-uc")
2350      assert isinstance(uc_store, UcModelRegistryStore)
2351  
2352      # Verify that the UC store was created with the right get_host_creds function
2353      # Note: UcModelRegistryStore uses get_databricks_host_creds internally,
2354      # not get_databricks_uc_host_creds
2355      assert hasattr(uc_store, "get_host_creds")
2356      assert uc_store.get_host_creds.func.__name__ == "get_databricks_host_creds"
2357      assert uc_store.get_host_creds.args == ("databricks-uc",)
2358  
2359      # Also verify it has tracking_uri set
2360      assert hasattr(uc_store, "tracking_uri")
2361      # The tracking_uri will be set based on environment/test config
2362      # In test environment, it may be set to a test sqlite database
2363      assert uc_store.tracking_uri is not None
2364  
2365  
2366  def test_search_experiments_empty_page_token(mock_get_request_message, mock_tracking_store):
2367      # Create proto without setting page_token - it defaults to empty string
2368      search_experiments_proto = SearchExperiments()
2369      search_experiments_proto.max_results = 10
2370  
2371      # Verify that proto's default page_token is empty string
2372      assert search_experiments_proto.page_token == ""
2373  
2374      mock_get_request_message.return_value = search_experiments_proto
2375      mock_tracking_store.search_experiments.return_value = PagedList([], None)
2376  
2377      _search_experiments()
2378  
2379      # Verify that search_experiments was called with page_token=None (not empty string)
2380      mock_tracking_store.search_experiments.assert_called_once()
2381      call_kwargs = mock_tracking_store.search_experiments.call_args.kwargs
2382      assert call_kwargs.get("page_token") is None
2383      assert call_kwargs.get("max_results") == 10
2384  
2385  
2386  def test_search_registered_models_empty_page_token(
2387      mock_get_request_message, mock_model_registry_store
2388  ):
2389      # Create proto without setting page_token - it defaults to empty string
2390      search_registered_models_proto = SearchRegisteredModels()
2391      search_registered_models_proto.max_results = 10
2392  
2393      # Verify that proto's default page_token is empty string
2394      assert search_registered_models_proto.page_token == ""
2395  
2396      mock_get_request_message.return_value = search_registered_models_proto
2397      mock_model_registry_store.search_registered_models.return_value = PagedList([], None)
2398  
2399      _search_registered_models()
2400  
2401      # Verify that search_registered_models was called with page_token=None (not empty string)
2402      mock_model_registry_store.search_registered_models.assert_called_once()
2403      call_kwargs = mock_model_registry_store.search_registered_models.call_args.kwargs
2404      assert call_kwargs.get("page_token") is None
2405      assert call_kwargs.get("max_results") == 10
2406  
2407  
2408  def test_search_model_versions_empty_page_token(
2409      mock_get_request_message, mock_model_registry_store
2410  ):
2411      # Create proto without setting page_token - it defaults to empty string
2412      search_model_versions_proto = SearchModelVersions()
2413      search_model_versions_proto.max_results = 10
2414  
2415      # Verify that proto's default page_token is empty string
2416      assert search_model_versions_proto.page_token == ""
2417  
2418      mock_get_request_message.return_value = search_model_versions_proto
2419      mock_model_registry_store.search_model_versions.return_value = PagedList([], None)
2420  
2421      _search_model_versions()
2422  
2423      # Verify that search_model_versions was called with page_token=None (not empty string)
2424      mock_model_registry_store.search_model_versions.assert_called_once()
2425      call_kwargs = mock_model_registry_store.search_model_versions.call_args.kwargs
2426      assert call_kwargs.get("page_token") is None
2427      assert call_kwargs.get("max_results") == 10
2428  
2429  
2430  def test_search_traces_v3_empty_page_token(mock_get_request_message, mock_tracking_store):
2431      # Create proto without setting page_token - it defaults to empty string
2432      # SearchTracesV3 requires locations field
2433      search_traces_proto = SearchTracesV3()
2434      location = TraceLocation()
2435      location.mlflow_experiment.experiment_id = "1"
2436      search_traces_proto.locations.append(location)
2437      search_traces_proto.max_results = 10
2438  
2439      # Verify that proto's default page_token is empty string
2440      assert search_traces_proto.page_token == ""
2441  
2442      mock_get_request_message.return_value = search_traces_proto
2443      mock_tracking_store.search_traces.return_value = ([], None)
2444  
2445      _search_traces_v3()
2446  
2447      # Verify that search_traces was called with page_token=None (not empty string)
2448      mock_tracking_store.search_traces.assert_called_once()
2449      call_kwargs = mock_tracking_store.search_traces.call_args.kwargs
2450      assert call_kwargs.get("page_token") is None
2451      assert call_kwargs.get("max_results") == 10
2452  
2453  
2454  def test_deprecated_search_traces_v2_empty_page_token(
2455      mock_get_request_message, mock_tracking_store
2456  ):
2457      # Create proto without setting page_token - it defaults to empty string
2458      search_traces_proto = SearchTraces()
2459      search_traces_proto.max_results = 10
2460  
2461      # Verify that proto's default page_token is empty string
2462      assert search_traces_proto.page_token == ""
2463  
2464      mock_get_request_message.return_value = search_traces_proto
2465      mock_tracking_store.search_traces.return_value = ([], None)
2466  
2467      _deprecated_search_traces_v2()
2468  
2469      # Verify that search_traces was called with page_token=None (not empty string)
2470      mock_tracking_store.search_traces.assert_called_once()
2471      call_kwargs = mock_tracking_store.search_traces.call_args.kwargs
2472      assert call_kwargs.get("page_token") is None
2473      assert call_kwargs.get("max_results") == 10
2474  
2475  
2476  def test_search_logged_models_empty_page_token(mock_get_request_message, mock_tracking_store):
2477      # Create proto without setting page_token - it defaults to empty string
2478      search_logged_models_proto = SearchLoggedModels()
2479      search_logged_models_proto.max_results = 10
2480  
2481      # Verify that proto's default page_token is empty string
2482      assert search_logged_models_proto.page_token == ""
2483  
2484      mock_get_request_message.return_value = search_logged_models_proto
2485      mock_tracking_store.search_logged_models.return_value = PagedList([], None)
2486  
2487      _search_logged_models()
2488  
2489      # Verify that search_logged_models was called with page_token=None (not empty string)
2490      mock_tracking_store.search_logged_models.assert_called_once()
2491      call_kwargs = mock_tracking_store.search_logged_models.call_args.kwargs
2492      assert call_kwargs.get("page_token") is None
2493      assert call_kwargs.get("max_results") == 10
2494  
2495  
2496  def test_list_webhooks_empty_page_token(mock_get_request_message, mock_model_registry_store):
2497      # Create proto without setting page_token - it defaults to empty string
2498      list_webhooks_proto = ListWebhooks()
2499      list_webhooks_proto.max_results = 10
2500  
2501      # Verify that proto's default page_token is empty string
2502      assert list_webhooks_proto.page_token == ""
2503  
2504      mock_get_request_message.return_value = list_webhooks_proto
2505      mock_model_registry_store.list_webhooks.return_value = PagedList([], None)
2506  
2507      _list_webhooks()
2508  
2509      # Verify that list_webhooks was called with page_token=None (not empty string)
2510      mock_model_registry_store.list_webhooks.assert_called_once()
2511      call_kwargs = mock_model_registry_store.list_webhooks.call_args.kwargs
2512      assert call_kwargs.get("page_token") is None
2513      assert call_kwargs.get("max_results") == 10
2514  
2515  
2516  def test_batch_get_traces_handler(mock_get_request_message, mock_tracking_store):
2517      trace_id_1 = "test-trace-123"
2518      trace_id_2 = "test-trace-456"
2519  
2520      get_traces_proto = BatchGetTraces(trace_ids=[trace_id_1, trace_id_2])
2521  
2522      mock_get_request_message.return_value = get_traces_proto
2523  
2524      otel_span = OTelReadableSpan(
2525          name="test",
2526          context=build_otel_context(123, 234),
2527          parent=None,
2528          start_time=100,
2529          end_time=200,
2530          attributes={
2531              "mlflow.spanInputs": json.dumps("inputs"),
2532              "mlflow.spanOutputs": json.dumps("outputs"),
2533              "mlflow.spanType": json.dumps("span_type"),
2534          },
2535      )
2536      mock_span = Span(otel_span)
2537  
2538      # Create mock traces to return
2539      mock_trace_1 = Trace(
2540          info=TraceInfo(
2541              trace_id=trace_id_1,
2542              trace_location=EntityTraceLocation.from_experiment_id("1"),
2543              request_time=1234567890,
2544              execution_duration=5000,
2545              state=TraceState.OK,
2546          ),
2547          data=TraceData(spans=[mock_span]),
2548      )
2549  
2550      mock_trace_2 = Trace(
2551          info=TraceInfo(
2552              trace_id=trace_id_2,
2553              trace_location=EntityTraceLocation.from_experiment_id("1"),
2554              request_time=1234567890,
2555              execution_duration=3000,
2556              state=TraceState.OK,
2557          ),
2558          data=TraceData(spans=[mock_span]),
2559      )
2560  
2561      mock_tracking_store.batch_get_traces.return_value = [mock_trace_1, mock_trace_2]
2562  
2563      # Call the handler
2564      response = _batch_get_traces()
2565  
2566      # Verify the store was called with the correct trace IDs
2567      mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id_1, trace_id_2], None)
2568  
2569      # Verify response was created
2570      assert response is not None
2571      assert response.status_code == 200
2572      traces = json.loads(response.get_data())["traces"]
2573      assert len(traces) == 2
2574      assert len(traces[0]["spans"]) == 1
2575      assert len(traces[1]["spans"]) == 1
2576  
2577  
2578  def test_batch_get_traces_handler_empty_list(mock_get_request_message, mock_tracking_store):
2579      get_traces_proto = BatchGetTraces()
2580  
2581      mock_get_request_message.return_value = get_traces_proto
2582      mock_tracking_store.batch_get_traces.return_value = []
2583  
2584      response = _batch_get_traces()
2585  
2586      mock_tracking_store.batch_get_traces.assert_called_once_with([], None)
2587  
2588      # Verify response was created
2589      assert response is not None
2590      assert response.status_code == 200
2591  
2592  
2593  def test_batch_get_trace_infos_handler(mock_get_request_message, mock_tracking_store):
2594      trace_id_1 = "test-trace-123"
2595      trace_id_2 = "test-trace-456"
2596  
2597      mock_get_request_message.return_value = BatchGetTraceInfos(trace_ids=[trace_id_1, trace_id_2])
2598  
2599      mock_trace_info_1 = TraceInfo(
2600          trace_id=trace_id_1,
2601          trace_location=EntityTraceLocation.from_experiment_id("1"),
2602          request_time=1234567890,
2603          execution_duration=5000,
2604          state=TraceState.OK,
2605      )
2606      mock_trace_info_2 = TraceInfo(
2607          trace_id=trace_id_2,
2608          trace_location=EntityTraceLocation.from_experiment_id("1"),
2609          request_time=1234567890,
2610          execution_duration=3000,
2611          state=TraceState.OK,
2612      )
2613  
2614      mock_tracking_store.batch_get_trace_infos.return_value = [
2615          mock_trace_info_1,
2616          mock_trace_info_2,
2617      ]
2618  
2619      response = _batch_get_trace_infos()
2620  
2621      mock_tracking_store.batch_get_trace_infos.assert_called_once_with([trace_id_1, trace_id_2])
2622  
2623      assert response is not None
2624      assert response.status_code == 200
2625      trace_infos = json.loads(response.get_data())["trace_infos"]
2626      assert len(trace_infos) == 2
2627      assert trace_infos[0]["trace_id"] == trace_id_1
2628      assert trace_infos[1]["trace_id"] == trace_id_2
2629  
2630  
2631  def test_get_trace_handler(mock_get_request_message, mock_tracking_store):
2632      trace_id = "test-trace-123"
2633  
2634      get_trace_proto = GetTrace(trace_id=trace_id, allow_partial=True)
2635      mock_get_request_message.return_value = get_trace_proto
2636  
2637      otel_span = OTelReadableSpan(
2638          name="test",
2639          context=build_otel_context(123, 234),
2640          parent=None,
2641          start_time=100,
2642          end_time=200,
2643          attributes={
2644              "mlflow.spanInputs": json.dumps("inputs"),
2645              "mlflow.spanOutputs": json.dumps("outputs"),
2646              "mlflow.spanType": json.dumps("span_type"),
2647          },
2648      )
2649      mock_span = Span(otel_span)
2650  
2651      mock_trace = Trace(
2652          info=TraceInfo(
2653              trace_id=trace_id,
2654              trace_location=EntityTraceLocation.from_experiment_id("1"),
2655              request_time=1234567890,
2656              execution_duration=5000,
2657              state=TraceState.OK,
2658          ),
2659          data=TraceData(spans=[mock_span]),
2660      )
2661  
2662      mock_tracking_store.get_trace.return_value = mock_trace
2663  
2664      response = _get_trace()
2665  
2666      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True)
2667  
2668      assert response is not None
2669      assert response.status_code == 200
2670      response_data = json.loads(response.get_data())
2671      assert "trace" in response_data
2672      trace = response_data["trace"]
2673      assert trace["trace_info"]["trace_id"] == trace_id
2674      assert len(trace["spans"]) == 1
2675  
2676  
2677  def test_get_trace_handler_with_allow_partial_false(mock_get_request_message, mock_tracking_store):
2678      trace_id = "test-trace-456"
2679  
2680      get_trace_proto = GetTrace(trace_id=trace_id, allow_partial=False)
2681      mock_get_request_message.return_value = get_trace_proto
2682  
2683      otel_span = OTelReadableSpan(
2684          name="test",
2685          context=build_otel_context(123, 234),
2686          parent=None,
2687          start_time=100,
2688          end_time=200,
2689          attributes={},
2690      )
2691      mock_span = Span(otel_span)
2692  
2693      mock_trace = Trace(
2694          info=TraceInfo(
2695              trace_id=trace_id,
2696              trace_location=EntityTraceLocation.from_experiment_id("1"),
2697              request_time=1234567890,
2698              execution_duration=5000,
2699              state=TraceState.OK,
2700          ),
2701          data=TraceData(spans=[mock_span]),
2702      )
2703  
2704      mock_tracking_store.get_trace.return_value = mock_trace
2705  
2706      response = _get_trace()
2707  
2708      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=False)
2709  
2710      assert response is not None
2711      assert response.status_code == 200
2712      response_data = json.loads(response.get_data())
2713      assert "trace" in response_data
2714  
2715  
2716  def test_get_trace_handler_not_found(mock_get_request_message, mock_tracking_store):
2717      trace_id = "non-existent-trace"
2718  
2719      get_trace_proto = GetTrace(trace_id=trace_id)
2720      mock_get_request_message.return_value = get_trace_proto
2721  
2722      mock_tracking_store.get_trace.side_effect = MlflowException(
2723          f"Trace with ID {trace_id} is not found.",
2724          error_code=RESOURCE_DOES_NOT_EXIST,
2725      )
2726  
2727      response = _get_trace()
2728  
2729      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=False)
2730  
2731      assert response is not None
2732      assert response.status_code == 404
2733      response_data = json.loads(response.get_data())
2734      assert "error_code" in response_data
2735      assert response_data["error_code"] == "RESOURCE_DOES_NOT_EXIST"
2736  
2737  
2738  def test_get_trace_artifact_handler(mock_tracking_store):
2739      trace_id = "test-trace-artifact-123"
2740  
2741      otel_span = OTelReadableSpan(
2742          name="test_span",
2743          context=build_otel_context(123, 234),
2744          parent=None,
2745          start_time=100,
2746          end_time=200,
2747          attributes={
2748              "mlflow.spanInputs": json.dumps({"input": "test_input"}),
2749              "mlflow.spanOutputs": json.dumps({"output": "test_output"}),
2750          },
2751      )
2752      mock_span = Span(otel_span)
2753  
2754      mock_trace = Trace(
2755          info=TraceInfo(
2756              trace_id=trace_id,
2757              trace_location=EntityTraceLocation.from_experiment_id("1"),
2758              request_time=1234567890,
2759              execution_duration=5000,
2760              state=TraceState.OK,
2761          ),
2762          data=TraceData(spans=[mock_span]),
2763      )
2764  
2765      mock_tracking_store.get_trace.return_value = mock_trace
2766      mock_tracking_store.batch_get_traces.return_value = [mock_trace]
2767  
2768      with app.test_request_context(method="GET", query_string={"request_id": trace_id}):
2769          response = get_trace_artifact_handler()
2770  
2771      # Verify the store was called correctly
2772      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True)
2773  
2774      # Verify response headers and status
2775      assert response is not None
2776      assert response.status_code == 200
2777      assert response.headers["Content-Disposition"] == "attachment; filename=traces.json"
2778  
2779  
2780  def test_get_trace_artifact_handler_missing_request_id(mock_tracking_store):
2781      with app.test_request_context(method="GET"):
2782          response = get_trace_artifact_handler()
2783  
2784      assert response.status_code == 400
2785      response_data = json.loads(response.get_data())
2786      assert "error_code" in response_data
2787      assert response_data["error_code"] == "BAD_REQUEST"
2788      assert 'must include the "request_id" query parameter' in response_data["message"]
2789  
2790  
2791  def test_get_trace_artifact_handler_trace_not_found(mock_tracking_store):
2792      trace_id = "non-existent-trace"
2793      mock_tracking_store.get_trace.side_effect = MlflowException(
2794          f"Trace with ID {trace_id} is not found.",
2795          error_code=RESOURCE_DOES_NOT_EXIST,
2796      )
2797  
2798      with app.test_request_context(method="GET", query_string={"request_id": trace_id}):
2799          response = get_trace_artifact_handler()
2800  
2801      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True)
2802  
2803      assert response.status_code == 404
2804      response_data = json.loads(response.get_data())
2805      assert "error_code" in response_data
2806      assert response_data["error_code"] == "RESOURCE_DOES_NOT_EXIST"
2807      assert f"Trace with ID {trace_id} is not found" in response_data["message"]
2808  
2809  
2810  def test_get_trace_artifact_handler_fallback_to_batch_get_traces(mock_tracking_store):
2811      trace_id = "test-trace-batch-123"
2812  
2813      otel_span = OTelReadableSpan(
2814          name="test_span_batch",
2815          context=build_otel_context(456, 789),
2816          parent=None,
2817          start_time=100,
2818          end_time=200,
2819          attributes={
2820              "mlflow.spanInputs": json.dumps({"input": "batch_input"}),
2821              "mlflow.spanOutputs": json.dumps({"output": "batch_output"}),
2822          },
2823      )
2824      mock_span = Span(otel_span)
2825  
2826      mock_trace = Trace(
2827          info=TraceInfo(
2828              trace_id=trace_id,
2829              trace_location=EntityTraceLocation.from_experiment_id("2"),
2830              request_time=1234567890,
2831              execution_duration=3000,
2832              state=TraceState.OK,
2833          ),
2834          data=TraceData(spans=[mock_span]),
2835      )
2836  
2837      # Simulate get_trace not being implemented
2838      mock_tracking_store.get_trace.side_effect = MlflowNotImplementedException(
2839          "get_trace is not implemented"
2840      )
2841      mock_tracking_store.batch_get_traces.return_value = [mock_trace]
2842  
2843      with app.test_request_context(method="GET", query_string={"request_id": trace_id}):
2844          response = get_trace_artifact_handler()
2845  
2846      # Verify both methods were called
2847      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True)
2848      mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id], None)
2849  
2850      # Verify successful response
2851      assert response is not None
2852      assert response.status_code == 200
2853      assert response.headers["Content-Disposition"] == "attachment; filename=traces.json"
2854  
2855  
2856  def test_get_trace_artifact_handler_batch_get_traces_not_found(mock_tracking_store):
2857      trace_id = "non-existent-batch-trace"
2858  
2859      # Simulate get_trace not being implemented
2860      mock_tracking_store.get_trace.side_effect = MlflowNotImplementedException(
2861          "get_trace is not implemented"
2862      )
2863      # batch_get_traces returns empty list (trace not found)
2864      mock_tracking_store.batch_get_traces.return_value = []
2865  
2866      with app.test_request_context(method="GET", query_string={"request_id": trace_id}):
2867          response = get_trace_artifact_handler()
2868  
2869      # Verify both methods were called
2870      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True)
2871      mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id], None)
2872  
2873      # Verify 404 response
2874      assert response.status_code == 404
2875      response_data = json.loads(response.get_data())
2876      assert "error_code" in response_data
2877      assert response_data["error_code"] == "RESOURCE_DOES_NOT_EXIST"
2878      assert f"Trace with id={trace_id} not found" in response_data["message"]
2879  
2880  
2881  def test_get_trace_artifact_handler_fallback_to_artifact_repo(mock_tracking_store):
2882      trace_id = "test-trace-artifact-repo-123"
2883  
2884      trace_info = TraceInfo(
2885          trace_id=trace_id,
2886          trace_location=EntityTraceLocation.from_experiment_id("3"),
2887          request_time=1234567890,
2888          execution_duration=4000,
2889          state=TraceState.OK,
2890      )
2891  
2892      trace_data = {
2893          "spans": [
2894              {
2895                  "name": "artifact_span",
2896                  "context": {"trace_id": trace_id, "span_id": "123"},
2897                  "parent_id": None,
2898                  "start_time": 100,
2899                  "end_time": 200,
2900                  "status_code": "OK",
2901                  "status_message": "",
2902                  "attributes": {},
2903                  "events": [],
2904              }
2905          ]
2906      }
2907  
2908      # Simulate batch_get_traces not being implemented
2909      mock_tracking_store.get_trace.side_effect = MlflowNotImplementedException(
2910          "get_trace is not implemented"
2911      )
2912      mock_tracking_store.batch_get_traces.side_effect = MlflowNotImplementedException(
2913          "batch_get_traces is not implemented"
2914      )
2915      mock_tracking_store.get_trace_info.return_value = trace_info
2916  
2917      # Mock the artifact repo
2918      mock_artifact_repo = mock.MagicMock()
2919      mock_artifact_repo.download_trace_data.return_value = trace_data
2920  
2921      with mock.patch(
2922          "mlflow.server.handlers._get_trace_artifact_repo", return_value=mock_artifact_repo
2923      ):
2924          with app.test_request_context(method="GET", query_string={"request_id": trace_id}):
2925              response = get_trace_artifact_handler()
2926  
2927      # Verify the fallback path was taken
2928      mock_tracking_store.get_trace.assert_called_once_with(trace_id, allow_partial=True)
2929      mock_tracking_store.batch_get_traces.assert_called_once_with([trace_id], None)
2930      mock_tracking_store.get_trace_info.assert_called_once_with(trace_id)
2931      mock_artifact_repo.download_trace_data.assert_called_once()
2932  
2933      # Verify successful response
2934      assert response is not None
2935      assert response.status_code == 200
2936      assert response.headers["Content-Disposition"] == "attachment; filename=traces.json"
2937  
2938  
2939  def test_get_trace_artifact_handler_with_attachment_path(mock_tracking_store):
2940      trace_id = "tr-test-attachment-123"
2941      attachment_id = "a1b2c3d4-e5f6-4890-abcd-ef1234567890"
2942  
2943      trace_info = TraceInfo(
2944          trace_id=trace_id,
2945          trace_location=EntityTraceLocation.from_experiment_id("3"),
2946          request_time=1234567890,
2947          execution_duration=4000,
2948          state=TraceState.OK,
2949      )
2950  
2951      mock_tracking_store.get_trace_info.return_value = trace_info
2952  
2953      mock_artifact_repo = mock.MagicMock()
2954      mock_artifact_repo.download_trace_attachment.return_value = b"\x89PNG fake image"
2955  
2956      with mock.patch(
2957          "mlflow.server.handlers._get_trace_artifact_repo", return_value=mock_artifact_repo
2958      ):
2959          query = {"request_id": trace_id, "path": attachment_id}
2960          with app.test_request_context(method="GET", query_string=query):
2961              response = get_trace_artifact_handler()
2962  
2963      mock_tracking_store.get_trace_info.assert_called_once_with(trace_id)
2964      mock_artifact_repo.download_trace_attachment.assert_called_once_with(attachment_id)
2965      assert response.status_code == 200
2966      assert response.headers["Content-Type"] == "application/octet-stream"
2967      assert response.headers["Content-Disposition"] == f"attachment; filename={attachment_id}"
2968      assert response.headers["X-Content-Type-Options"] == "nosniff"
2969  
2970  
2971  def test_get_trace_artifact_handler_attachment_missing_request_id():
2972      query = {"path": "a1b2c3d4-e5f6-4890-abcd-ef1234567890"}
2973      with app.test_request_context(method="GET", query_string=query):
2974          response = get_trace_artifact_handler()
2975      assert response.status_code == 400
2976  
2977  
2978  def test_get_trace_artifact_handler_attachment_trace_not_found(mock_tracking_store):
2979      mock_tracking_store.get_trace_info.return_value = None
2980  
2981      query = {"request_id": "tr-nonexistent", "path": "a1b2c3d4-e5f6-4890-abcd-ef1234567890"}
2982      with app.test_request_context(method="GET", query_string=query):
2983          response = get_trace_artifact_handler()
2984      assert response.status_code == 404
2985  
2986  
2987  def test_delete_trace_tag_v2_handler(mock_get_request_message, mock_tracking_store):
2988      """Test v2 delete_trace_tag handler with request_id parameter.
2989  
2990      Verifies that when the Flask route uses request_id path parameter,
2991      the _delete_trace_tag handler is called and invokes store.delete_trace_tag().
2992      """
2993  
2994      request_id = "tr-123v2"
2995      tag_key = "tk"
2996  
2997      # Create the request message
2998      request_msg = DeleteTraceTag(key=tag_key)
2999      mock_get_request_message.return_value = request_msg
3000  
3001      # Call the v2 handler with request_id parameter
3002      response = _delete_trace_tag(request_id=request_id)
3003  
3004      # Verify the store method was called with correct parameters
3005      mock_tracking_store.delete_trace_tag.assert_called_once_with(request_id, tag_key)
3006  
3007      assert response is not None
3008      assert response.status_code == 200
3009  
3010  
3011  def test_delete_trace_tag_v3_handler(mock_get_request_message, mock_tracking_store):
3012      """Test v3 delete_trace_tag handler with trace_id parameter.
3013  
3014      Verifies that when the Flask route uses trace_id path parameter,
3015      the _delete_trace_tag_v3 handler is called and invokes store.delete_trace_tag().
3016      This is similar to v2 but uses the v3 proto message and route parameter naming.
3017      """
3018  
3019      trace_id = "tr-v3-456"
3020      tag_key = "tk"
3021  
3022      # Create the request message with V3
3023      request_msg = DeleteTraceTagV3(key=tag_key)
3024      mock_get_request_message.return_value = request_msg
3025  
3026      # Call the v3 handler with trace_id parameter
3027      response = _delete_trace_tag_v3(trace_id=trace_id)
3028  
3029      # Verify the store method was called with correct parameters
3030      # Both v2 and v3 call the same store method
3031      mock_tracking_store.delete_trace_tag.assert_called_once_with(trace_id, tag_key)
3032  
3033      assert response is not None
3034      assert response.status_code == 200
3035  
3036  
3037  def test_set_trace_tag_v2_handler(mock_get_request_message, mock_tracking_store):
3038      """Test v2 set_trace_tag handler with request_id parameter.
3039  
3040      Verifies that when the Flask route uses request_id path parameter,
3041      the _set_trace_tag handler is called and invokes store.set_trace_tag().
3042      """
3043      trace_id = "tr-test-v2-123"
3044      tag_key = "tk"
3045      tag_value = "tv"
3046  
3047      # Create the request message
3048      request_msg = SetTraceTag(key=tag_key, value=tag_value)
3049      mock_get_request_message.return_value = request_msg
3050  
3051      # Call the v2 handler with request_id parameter
3052      response = _set_trace_tag(request_id=trace_id)
3053  
3054      # Verify the store method was called with correct parameters
3055      mock_tracking_store.set_trace_tag.assert_called_once_with(trace_id, tag_key, tag_value)
3056  
3057      # Verify response was created (200 status)
3058      assert response is not None
3059      assert response.status_code == 200
3060  
3061  
3062  def test_set_trace_tag_v3_handler(mock_get_request_message, mock_tracking_store):
3063      """Test v3 set_trace_tag handler with trace_id parameter.
3064  
3065      Verifies that when the Flask route uses trace_id path parameter,
3066      the _set_trace_tag_v3 handler is called and invokes store.set_trace_tag().
3067      This is similar to v2 but uses the v3 proto message and route parameter naming.
3068      """
3069      trace_id = "tr-test-v3-456"
3070      tag_key = "tk"
3071      tag_value = "tv"
3072  
3073      # Create the request message (v3 version)
3074      request_msg = SetTraceTagV3(key=tag_key, value=tag_value)
3075      mock_get_request_message.return_value = request_msg
3076  
3077      # Call the v3 handler with trace_id parameter
3078      response = _set_trace_tag_v3(trace_id=trace_id)
3079  
3080      # Verify the store method was called with correct parameters
3081      # Note: Both handlers call the same store method
3082      mock_tracking_store.set_trace_tag.assert_called_once_with(trace_id, tag_key, tag_value)
3083  
3084      # Verify response was created (200 status)
3085      assert response is not None
3086      assert response.status_code == 200
3087  
3088  
3089  def test_link_prompts_to_trace_handler(mock_get_request_message, mock_tracking_store):
3090      """Test link_prompts_to_trace handler.
3091  
3092      Verifies that the handler correctly parses the request and calls
3093      store.link_prompts_to_trace() with the appropriate PromptVersion objects.
3094      """
3095      trace_id = "tr-test-123"
3096      prompt_versions_refs = [
3097          LinkPromptsToTrace.PromptVersionRef(name="prompt1", version="1"),
3098          LinkPromptsToTrace.PromptVersionRef(name="prompt2", version="2"),
3099      ]
3100  
3101      # Create the request message
3102      request_msg = LinkPromptsToTrace(trace_id=trace_id, prompt_versions=prompt_versions_refs)
3103      mock_get_request_message.return_value = request_msg
3104  
3105      # Call the handler
3106      response = _link_prompts_to_trace()
3107  
3108      # Verify the store method was called with correct parameters
3109      # The handler should convert PromptVersionRef to PromptVersion objects
3110      call_args = mock_tracking_store.link_prompts_to_trace.call_args
3111      assert call_args[1]["trace_id"] == trace_id
3112  
3113      prompt_versions = call_args[1]["prompt_versions"]
3114      assert len(prompt_versions) == 2
3115      assert isinstance(prompt_versions[0], PromptVersion)
3116      assert prompt_versions[0].name == "prompt1"
3117      assert prompt_versions[0].version == 1
3118      assert isinstance(prompt_versions[1], PromptVersion)
3119      assert prompt_versions[1].name == "prompt2"
3120      assert prompt_versions[1].version == 2
3121  
3122      # Verify response was created (200 status)
3123      assert response is not None
3124      assert response.status_code == 200
3125  
3126  
3127  def test_list_providers():
3128      with app.test_client() as c:
3129          response = c.get("/ajax-api/3.0/mlflow/gateway/supported-providers")
3130          assert response.status_code == 200
3131          data = response.get_json()
3132          assert "providers" in data
3133          assert isinstance(data["providers"], list)
3134          assert len(data["providers"]) > 0
3135          assert "openai" in data["providers"]
3136  
3137  
3138  def test_list_providers_with_allowed_filter(monkeypatch):
3139      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "openai,anthropic")
3140      with app.test_client() as c:
3141          response = c.get("/ajax-api/3.0/mlflow/gateway/supported-providers")
3142          assert response.status_code == 200
3143          data = response.get_json()
3144          assert "openai" in data["providers"]
3145          assert "anthropic" in data["providers"]
3146          assert "gemini" not in data["providers"]
3147          assert "bedrock" not in data["providers"]
3148  
3149  
3150  def test_list_models():
3151      with app.test_client() as c:
3152          response = c.get("/ajax-api/3.0/mlflow/gateway/supported-models?provider=openai")
3153          assert response.status_code == 200
3154          data = response.get_json()
3155          assert "models" in data
3156          assert isinstance(data["models"], list)
3157          assert len(data["models"]) > 0
3158  
3159  
3160  def test_list_models_all_providers():
3161      with app.test_client() as c:
3162          response = c.get("/ajax-api/3.0/mlflow/gateway/supported-models")
3163          assert response.status_code == 200
3164          data = response.get_json()
3165          assert "models" in data
3166          assert isinstance(data["models"], list)
3167          assert len(data["models"]) > 0
3168  
3169  
3170  def test_get_provider_config():
3171      with app.test_client() as c:
3172          response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=openai")
3173          assert response.status_code == 200
3174          data = response.get_json()
3175          assert "auth_modes" in data
3176          assert "default_mode" in data
3177          assert data["default_mode"] == "api_key"
3178          assert len(data["auth_modes"]) >= 1
3179          api_key_mode = data["auth_modes"][0]
3180          assert api_key_mode["mode"] == "api_key"
3181  
3182  
3183  def test_get_provider_config_with_multiple_auth_modes():
3184      with app.test_client() as c:
3185          response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=bedrock")
3186          assert response.status_code == 200
3187          data = response.get_json()
3188  
3189          assert "auth_modes" in data
3190          assert data["default_mode"] == "api_key"
3191          assert len(data["auth_modes"]) >= 2
3192  
3193          access_keys_mode = next(m for m in data["auth_modes"] if m["mode"] == "access_keys")
3194          assert len(access_keys_mode["secret_fields"]) == 2
3195          assert any(f["name"] == "aws_secret_access_key" for f in access_keys_mode["secret_fields"])
3196          assert any(f["name"] == "aws_region_name" for f in access_keys_mode["config_fields"])
3197  
3198          iam_role_mode = next(m for m in data["auth_modes"] if m["mode"] == "iam_role")
3199          assert any(f["name"] == "aws_role_name" for f in iam_role_mode["config_fields"])
3200  
3201  
3202  def test_get_provider_config_missing_provider():
3203      with app.test_client() as c:
3204          response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config")
3205          assert response.status_code == 400
3206  
3207  
3208  def test_get_provider_config_with_allowed_filter(monkeypatch):
3209      monkeypatch.setenv("MLFLOW_GATEWAY_ALLOWED_PROVIDERS", "anthropic")
3210      with app.test_client() as c:
3211          response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=openai")
3212          assert response.status_code == 400
3213          data = response.get_json()
3214          assert "not allowed" in data["message"]
3215  
3216          response = c.get("/ajax-api/3.0/mlflow/gateway/provider-config?provider=anthropic")
3217          assert response.status_code == 200
3218  
3219  
3220  @pytest.mark.parametrize(
3221      "invalid_name",
3222      [
3223          "invalid name",  # space
3224          "invalid/name",  # slash
3225          "invalid?name",  # question mark
3226          "invalid&name",  # ampersand
3227          "invalid#name",  # hash
3228          "invalid@name",  # at sign
3229          "invalid:name",  # colon
3230          "日本語",  # unicode (Japanese)
3231          "naïve",  # unicode (accented)
3232      ],
3233  )
3234  def test_create_gateway_endpoint_rejects_invalid_name(mock_get_request_message, invalid_name):
3235      from mlflow.protos.service_pb2 import CreateGatewayEndpoint
3236      from mlflow.server.handlers import _create_gateway_endpoint
3237  
3238      request_msg = CreateGatewayEndpoint()
3239      request_msg.name = invalid_name
3240      mock_get_request_message.return_value = request_msg
3241  
3242      response = _create_gateway_endpoint()
3243  
3244      assert response.status_code == 400
3245      response_data = json.loads(response.get_data())
3246      assert "Invalid endpoint name" in response_data["message"]
3247      assert response_data["error_code"] == "INVALID_PARAMETER_VALUE"
3248  
3249  
3250  @pytest.mark.parametrize(
3251      "invalid_name",
3252      [
3253          "invalid name",  # space
3254          "invalid/name",  # slash
3255          "invalid?name",  # question mark
3256          "invalid&name",  # ampersand
3257      ],
3258  )
3259  def test_update_gateway_endpoint_rejects_invalid_name(mock_get_request_message, invalid_name):
3260      from mlflow.protos.service_pb2 import UpdateGatewayEndpoint
3261      from mlflow.server.handlers import _update_gateway_endpoint
3262  
3263      request_msg = UpdateGatewayEndpoint()
3264      request_msg.endpoint_id = "test-endpoint-id"
3265      request_msg.name = invalid_name
3266      mock_get_request_message.return_value = request_msg
3267  
3268      response = _update_gateway_endpoint()
3269  
3270      assert response.status_code == 400
3271      response_data = json.loads(response.get_data())
3272      assert "Invalid endpoint name" in response_data["message"]
3273      assert response_data["error_code"] == "INVALID_PARAMETER_VALUE"
3274  
3275  
3276  def test_get_gateway_endpoint_by_endpoint_id(mock_get_request_message, mock_tracking_store):
3277      request_msg = GetGatewayEndpoint()
3278      request_msg.endpoint_id = "ep-123"
3279      mock_get_request_message.return_value = request_msg
3280  
3281      mock_endpoint = mock.MagicMock()
3282      mock_endpoint.to_proto.return_value = GatewayEndpoint(endpoint_id="ep-123")
3283      mock_tracking_store.get_gateway_endpoint.return_value = mock_endpoint
3284  
3285      response = _get_gateway_endpoint()
3286  
3287      mock_tracking_store.get_gateway_endpoint.assert_called_once_with(
3288          endpoint_id="ep-123", name=None
3289      )
3290      assert response.status_code == 200
3291  
3292  
3293  def test_get_gateway_endpoint_by_name(mock_get_request_message, mock_tracking_store):
3294  
3295      request_msg = GetGatewayEndpoint()
3296      request_msg.name = "my-endpoint"
3297      mock_get_request_message.return_value = request_msg
3298  
3299      mock_endpoint = mock.MagicMock()
3300      mock_endpoint.to_proto.return_value = GatewayEndpoint(endpoint_id="ep-456", name="my-endpoint")
3301      mock_tracking_store.get_gateway_endpoint.return_value = mock_endpoint
3302  
3303      response = _get_gateway_endpoint()
3304  
3305      mock_tracking_store.get_gateway_endpoint.assert_called_once_with(
3306          endpoint_id=None, name="my-endpoint"
3307      )
3308      assert response.status_code == 200
3309  
3310  
3311  def test_query_trace_metrics_handler(mock_get_request_message, mock_tracking_store):
3312      experiment_ids = ["exp1", "exp2"]
3313      metric_name = "latency"
3314  
3315      # Create aggregation protos
3316      aggregations_proto = [
3317          MetricAggregation(aggregation_type=AggregationType.AVG).to_proto(),
3318          MetricAggregation(
3319              aggregation_type=AggregationType.PERCENTILE, percentile_value=95.0
3320          ).to_proto(),
3321      ]
3322  
3323      # Create the request message
3324      request_msg = QueryTraceMetrics(
3325          experiment_ids=experiment_ids,
3326          view_type=MetricViewType.TRACES.to_proto(),
3327          metric_name=metric_name,
3328          aggregations=aggregations_proto,
3329          dimensions=["status", "model"],
3330          filters=["status = 'OK'"],
3331          time_interval_seconds=3600,
3332          start_time_ms=1000000,
3333          end_time_ms=2000000,
3334          max_results=100,
3335          page_token="token123",
3336      )
3337      mock_get_request_message.return_value = request_msg
3338  
3339      # Create mock result
3340      mock_data_points = [
3341          MetricDataPoint(
3342              metric_name="latency",
3343              dimensions={"status": "OK", "model": "gpt-4"},
3344              values={"AVG": 150.5, "P95.0": 200.0},
3345          ),
3346          MetricDataPoint(
3347              metric_name="latency",
3348              dimensions={"status": "ERROR", "model": "gpt-4"},
3349              values={"AVG": 50.0, "P95.0": 75.0},
3350          ),
3351      ]
3352  
3353      # Create a mock result object with next_page_token attribute
3354      mock_result = mock.MagicMock()
3355      mock_result.__iter__ = mock.MagicMock(return_value=iter(mock_data_points))
3356      mock_result.token = "next_token"
3357      mock_tracking_store.query_trace_metrics.return_value = mock_result
3358  
3359      # Call the handler
3360      response = _query_trace_metrics()
3361  
3362      mock_tracking_store.query_trace_metrics.assert_called_once_with(
3363          experiment_ids=experiment_ids,
3364          view_type=MetricViewType.TRACES,
3365          metric_name=metric_name,
3366          aggregations=[
3367              MetricAggregation(aggregation_type=AggregationType.AVG),
3368              MetricAggregation(aggregation_type=AggregationType.PERCENTILE, percentile_value=95.0),
3369          ],
3370          dimensions=["status", "model"],
3371          filters=["status = 'OK'"],
3372          time_interval_seconds=3600,
3373          start_time_ms=1000000,
3374          end_time_ms=2000000,
3375          max_results=100,
3376          page_token="token123",
3377      )
3378  
3379      assert response is not None
3380      assert response.status_code == 200
3381      response_data = json.loads(response.get_data())
3382      assert "data_points" in response_data
3383      assert len(response_data["data_points"]) == 2
3384      assert response_data["data_points"][0] == asdict(mock_data_points[0])
3385      assert response_data["data_points"][1] == asdict(mock_data_points[1])
3386      assert response_data["next_page_token"] == "next_token"
3387  
3388  
3389  def test_query_trace_metrics_handler_empty_result(mock_get_request_message, mock_tracking_store):
3390      request_msg = QueryTraceMetrics(
3391          experiment_ids=["exp1"],
3392          view_type=MetricViewType.TRACES.to_proto(),
3393          metric_name="latency",
3394          aggregations=[MetricAggregation(aggregation_type=AggregationType.AVG).to_proto()],
3395      )
3396      mock_get_request_message.return_value = request_msg
3397  
3398      mock_result = mock.MagicMock()
3399      mock_result.__iter__ = mock.MagicMock(return_value=iter([]))
3400      mock_result.token = None
3401      mock_tracking_store.query_trace_metrics.return_value = mock_result
3402  
3403      response = _query_trace_metrics()
3404  
3405      mock_tracking_store.query_trace_metrics.assert_called_once_with(
3406          experiment_ids=["exp1"],
3407          view_type=MetricViewType.TRACES,
3408          metric_name="latency",
3409          aggregations=[MetricAggregation(aggregation_type=AggregationType.AVG)],
3410          dimensions=None,
3411          filters=None,
3412          time_interval_seconds=None,
3413          start_time_ms=None,
3414          end_time_ms=None,
3415          max_results=MAX_RESULTS_QUERY_TRACE_METRICS,
3416          page_token=None,
3417      )
3418  
3419      assert response is not None
3420      assert response.status_code == 200
3421      response_data = json.loads(response.get_data())
3422      assert response_data == {}
3423  
3424  
3425  def test_invoke_scorer_missing_experiment_id():
3426      with app.test_client() as c:
3427          response = c.post(
3428              "/ajax-api/3.0/mlflow/scorer/invoke",
3429              json={"serialized_scorer": "test", "trace_ids": ["trace1"]},
3430          )
3431          assert response.status_code == 400
3432          data = response.get_json()
3433          assert "experiment_id" in data["message"]
3434  
3435  
3436  def test_invoke_scorer_missing_serialized_scorer():
3437      with app.test_client() as c:
3438          response = c.post(
3439              "/ajax-api/3.0/mlflow/scorer/invoke",
3440              json={"experiment_id": "123", "trace_ids": ["trace1"]},
3441          )
3442          assert response.status_code == 400
3443          data = response.get_json()
3444          assert "serialized_scorer" in data["message"]
3445  
3446  
3447  def test_invoke_scorer_missing_trace_ids():
3448      with app.test_client() as c:
3449          response = c.post(
3450              "/ajax-api/3.0/mlflow/scorer/invoke",
3451              json={"experiment_id": "123", "serialized_scorer": "test"},
3452          )
3453          assert response.status_code == 400
3454          data = response.get_json()
3455          assert "Please select at least one trace to evaluate" in data["message"]
3456  
3457  
3458  def test_invoke_scorer_submits_jobs(mock_tracking_store):
3459      serialized_scorer = json.dumps({
3460          "name": "test_judge",
3461          "aggregations": [],
3462          "description": None,
3463          "is_session_level_scorer": False,
3464          "mlflow_version": mlflow.__version__,
3465          "serialization_version": 1,
3466          "builtin_scorer_class": None,
3467          "builtin_scorer_pydantic_data": None,
3468          "call_source": None,
3469          "call_signature": None,
3470          "original_func_name": None,
3471          "instructions_judge_pydantic_data": {
3472              "instructions": "Test: {{ inputs }}",
3473              "model": "openai:/gpt-4",
3474              "feedback_value_type": {
3475                  "enum": ["Yes", "No"],
3476                  "title": "Result",
3477                  "type": "string",
3478              },
3479          },
3480      })
3481  
3482      with mock.patch("mlflow.server.jobs.submit_job") as mock_submit:
3483          mock_job = mock.MagicMock()
3484          mock_job.job_id = "test-job-123"
3485          mock_submit.return_value = mock_job
3486  
3487          with app.test_client() as c:
3488              response = c.post(
3489                  "/ajax-api/3.0/mlflow/scorer/invoke",
3490                  json={
3491                      "experiment_id": "exp-123",
3492                      "serialized_scorer": serialized_scorer,
3493                      "trace_ids": ["trace1", "trace2"],
3494                  },
3495              )
3496              assert response.status_code == 200
3497              data = response.get_json()
3498              assert "jobs" in data
3499              assert len(data["jobs"]) == 1
3500              assert data["jobs"][0]["job_id"] == "test-job-123"
3501              assert data["jobs"][0]["trace_ids"] == ["trace1", "trace2"]
3502  
3503          mock_submit.assert_called_once()
3504  
3505  
3506  def test_get_ui_telemetry_handler(
3507      test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check
3508  ):
3509      config = {
3510          "disable_telemetry": False,
3511          "disable_ui_telemetry": False,
3512          "disable_ui_events": ["event1", "event2"],
3513          "ui_rollout_percentage": 50,
3514      }
3515  
3516      with mock.patch(
3517          "mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config
3518      ) as mock_fetch:
3519          response = get_ui_telemetry_handler()
3520  
3521          assert response is not None
3522          assert response.status_code == 200
3523  
3524          response_data = json.loads(response.get_data())
3525  
3526          assert response_data["disable_ui_telemetry"] is False
3527          assert response_data["disable_ui_events"] == ["event1", "event2"]
3528          # rollout percent gets converted to a float as that is the proto definition
3529          assert response_data["ui_rollout_percentage"] == 50.0
3530          assert "config" in mock_telemetry_config_cache
3531          assert mock_fetch.call_count == 1
3532          mock_fetch.reset_mock()
3533  
3534          # subsequent call should hit cache
3535          response = get_ui_telemetry_handler()
3536          mock_fetch.assert_not_called()
3537          assert response_data["disable_ui_telemetry"] is False
3538          assert response_data["disable_ui_events"] == ["event1", "event2"]
3539          assert response_data["ui_rollout_percentage"] == 50.0
3540  
3541  
3542  def test_get_ui_telemetry_handler_disabled_by_config(
3543      test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check
3544  ):
3545      config = {
3546          "disable_telemetry": True,
3547          "disable_ui_telemetry": False,
3548          "disable_ui_events": [],
3549          "ui_rollout_percentage": 0,
3550      }
3551  
3552      with mock.patch(
3553          "mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config
3554      ) as mock_fetch:
3555          response = get_ui_telemetry_handler()
3556          assert response is not None
3557          assert response.status_code == 200
3558          response_data = json.loads(response.get_data())
3559  
3560          # if disable_telemetry is True, the server should always report
3561          # that UI telemetry is disabled regardless of disable_ui_telemetry
3562          assert response_data["disable_ui_telemetry"] is True
3563          assert response_data["ui_rollout_percentage"] == 0.0
3564          assert response_data["disable_ui_events"] == []
3565          assert mock_fetch.call_count == 1
3566  
3567  
3568  def test_get_ui_telemetry_handler_disabled_by_env(
3569      test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check, monkeypatch
3570  ):
3571      monkeypatch.setenv("DO_NOT_TRACK", "true")
3572      with mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config") as mock_fetch:
3573          response = get_ui_telemetry_handler()
3574          assert response is not None
3575          assert response.status_code == 200
3576          response_data = json.loads(response.get_data())
3577  
3578          # if telemetry is disabled by env var, the server should always report
3579          # that UI telemetry is disabled, and no config fetch should happen
3580          mock_fetch.assert_not_called()
3581          assert response_data["disable_ui_telemetry"] is True
3582          assert response_data["ui_rollout_percentage"] == 0.0
3583          assert response_data["disable_ui_events"] == []
3584  
3585  
3586  def test_get_ui_telemetry_handler_fallback_values(
3587      test_app_context, mock_telemetry_config_cache, bypass_telemetry_env_check
3588  ):
3589      config_without_ui_fields = {
3590          "disable_telemetry": False,
3591          "rollout_percentage": 100,
3592      }
3593  
3594      # test fallback values if we forget to define UI config fields
3595      with mock.patch("requests.get", return_value=config_without_ui_fields):
3596          response = get_ui_telemetry_handler()
3597  
3598          assert response is not None
3599          assert response.status_code == 200
3600  
3601          response_data = json.loads(response.get_data())
3602  
3603          assert response_data["disable_ui_telemetry"] is True
3604          assert response_data["ui_rollout_percentage"] == 0
3605          assert response_data["disable_ui_events"] == []
3606  
3607      # test fallback values if we fail to fetch the config
3608      with mock.patch("requests.get", return_value=mock.Mock(status_code=404)):
3609          response = get_ui_telemetry_handler()
3610  
3611          assert response.status_code == 200
3612  
3613          response_data = json.loads(response.get_data())
3614          assert response_data["disable_ui_telemetry"] is True
3615          assert response_data["ui_rollout_percentage"] == 0
3616          assert response_data["disable_ui_events"] == []
3617  
3618  
3619  def test_post_ui_telemetry_handler_success(
3620      test_app, mock_telemetry_config_cache, bypass_telemetry_env_check
3621  ):
3622      event1 = {
3623          "event_name": "test_event_1",
3624          "timestamp_ns": 1234567890000000,
3625          "params": {"key1": "value1"},
3626          "installation_id": "install-123",
3627          "session_id": "session-456",
3628      }
3629  
3630      event2 = {
3631          "event_name": "test_event_2",
3632          "timestamp_ns": 1234567890000001,
3633          "params": {"key2": "value2"},
3634          "installation_id": "install-123",
3635          "session_id": "session-456",
3636      }
3637      request = json.dumps({"records": [event1, event2]})
3638      config = {"disable_ui_telemetry": False, "disable_telemetry": False}
3639      mock_client = mock.MagicMock()
3640  
3641      server_install_id = "server-install-789"
3642      with (
3643          test_app.test_request_context(
3644              "/ui-telemetry", method="POST", data=request, content_type="application/json"
3645          ),
3646          mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config),
3647          mock.patch("mlflow.server.handlers.get_telemetry_client", return_value=mock_client),
3648          mock.patch(
3649              "mlflow.server.handlers.get_or_create_installation_id",
3650              return_value=server_install_id,
3651          ),
3652      ):
3653          response = post_ui_telemetry_handler()
3654  
3655          assert response is not None
3656          assert response.status_code == 200
3657  
3658          response_data = json.loads(response.get_data())
3659  
3660          assert response_data["status"] == "success"
3661          assert mock_client.add_records.call_count == 1
3662          assert mock_client.add_records.call_args[0][0] == [
3663              Record(
3664                  **event1,
3665                  duration_ms=0,
3666                  status=Status.SUCCESS,
3667                  server_installation_id=server_install_id,
3668              ),
3669              Record(
3670                  **event2,
3671                  duration_ms=0,
3672                  status=Status.SUCCESS,
3673                  server_installation_id=server_install_id,
3674              ),
3675          ]
3676  
3677  
3678  def test_post_ui_telemetry_handler_telemetry_disabled_by_config(
3679      test_app, mock_telemetry_config_cache, bypass_telemetry_env_check
3680  ):
3681      event = {
3682          "event_name": "test_event_1",
3683          "timestamp_ns": 1234567890000000,
3684          "params": {"key1": "value1"},
3685          "installation_id": "install-123",
3686          "session_id": "session-456",
3687      }
3688  
3689      request = json.dumps({"records": [event]})
3690  
3691      config = {"disable_ui_telemetry": True}
3692  
3693      mock_client = mock.MagicMock()
3694  
3695      with (
3696          test_app.test_request_context(
3697              "/ui-telemetry", method="POST", data=request, content_type="application/json"
3698          ),
3699          mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config", return_value=config),
3700          mock.patch("mlflow.server.handlers.get_telemetry_client", return_value=mock_client),
3701      ):
3702          response = post_ui_telemetry_handler()
3703  
3704          assert response is not None
3705          assert response.status_code == 200
3706  
3707          response_data = json.loads(response.get_data())
3708  
3709          assert response_data["status"] == "disabled"
3710          mock_client.add_record.assert_not_called()
3711  
3712  
3713  def test_post_ui_telemetry_handler_telemetry_disabled_by_env(
3714      test_app, mock_telemetry_config_cache, bypass_telemetry_env_check, monkeypatch
3715  ):
3716      monkeypatch.setenv("DO_NOT_TRACK", "true")
3717      request = json.dumps({"records": []})
3718      with (
3719          test_app.test_request_context(
3720              "/ui-telemetry", method="POST", data=request, content_type="application/json"
3721          ),
3722          mock.patch("mlflow.server.handlers.fetch_ui_telemetry_config") as mock_fetch,
3723          mock.patch("mlflow.server.handlers.get_telemetry_client") as mock_get_client,
3724      ):
3725          response = post_ui_telemetry_handler()
3726  
3727          assert response is not None
3728          assert response.status_code == 200
3729  
3730          response_data = json.loads(response.get_data())
3731  
3732          assert response_data["status"] == "disabled"
3733  
3734          # assert that no fetch happens and no client is retrieved
3735          mock_fetch.assert_not_called()
3736          mock_get_client.assert_not_called()
3737  
3738  
3739  def test_download_artifact_streams_in_chunks(enable_serve_artifacts, tmp_path):
3740      # Create a test file with binary data larger than the chunk size (2MB + 1000 bytes)
3741      test_file_size = ARTIFACT_STREAM_CHUNK_SIZE * 2 + 1000
3742      test_data = b"x" * test_file_size
3743  
3744      artifact_path = "test_model/model.pkl"
3745      test_file = tmp_path / "model.pkl"
3746      test_file.write_bytes(test_data)
3747  
3748      with (
3749          app.test_request_context(method="GET"),
3750          mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo,
3751          mock.patch("mlflow.server.handlers.tempfile.TemporaryDirectory") as mock_tmp_dir,
3752      ):
3753          # Setup mocks
3754          mock_tmp_dir_instance = mock.MagicMock()
3755          mock_tmp_dir_instance.name = str(tmp_path)
3756          mock_tmp_dir.return_value = mock_tmp_dir_instance
3757  
3758          mock_artifact_repo = mock.MagicMock()
3759          mock_artifact_repo.download_artifacts.return_value = str(test_file)
3760          mock_repo.return_value = mock_artifact_repo
3761  
3762          # Call the function and capture the response
3763          response = _download_artifact(artifact_path)
3764  
3765          # Extract chunks from the response by iterating over its data
3766          response_chunks = list(response.response)
3767  
3768          # Verify that data was streamed in chunks, not line by line
3769          # For a 2MB+ binary file, line-by-line would produce many small chunks
3770          # Chunk-based streaming should produce exactly 3 chunks (2*1MB + 1000 bytes)
3771          assert len(response_chunks) == 3, f"Expected 3 chunks, got {len(response_chunks)}"
3772  
3773          # Verify chunk sizes
3774          assert len(response_chunks[0]) == ARTIFACT_STREAM_CHUNK_SIZE
3775          assert len(response_chunks[1]) == ARTIFACT_STREAM_CHUNK_SIZE
3776          assert len(response_chunks[2]) == 1000
3777  
3778          # Verify that all data is correctly streamed
3779          streamed_data = b"".join(response_chunks)
3780          assert streamed_data == test_data
3781  
3782  
3783  def test_create_prompt_optimization_job(mock_tracking_store):
3784      mock_job_entity = JobEntity(
3785          job_id="job-123",
3786          creation_time=1234567890,
3787          job_name="optimize_prompts",
3788          params='{"run_id": "run-456"}',
3789          timeout=None,
3790          status=JobStatus.PENDING,
3791          result=None,
3792          retry_count=0,
3793          last_update_time=1234567890,
3794          status_details=None,
3795      )
3796  
3797      mock_run = mock.MagicMock()
3798      mock_run.info.run_id = "run-456"
3799      mock_tracking_store.create_run.return_value = mock_run
3800  
3801      mock_dataset = mock.MagicMock()
3802      mock_dataset._to_mlflow_entity.return_value = mock.MagicMock()
3803  
3804      with (
3805          mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job_entity),
3806          mock.patch("mlflow.server.handlers._get_user", return_value="test_user"),
3807          mock.patch(
3808              "mlflow.genai.datasets.get_dataset", return_value=mock_dataset
3809          ) as mock_get_dataset,
3810      ):
3811          with app.test_request_context(
3812              method="POST",
3813              json={
3814                  "experiment_id": "exp-123",
3815                  "source_prompt_uri": "prompts:/my-prompt/1",
3816                  "config": {
3817                      "optimizer_type": OPTIMIZER_TYPE_GEPA,
3818                      "dataset_id": "dataset-123",
3819                      "scorers": ["Correctness", "Safety"],
3820                      "optimizer_config_json": '{"reflection_model": "openai:/gpt-4"}',
3821                  },
3822                  "tags": [{"key": "env", "value": "test"}],
3823              },
3824          ):
3825              response = _create_prompt_optimization_job()
3826  
3827          mock_get_dataset.assert_called_once_with(dataset_id="dataset-123")
3828  
3829      mock_tracking_store.create_run.assert_called_once()
3830      call_kwargs = mock_tracking_store.create_run.call_args[1]
3831      assert call_kwargs["experiment_id"] == "exp-123"
3832      assert call_kwargs["user_id"] == "test_user"
3833  
3834      mock_tracking_store.log_batch.assert_called_once()
3835      logged_params = mock_tracking_store.log_batch.call_args[1]["params"]
3836      param_dict = {p.key: p.value for p in logged_params}
3837      assert param_dict["source_prompt_uri"] == "prompts:/my-prompt/1"
3838      assert param_dict["optimizer_type"] == "gepa"
3839      assert param_dict["dataset_id"] == "dataset-123"
3840      assert param_dict["scorer_names"] == '["Correctness", "Safety"]'
3841  
3842      response_data = json.loads(response.get_data())
3843      assert response_data["job"]["job_id"] == "job-123"
3844      assert response_data["job"]["run_id"] == "run-456"
3845      assert response_data["job"]["state"]["status"] == "JOB_STATUS_PENDING"
3846      assert response_data["job"]["experiment_id"] == "exp-123"
3847      assert response_data["job"]["source_prompt_uri"] == "prompts:/my-prompt/1"
3848  
3849  
3850  def test_create_prompt_optimization_job_zero_shot(mock_tracking_store):
3851      mock_job_entity = JobEntity(
3852          job_id="job-999",
3853          creation_time=1234567890,
3854          job_name="optimize_prompts",
3855          params='{"run_id": "run-999"}',
3856          timeout=None,
3857          status=JobStatus.PENDING,
3858          result=None,
3859          retry_count=0,
3860          last_update_time=1234567890,
3861          status_details=None,
3862      )
3863  
3864      mock_run = mock.MagicMock()
3865      mock_run.info.run_id = "run-999"
3866      mock_tracking_store.create_run.return_value = mock_run
3867  
3868      with (
3869          mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job_entity),
3870          mock.patch("mlflow.server.handlers._get_user", return_value="test_user"),
3871      ):
3872          with app.test_request_context(
3873              method="POST",
3874              json={
3875                  "experiment_id": "exp-123",
3876                  "source_prompt_uri": "prompts:/my-prompt/1",
3877                  "config": {
3878                      "optimizer_type": OPTIMIZER_TYPE_METAPROMPT,
3879                      "scorers": [],  # Empty scorers for zero-shot
3880                      # No dataset_id - zero-shot optimization
3881                  },
3882              },
3883          ):
3884              response = _create_prompt_optimization_job()
3885  
3886      response_data = json.loads(response.get_data())
3887      assert response_data["job"]["job_id"] == "job-999"
3888      assert response_data["job"]["run_id"] == "run-999"
3889      assert response_data["job"]["state"]["status"] == "JOB_STATUS_PENDING"
3890  
3891      mock_tracking_store.log_batch.assert_called_once()
3892      logged_params = mock_tracking_store.log_batch.call_args[1]["params"]
3893      param_dict = {p.key: p.value for p in logged_params}
3894      assert param_dict["dataset_id"] == ""  # Empty string for None
3895      assert param_dict["scorer_names"] == "[]"  # Empty list
3896  
3897  
3898  def test_create_prompt_optimization_job_missing_prompt_uri(mock_tracking_store):
3899      with app.test_request_context(
3900          method="POST",
3901          json={
3902              "experiment_id": "exp-123",
3903              "config": {
3904                  "optimizer_type": 1,
3905                  "dataset_id": "dataset-123",
3906                  "scorers": ["Correctness"],
3907              },
3908          },
3909      ):
3910          response = _create_prompt_optimization_job()
3911          assert response.status_code == 400
3912          json_response = json.loads(response.get_data())
3913          assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
3914          assert "source_prompt_uri" in json_response["message"]
3915  
3916  
3917  def test_create_prompt_optimization_job_unspecified_optimizer_type(mock_tracking_store):
3918      with app.test_request_context(
3919          method="POST",
3920          json={
3921              "experiment_id": "exp-123",
3922              "source_prompt_uri": "prompts:/my-prompt/1",
3923              "config": {
3924                  "optimizer_type": OPTIMIZER_TYPE_UNSPECIFIED,
3925                  "dataset_id": "dataset-123",
3926                  "scorers": ["Correctness"],
3927              },
3928          },
3929      ):
3930          response = _create_prompt_optimization_job()
3931          assert response.status_code == 400
3932          json_response = json.loads(response.get_data())
3933          assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
3934          assert "optimizer_type is required" in json_response["message"]
3935  
3936  
3937  def test_create_prompt_optimization_job_invalid_optimizer_config_json(mock_tracking_store):
3938      with app.test_request_context(
3939          method="POST",
3940          json={
3941              "experiment_id": "exp-123",
3942              "source_prompt_uri": "prompts:/my-prompt/1",
3943              "config": {
3944                  "optimizer_type": 1,
3945                  "dataset_id": "dataset-123",
3946                  "scorers": ["Correctness"],
3947                  "optimizer_config_json": "invalid json {",
3948              },
3949          },
3950      ):
3951          response = _create_prompt_optimization_job()
3952          assert response.status_code == 400
3953          json_response = json.loads(response.get_data())
3954          assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
3955          assert "Invalid JSON in optimizer_config_json" in json_response["message"]
3956  
3957  
3958  def test_create_prompt_optimization_job_missing_experiment_id(mock_tracking_store):
3959      with app.test_request_context(
3960          method="POST",
3961          json={
3962              "experiment_id": "",  # Empty experiment_id
3963              "source_prompt_uri": "prompts:/my-prompt/1",
3964              "config": {
3965                  "optimizer_type": 1,
3966                  "dataset_id": "dataset-123",
3967                  "scorers": ["Correctness"],
3968              },
3969          },
3970      ):
3971          response = _create_prompt_optimization_job()
3972          assert response.status_code == 400
3973          json_response = json.loads(response.get_data())
3974          assert json_response["error_code"] == ErrorCode.Name(INVALID_PARAMETER_VALUE)
3975          assert "experiment_id is required" in json_response["message"]
3976  
3977  
3978  def test_cancel_prompt_optimization_job():
3979      mock_job_entity = JobEntity(
3980          job_id="job-123",
3981          creation_time=1234567890,
3982          job_name="optimize_prompts",
3983          params=(
3984              '{"experiment_id": "exp-123", "prompt_uri": "prompts:/my-prompt/1", '
3985              '"run_id": "run-456"}'
3986          ),
3987          timeout=None,
3988          status=JobStatus.CANCELED,
3989          result=None,
3990          retry_count=0,
3991          last_update_time=1234567890,
3992          status_details=None,
3993      )
3994  
3995      with (
3996          mock.patch("mlflow.server.jobs.cancel_job", return_value=mock_job_entity),
3997          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
3998      ):
3999          mock_tracking_store = mock.Mock()
4000          mock_store.return_value = mock_tracking_store
4001          with app.test_request_context(method="POST"):
4002              response = _cancel_prompt_optimization_job("job-123")
4003  
4004          # Verify that the underlying run was terminated
4005          mock_tracking_store.update_run_info.assert_called_once()
4006          call_args = mock_tracking_store.update_run_info.call_args
4007          assert call_args.kwargs["run_id"] == "run-456"
4008          assert call_args.kwargs["run_status"] == RunStatus.KILLED
4009          assert call_args.kwargs["run_name"] is None
4010          assert "end_time" in call_args.kwargs
4011  
4012      response_data = json.loads(response.get_data())
4013      assert response_data["job"]["job_id"] == "job-123"
4014      assert response_data["job"]["state"]["status"] == "JOB_STATUS_CANCELED"
4015      assert response_data["job"]["experiment_id"] == "exp-123"
4016      assert response_data["job"]["source_prompt_uri"] == "prompts:/my-prompt/1"
4017      assert response_data["job"]["run_id"] == "run-456"
4018  
4019  
4020  def test_get_prompt_optimization_job_pending(mock_tracking_store):
4021      mock_job = _create_mock_job(status_name="PENDING")
4022  
4023      mock_run = _create_mock_run(
4024          params={
4025              "source_prompt_uri": "prompts:/my-prompt/1",
4026              "optimizer_type": "gepa",
4027              "dataset_id": "dataset-789",
4028              "scorer_names": '["Correctness"]',
4029          }
4030      )
4031      mock_tracking_store.get_run.return_value = mock_run
4032  
4033      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4034          with app.test_client() as c:
4035              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4036              assert response.status_code == 200
4037  
4038              data = response.get_json()
4039              assert "job" in data
4040              job = data["job"]
4041              assert job["job_id"] == "job-123"
4042              assert job["run_id"] == "run-456"
4043              assert job["experiment_id"] == "exp-123"
4044              assert job["source_prompt_uri"] == "prompts:/my-prompt/1"
4045              assert job["state"]["status"] == "JOB_STATUS_PENDING"
4046  
4047  
4048  def test_get_prompt_optimization_job_succeeded_with_result(mock_tracking_store):
4049      mock_job = _create_mock_job(
4050          status_name="SUCCEEDED",
4051          result={"optimized_prompt_uri": "prompts:/my-prompt/2"},
4052      )
4053  
4054      mock_run = _create_mock_run(
4055          params={
4056              "source_prompt_uri": "prompts:/my-prompt/1",
4057              "optimizer_type": "gepa",
4058              "dataset_id": "dataset-789",
4059              "scorer_names": '["Correctness", "Safety"]',
4060          },
4061          metrics={
4062              "initial_eval_score.Correctness": 0.65,
4063              "initial_eval_score.Safety": 0.80,
4064              "final_eval_score.Correctness": 0.89,
4065              "final_eval_score.Safety": 0.95,
4066          },
4067      )
4068      mock_tracking_store.get_run.return_value = mock_run
4069  
4070      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4071          with app.test_client() as c:
4072              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4073              assert response.status_code == 200
4074  
4075              data = response.get_json()
4076              job = data["job"]
4077              assert job["state"]["status"] == "JOB_STATUS_COMPLETED"
4078              assert job["optimized_prompt_uri"] == "prompts:/my-prompt/2"
4079              # Verify metrics are populated from the run
4080              assert job["initial_eval_scores"]["Correctness"] == 0.65
4081              assert job["initial_eval_scores"]["Safety"] == 0.80
4082              assert job["final_eval_scores"]["Correctness"] == 0.89
4083              assert job["final_eval_scores"]["Safety"] == 0.95
4084  
4085  
4086  def test_get_prompt_optimization_job_succeeded_run_fetch_fails(mock_tracking_store):
4087      mock_job = _create_mock_job(
4088          status_name="SUCCEEDED",
4089          result={"optimized_prompt_uri": "prompts:/my-prompt/2"},
4090      )
4091  
4092      # Simulate run fetch failing (e.g., run not found)
4093      mock_tracking_store.get_run.side_effect = Exception("Run not found")
4094  
4095      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4096          with app.test_client() as c:
4097              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4098              assert response.status_code == 200
4099  
4100              data = response.get_json()
4101              job = data["job"]
4102              assert job["state"]["status"] == "JOB_STATUS_COMPLETED"
4103              assert job["optimized_prompt_uri"] == "prompts:/my-prompt/2"
4104              # Metrics should not be present when run fetch fails
4105              assert "initial_eval_scores" not in job or job["initial_eval_scores"] == {}
4106  
4107  
4108  def test_get_prompt_optimization_job_failed_with_error(mock_tracking_store):
4109      mock_job = _create_mock_job(
4110          status_name="FAILED",
4111          result="Optimization failed: Invalid scorer",
4112      )
4113  
4114      mock_run = _create_mock_run()
4115      mock_tracking_store.get_run.return_value = mock_run
4116  
4117      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4118          with app.test_client() as c:
4119              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4120              assert response.status_code == 200
4121  
4122              data = response.get_json()
4123              job = data["job"]
4124              assert job["state"]["status"] == "JOB_STATUS_FAILED"
4125              assert "Optimization failed" in job["state"]["error_message"]
4126  
4127  
4128  def test_get_prompt_optimization_job_without_run_id(mock_tracking_store):
4129      mock_job = _create_mock_job(
4130          params={"experiment_id": "exp-123", "prompt_uri": "prompts:/my-prompt/1"}
4131      )
4132  
4133      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4134          with app.test_client() as c:
4135              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4136              assert response.status_code == 200
4137              data = response.get_json()
4138              job = data["job"]
4139              assert job["job_id"] == "job-123"
4140              assert job["experiment_id"] == "exp-123"
4141              assert "run_id" not in job  # run_id is not set
4142  
4143  
4144  def test_get_prompt_optimization_job_with_progress(mock_tracking_store):
4145      mock_job = _create_mock_job(
4146          status_name="RUNNING",
4147          params={
4148              "experiment_id": "exp-123",
4149              "prompt_uri": "prompts:/my-prompt/1",
4150              "run_id": "run-456",
4151              "optimizer_config": {"max_metric_calls": 200, "reflection_model": "openai:/gpt-4o"},
4152          },
4153      )
4154  
4155      mock_run = _create_mock_run(
4156          params={
4157              "source_prompt_uri": "prompts:/my-prompt/1",
4158              "optimizer_type": "gepa",
4159          },
4160          metrics={
4161              "total_metric_calls": 86,
4162              "eval_score": 0.75,
4163          },
4164      )
4165      mock_tracking_store.get_run.return_value = mock_run
4166  
4167      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4168          with app.test_client() as c:
4169              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4170              assert response.status_code == 200
4171  
4172              data = response.get_json()
4173              job = data["job"]
4174              assert job["state"]["status"] == "JOB_STATUS_IN_PROGRESS"
4175              # Progress should be 86 / 200 = 0.43
4176              assert job["state"]["metadata"]["progress"] == "0.43"
4177  
4178  
4179  def test_get_prompt_optimization_job_progress_capped_at_one(mock_tracking_store):
4180      mock_job = _create_mock_job(
4181          status_name="RUNNING",
4182          params={
4183              "experiment_id": "exp-123",
4184              "prompt_uri": "prompts:/my-prompt/1",
4185              "run_id": "run-456",
4186              "optimizer_config": {"max_metric_calls": 100, "reflection_model": "openai:/gpt-4o"},
4187          },
4188      )
4189  
4190      mock_run = _create_mock_run(
4191          metrics={
4192              "total_metric_calls": 150,  # Exceeds max_metric_calls
4193          },
4194      )
4195      mock_tracking_store.get_run.return_value = mock_run
4196  
4197      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4198          with app.test_client() as c:
4199              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4200              assert response.status_code == 200
4201  
4202              data = response.get_json()
4203              job = data["job"]
4204              # Progress should be capped at 1.0, not 1.5
4205              assert job["state"]["metadata"]["progress"] == "1.0"
4206  
4207  
4208  def test_get_prompt_optimization_job_no_progress_without_max_metric_calls(mock_tracking_store):
4209      mock_job = _create_mock_job(
4210          status_name="RUNNING",
4211          params={
4212              "experiment_id": "exp-123",
4213              "prompt_uri": "prompts:/my-prompt/1",
4214              "run_id": "run-456",
4215              "optimizer_config": {"reflection_model": "openai:/gpt-4o"},
4216          },
4217      )
4218  
4219      mock_run = _create_mock_run(
4220          metrics={
4221              "total_metric_calls": 50,
4222          },
4223      )
4224      mock_tracking_store.get_run.return_value = mock_run
4225  
4226      with mock.patch("mlflow.server.jobs.get_job", return_value=mock_job):
4227          with app.test_client() as c:
4228              response = c.get("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4229              assert response.status_code == 200
4230  
4231              data = response.get_json()
4232              job = data["job"]
4233              # Progress should NOT be set when max_metric_calls is not configured
4234              assert "progress" not in job["state"].get("status_details", {})
4235  
4236  
4237  def test_search_prompt_optimization_jobs_returns_multiple_jobs(mock_job_store):
4238      mock_jobs = [
4239          _create_mock_job(
4240              job_id="job-1",
4241              status_name="SUCCEEDED",
4242              result={"optimized_prompt_uri": "prompts:/opt/1"},
4243          ),
4244          _create_mock_job(job_id="job-2", status_name="RUNNING"),
4245          _create_mock_job(job_id="job-3", status_name="PENDING"),
4246      ]
4247      mock_job_store.list_jobs.return_value = iter(mock_jobs)
4248  
4249      with app.test_client() as c:
4250          response = c.post(
4251              "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search",
4252              json={"experiment_id": "exp-123"},
4253          )
4254          assert response.status_code == 200
4255  
4256          data = response.get_json()
4257          assert "jobs" in data
4258          assert len(data["jobs"]) == 3
4259  
4260          job_ids = [job["job_id"] for job in data["jobs"]]
4261          assert "job-1" in job_ids
4262          assert "job-2" in job_ids
4263          assert "job-3" in job_ids
4264  
4265      mock_job_store.list_jobs.assert_called_once_with(
4266          job_name="optimize_prompts",
4267          params={"experiment_id": "exp-123"},
4268      )
4269  
4270  
4271  def test_search_prompt_optimization_jobs_returns_empty_list(mock_job_store):
4272      mock_job_store.list_jobs.return_value = iter([])
4273  
4274      with app.test_client() as c:
4275          response = c.post(
4276              "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search",
4277              json={"experiment_id": "exp-456"},
4278          )
4279          assert response.status_code == 200
4280  
4281          data = response.get_json()
4282          assert data.get("jobs", []) == []
4283  
4284  
4285  def test_search_prompt_optimization_jobs_missing_experiment_id():
4286      with app.test_client() as c:
4287          response = c.post(
4288              "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search",
4289              json={},
4290          )
4291          assert response.status_code == 400
4292  
4293  
4294  def test_search_prompt_optimization_jobs_includes_succeeded_job_result(mock_job_store):
4295      mock_job = _create_mock_job(
4296          job_id="job-1",
4297          status_name="SUCCEEDED",
4298          result={"optimized_prompt_uri": "prompts:/optimized/1"},
4299      )
4300      mock_job_store.list_jobs.return_value = iter([mock_job])
4301  
4302      with app.test_client() as c:
4303          response = c.post(
4304              "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search",
4305              json={"experiment_id": "exp-123"},
4306          )
4307          assert response.status_code == 200
4308  
4309          data = response.get_json()
4310          assert len(data["jobs"]) == 1
4311          assert data["jobs"][0]["optimized_prompt_uri"] == "prompts:/optimized/1"
4312  
4313  
4314  def test_search_prompt_optimization_jobs_includes_failed_job_error(mock_job_store):
4315      mock_job = _create_mock_job(
4316          job_id="job-1",
4317          status_name="FAILED",
4318          result="Some error occurred",
4319      )
4320      mock_job_store.list_jobs.return_value = iter([mock_job])
4321  
4322      with app.test_client() as c:
4323          response = c.post(
4324              "/ajax-api/3.0/mlflow/prompt-optimization/jobs/search",
4325              json={"experiment_id": "exp-123"},
4326          )
4327          assert response.status_code == 200
4328  
4329          data = response.get_json()
4330          assert len(data["jobs"]) == 1
4331          assert "Some error occurred" in data["jobs"][0]["state"]["error_message"]
4332  
4333  
4334  def test_delete_prompt_optimization_job_success(mock_job_store, mock_tracking_store):
4335      mock_job = _create_mock_job(
4336          status_name="SUCCEEDED",
4337          result={"optimized_prompt_uri": "prompts:/optimized/1"},
4338      )
4339      mock_job_store.get_job.return_value = mock_job
4340      mock_tracking_store.get_run.return_value = mock.MagicMock()  # Run exists
4341  
4342      with app.test_client() as c:
4343          response = c.delete("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4344          assert response.status_code == 200
4345  
4346      mock_job_store.delete_jobs.assert_called_once_with(job_ids=["job-123"])
4347      mock_tracking_store.get_run.assert_called_once_with("run-456")
4348      mock_tracking_store.delete_run.assert_called_once_with("run-456")
4349  
4350  
4351  def test_delete_prompt_optimization_job_without_run_id(mock_job_store, mock_tracking_store):
4352      mock_job = _create_mock_job(
4353          params={"experiment_id": "exp-123", "prompt_uri": "prompts:/my-prompt/1"}
4354      )
4355      mock_job_store.get_job.return_value = mock_job
4356  
4357      with app.test_client() as c:
4358          response = c.delete("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4359          assert response.status_code == 200
4360  
4361      mock_job_store.delete_jobs.assert_called_once_with(job_ids=["job-123"])
4362      mock_tracking_store.delete_run.assert_not_called()
4363  
4364  
4365  def test_delete_prompt_optimization_job_skips_run_deletion_when_run_not_found(
4366      mock_job_store, mock_tracking_store
4367  ):
4368      mock_job = _create_mock_job(
4369          status_name="SUCCEEDED",
4370          result={"optimized_prompt_uri": "prompts:/optimized/1"},
4371      )
4372      mock_job_store.get_job.return_value = mock_job
4373      mock_tracking_store.get_run.side_effect = MlflowException("Run not found")
4374  
4375      with app.test_client() as c:
4376          response = c.delete("/ajax-api/3.0/mlflow/prompt-optimization/jobs/job-123")
4377          assert response.status_code == 200
4378  
4379      mock_job_store.delete_jobs.assert_called_once_with(job_ids=["job-123"])
4380      # delete_run should not be called since run doesn't exist
4381      mock_tracking_store.delete_run.assert_not_called()
4382  
4383  
4384  def test_get_workspace_scoped_repo_path_if_enabled_allows_legacy_default_artifacts(monkeypatch):
4385      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4386      with WorkspaceContext(DEFAULT_WORKSPACE_NAME):
4387          assert (
4388              _get_workspace_scoped_repo_path_if_enabled("1/legacy/artifact") == "1/legacy/artifact"
4389          )
4390  
4391  
4392  def test_get_workspace_scoped_repo_path_if_enabled_still_scopes_non_default(monkeypatch):
4393      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4394      with WorkspaceContext("team-blue"):
4395          scoped = _get_workspace_scoped_repo_path_if_enabled("2/new/artifact")
4396      assert scoped.startswith("workspaces/team-blue/2/new/artifact")
4397  
4398  
4399  def test_get_workspace_scoped_repo_path_if_enabled_prevents_cross_workspace_access(monkeypatch):
4400      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4401  
4402      with WorkspaceContext("team-a"):
4403          with pytest.raises(MlflowException, match="targets workspace 'team-b'"):
4404              _get_workspace_scoped_repo_path_if_enabled("workspaces/team-b/secret.txt")
4405  
4406          with pytest.raises(MlflowException, match="targets workspace 'other'"):
4407              _get_workspace_scoped_repo_path_if_enabled("workspaces/other/data/model.pkl")
4408  
4409  
4410  def test_get_workspace_scoped_repo_path_if_enabled_rejects_empty_workspace_in_path(monkeypatch):
4411      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4412  
4413      with WorkspaceContext("team-a"):
4414          with pytest.raises(MlflowException, match="must include a workspace name"):
4415              _get_workspace_scoped_repo_path_if_enabled("workspaces/")
4416  
4417          with pytest.raises(MlflowException, match="must include a workspace name"):
4418              _get_workspace_scoped_repo_path_if_enabled("workspaces//data.txt")
4419  
4420  
4421  def test_get_workspace_scoped_repo_path_if_enabled_allows_matching_workspace_prefix(monkeypatch):
4422      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4423  
4424      with WorkspaceContext("team-a"):
4425          result = _get_workspace_scoped_repo_path_if_enabled("workspaces/team-a/data.txt")
4426          assert result == "workspaces/team-a/data.txt"
4427  
4428          result = _get_workspace_scoped_repo_path_if_enabled("/workspaces/team-a/nested/path")
4429          assert result == "workspaces/team-a/nested/path"
4430  
4431  
4432  def test_get_workspace_scoped_repo_path_if_enabled_default_workspace_cross_access_blocked(
4433      monkeypatch,
4434  ):
4435      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4436  
4437      with WorkspaceContext(DEFAULT_WORKSPACE_NAME):
4438          result = _get_workspace_scoped_repo_path_if_enabled("legacy/artifact.txt")
4439          assert result == "legacy/artifact.txt"
4440  
4441          with pytest.raises(MlflowException, match="targets workspace 'team-b'"):
4442              _get_workspace_scoped_repo_path_if_enabled("workspaces/team-b/data.txt")
4443  
4444  
4445  def test_get_workspace_scoped_repo_path_if_enabled_requires_active_workspace(monkeypatch):
4446      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4447  
4448      with pytest.raises(MlflowException, match="Active workspace is required"):
4449          _get_workspace_scoped_repo_path_if_enabled("some/path")
4450  
4451  
4452  def test_get_artifact_handler_applies_workspace_scoping(monkeypatch):
4453      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4454      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
4455      monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
4456  
4457      mock_run = mock.MagicMock()
4458      mock_run.info.artifact_uri = "mlflow-artifacts:/exp1/run1/artifacts"
4459  
4460      mock_artifact_repo = mock.MagicMock()
4461      mock_artifact_repo.download_artifacts.return_value = "/tmp/artifact.txt"
4462  
4463      with (
4464          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4465          mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo,
4466          mock.patch("mlflow.server.handlers._send_artifact") as mock_send,
4467      ):
4468          mock_store.return_value.get_run.return_value = mock_run
4469          mock_repo.return_value = mock_artifact_repo
4470  
4471          with WorkspaceContext("team-blue"):
4472              with app.test_request_context(
4473                  method="GET", query_string={"run_id": "run1", "path": "model/weights.bin"}
4474              ):
4475                  get_artifact_handler()
4476  
4477          mock_send.assert_called_once()
4478          artifact_path = mock_send.call_args[0][1]
4479          assert artifact_path.startswith("workspaces/team-blue/")
4480  
4481  
4482  def test_get_artifact_handler_no_scoping_when_workspaces_disabled(monkeypatch):
4483      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "false")
4484      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
4485      monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
4486  
4487      mock_run = mock.MagicMock()
4488      mock_run.info.artifact_uri = "mlflow-artifacts:/exp1/run1/artifacts"
4489  
4490      mock_artifact_repo = mock.MagicMock()
4491  
4492      with (
4493          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4494          mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo,
4495          mock.patch("mlflow.server.handlers._send_artifact") as mock_send,
4496      ):
4497          mock_store.return_value.get_run.return_value = mock_run
4498          mock_repo.return_value = mock_artifact_repo
4499  
4500          with app.test_request_context(
4501              method="GET", query_string={"run_id": "run1", "path": "model/weights.bin"}
4502          ):
4503              get_artifact_handler()
4504  
4505          mock_send.assert_called_once()
4506          artifact_path = mock_send.call_args[0][1]
4507          assert not artifact_path.startswith("workspaces/")
4508  
4509  
4510  def test_get_model_version_artifact_handler_applies_workspace_scoping(monkeypatch):
4511      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4512      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
4513      monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
4514  
4515      mock_artifact_repo = mock.MagicMock()
4516  
4517      with (
4518          mock.patch("mlflow.server.handlers._get_model_registry_store") as mock_store,
4519          mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo,
4520          mock.patch("mlflow.server.handlers._send_artifact") as mock_send,
4521      ):
4522          mock_store.return_value.get_model_version_download_uri.return_value = (
4523              "mlflow-artifacts:/models/MyModel/1"
4524          )
4525          mock_repo.return_value = mock_artifact_repo
4526  
4527          with WorkspaceContext("team-red"):
4528              with app.test_request_context(
4529                  method="GET", query_string={"name": "MyModel", "version": "1", "path": "model.pkl"}
4530              ):
4531                  get_model_version_artifact_handler()
4532  
4533          mock_send.assert_called_once()
4534          artifact_path = mock_send.call_args[0][1]
4535          assert artifact_path.startswith("workspaces/team-red/")
4536  
4537  
4538  def test_get_logged_model_artifact_handler_applies_workspace_scoping(monkeypatch):
4539      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4540      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
4541      monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
4542  
4543      mock_logged_model = mock.MagicMock()
4544      mock_logged_model.artifact_location = "mlflow-artifacts:/exp1/run1/artifacts/model"
4545  
4546      mock_artifact_repo = mock.MagicMock()
4547  
4548      with (
4549          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4550          mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo,
4551          mock.patch("mlflow.server.handlers._send_artifact") as mock_send,
4552      ):
4553          mock_store.return_value.get_logged_model.return_value = mock_logged_model
4554          mock_repo.return_value = mock_artifact_repo
4555  
4556          with WorkspaceContext("team-green"):
4557              with app.test_request_context(
4558                  method="GET", query_string={"artifact_file_path": "MLmodel"}
4559              ):
4560                  get_logged_model_artifact_handler("model123")
4561  
4562          mock_send.assert_called_once()
4563          artifact_path = mock_send.call_args[0][1]
4564          assert artifact_path.startswith("workspaces/team-green/")
4565  
4566  
4567  def test_upload_artifact_handler_applies_workspace_scoping(monkeypatch):
4568      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4569      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
4570      monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
4571  
4572      mock_run = mock.MagicMock()
4573      mock_run.info.artifact_uri = "mlflow-artifacts:/exp1/run1/artifacts"
4574      mock_run.info.experiment_id = "exp1"
4575      mock_run.info.run_id = "run1"
4576  
4577      mock_artifact_repo = mock.MagicMock()
4578  
4579      with (
4580          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4581          mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo,
4582      ):
4583          mock_store.return_value.get_run.return_value = mock_run
4584          mock_repo.return_value = mock_artifact_repo
4585  
4586          with WorkspaceContext("team-purple"):
4587              with app.test_request_context(
4588                  method="POST",
4589                  query_string={"run_uuid": "run1", "path": "output.txt"},
4590                  data=b"test data",
4591              ):
4592                  upload_artifact_handler()
4593  
4594          mock_artifact_repo.log_artifact.assert_called_once()
4595          logged_path = mock_artifact_repo.log_artifact.call_args[0][1]
4596          assert logged_path.startswith("workspaces/team-purple/")
4597  
4598  
4599  def test_list_artifacts_for_proxied_run_artifact_root_applies_workspace_scoping(monkeypatch):
4600      from mlflow.store.artifact.artifact_repo import ArtifactRepository
4601  
4602      monkeypatch.setenv(MLFLOW_ENABLE_WORKSPACES.name, "true")
4603      monkeypatch.setenv(SERVE_ARTIFACTS_ENV_VAR, "true")
4604      monkeypatch.setenv(ARTIFACTS_DESTINATION_ENV_VAR, "s3://bucket")
4605  
4606      mock_artifact_repo = mock.MagicMock(spec=ArtifactRepository)
4607      mock_artifact_repo.list_artifacts.return_value = []
4608  
4609      with (
4610          mock.patch("mlflow.server.handlers._get_artifact_repo_mlflow_artifacts") as mock_repo,
4611          WorkspaceContext("team-orange"),
4612      ):
4613          mock_repo.return_value = mock_artifact_repo
4614  
4615          _list_artifacts_for_proxied_run_artifact_root(
4616              proxied_artifact_root="mlflow-artifacts:/exp1/run1/artifacts",
4617              relative_path="model",
4618          )
4619  
4620          mock_artifact_repo.list_artifacts.assert_called_once()
4621          listed_path = mock_artifact_repo.list_artifacts.call_args[0][0]
4622          assert listed_path.startswith("workspaces/team-orange/")
4623  
4624  
4625  # ==================== Budget Window Tests ====================
4626  
4627  
4628  def _make_budget_policy(
4629      budget_policy_id="bp-test",
4630      budget_amount=100.0,
4631      duration=None,
4632  ):
4633      return GatewayBudgetPolicy(
4634          budget_policy_id=budget_policy_id,
4635          budget_unit=BudgetUnit.USD,
4636          budget_amount=budget_amount,
4637          duration=duration or BudgetDuration(unit=BudgetDurationUnit.DAYS, value=1),
4638          target_scope=BudgetTargetScope.GLOBAL,
4639          budget_action=BudgetAction.ALERT,
4640          created_at=0,
4641          last_updated_at=0,
4642      )
4643  
4644  
4645  def test_list_budget_windows_empty():
4646      with (
4647          app.test_client() as c,
4648          mock.patch("mlflow.server.handlers.get_budget_tracker") as mock_tracker,
4649          mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"),
4650      ):
4651          mock_tracker.return_value.get_all_windows.return_value = []
4652          response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows")
4653  
4654      assert response.status_code == 200
4655      assert response.json.get("windows", []) == []
4656  
4657  
4658  def test_list_budget_windows_returns_window_data():
4659      tracker = InMemoryBudgetTracker()
4660      policy = _make_budget_policy(budget_policy_id="bp-1", budget_amount=50.0)
4661      tracker.refresh_policies([policy])
4662      tracker.record_cost(12.5)
4663  
4664      with (
4665          app.test_client() as c,
4666          mock.patch("mlflow.server.handlers.get_budget_tracker", return_value=tracker),
4667          mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"),
4668      ):
4669          response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows")
4670  
4671      assert response.status_code == 200
4672      data = response.json
4673      assert len(data["windows"]) == 1
4674      window = data["windows"][0]
4675      assert window["budget_policy_id"] == "bp-1"
4676      assert window["current_spend"] == 12.5
4677      min_ms = int(datetime(2000, 1, 1, tzinfo=timezone.utc).timestamp() * 1000)
4678      assert window["window_start_ms"] >= min_ms
4679      assert window["window_end_ms"] > window["window_start_ms"]
4680      # Policy uses duration_unit=DAYS, duration_value=1 → exactly 1 day
4681      assert window["window_end_ms"] - window["window_start_ms"] == 86_400_000
4682  
4683  
4684  def test_list_budget_windows_multiple_policies():
4685      tracker = InMemoryBudgetTracker()
4686      policy1 = _make_budget_policy(budget_policy_id="bp-1", budget_amount=100.0)
4687      policy2 = _make_budget_policy(budget_policy_id="bp-2", budget_amount=200.0)
4688      tracker.refresh_policies([policy1, policy2])
4689      tracker.record_cost(30.0)
4690  
4691      with (
4692          app.test_client() as c,
4693          mock.patch("mlflow.server.handlers.get_budget_tracker", return_value=tracker),
4694          mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"),
4695      ):
4696          response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows")
4697  
4698      assert response.status_code == 200
4699      data = response.json
4700      policy_ids = {w["budget_policy_id"] for w in data["windows"]}
4701      assert policy_ids == {"bp-1", "bp-2"}
4702      windows_by_id = {w["budget_policy_id"]: w for w in data["windows"]}
4703      assert windows_by_id["bp-1"]["current_spend"] == 30.0
4704      assert windows_by_id["bp-2"]["current_spend"] == 30.0
4705  
4706  
4707  def test_list_budget_windows_zero_spend():
4708      tracker = InMemoryBudgetTracker()
4709      policy = _make_budget_policy(budget_amount=100.0)
4710      tracker.refresh_policies([policy])
4711  
4712      with (
4713          app.test_client() as c,
4714          mock.patch("mlflow.server.handlers.get_budget_tracker", return_value=tracker),
4715          mock.patch("mlflow.server.handlers.maybe_refresh_budget_policies"),
4716      ):
4717          response = c.get("/ajax-api/3.0/mlflow/gateway/budgets/windows")
4718  
4719      assert response.status_code == 200
4720      window = response.json["windows"][0]
4721      assert window["budget_policy_id"] == "bp-test"
4722      assert window["current_spend"] == 0.0
4723  
4724  
4725  def test_create_issue_with_all_fields():
4726      request_message = CreateIssue()
4727      request_message.experiment_id = "exp-123"
4728      request_message.name = "High latency"
4729      request_message.description = "API calls are taking too long"
4730      request_message.status = "pending"
4731      request_message.source_run_id = "run-123"
4732      request_message.root_causes.extend(["Database query inefficiency", "Network latency"])
4733      request_message.categories.extend(["performance", "database"])
4734      request_message.severity = IssueSeverity.HIGH.value
4735      request_message.created_by = "user@example.com"
4736  
4737      issue = Issue(
4738          issue_id="iss-123",
4739          experiment_id="exp-123",
4740          name="High latency",
4741          description="API calls are taking too long",
4742          status=IssueStatus.PENDING,
4743          source_run_id="run-123",
4744          root_causes=["Database query inefficiency", "Network latency"],
4745          categories=["performance", "database"],
4746          severity=IssueSeverity.HIGH,
4747          created_timestamp=1234567890,
4748          last_updated_timestamp=1234567890,
4749          created_by="user@example.com",
4750      )
4751  
4752      with (
4753          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4754          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
4755      ):
4756          mock_store.return_value.create_issue.return_value = issue
4757  
4758          response = _create_issue()
4759  
4760          mock_store.return_value.create_issue.assert_called_once()
4761          call_kwargs = mock_store.return_value.create_issue.call_args[1]
4762          assert call_kwargs["experiment_id"] == "exp-123"
4763          assert call_kwargs["name"] == "High latency"
4764          assert call_kwargs["description"] == "API calls are taking too long"
4765          assert call_kwargs["status"] == IssueStatus.PENDING
4766          assert call_kwargs["source_run_id"] == "run-123"
4767          assert call_kwargs["root_causes"] == ["Database query inefficiency", "Network latency"]
4768          assert call_kwargs["categories"] == ["performance", "database"]
4769          assert call_kwargs["severity"] == IssueSeverity.HIGH.value
4770          assert call_kwargs["created_by"] == "user@example.com"
4771  
4772          json_response = json.loads(response.get_data())
4773          assert json_response["issue"]["issue_id"] == "iss-123"
4774          assert json_response["issue"]["root_causes"] == [
4775              "Database query inefficiency",
4776              "Network latency",
4777          ]
4778          assert json_response["issue"]["categories"] == ["performance", "database"]
4779  
4780  
4781  def test_create_issue_without_optional_fields():
4782      request_message = CreateIssue()
4783      request_message.experiment_id = "exp-456"
4784      request_message.name = "Error handling issue"
4785      request_message.description = "Errors are not being caught properly"
4786  
4787      issue = Issue(
4788          issue_id="iss-456",
4789          experiment_id="exp-456",
4790          name="Error handling issue",
4791          description="Errors are not being caught properly",
4792          status=IssueStatus.PENDING,
4793          created_timestamp=1234567890,
4794          last_updated_timestamp=1234567890,
4795      )
4796  
4797      with (
4798          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4799          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
4800      ):
4801          mock_store.return_value.create_issue.return_value = issue
4802  
4803          response = _create_issue()
4804  
4805          mock_store.return_value.create_issue.assert_called_once()
4806          call_kwargs = mock_store.return_value.create_issue.call_args[1]
4807          assert call_kwargs["source_run_id"] is None
4808          assert call_kwargs["root_causes"] is None
4809          assert "severity" not in call_kwargs
4810  
4811          json_response = json.loads(response.get_data())
4812          assert json_response["issue"]["issue_id"] == "iss-456"
4813  
4814  
4815  def test_create_issue_with_default_status():
4816      request_message = CreateIssue()
4817      request_message.experiment_id = "exp-789"
4818      request_message.name = "Test issue"
4819      request_message.description = "Test description"
4820  
4821      issue = Issue(
4822          issue_id="iss-789",
4823          experiment_id="exp-789",
4824          name="Test issue",
4825          description="Test description",
4826          status=IssueStatus.PENDING,
4827          created_timestamp=1234567890,
4828          last_updated_timestamp=1234567890,
4829      )
4830  
4831      with (
4832          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4833          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
4834      ):
4835          mock_store.return_value.create_issue.return_value = issue
4836  
4837          _create_issue()
4838  
4839          call_kwargs = mock_store.return_value.create_issue.call_args[1]
4840          # Status should not be in kwargs when not provided (store uses default)
4841          assert "status" not in call_kwargs
4842  
4843  
4844  def test_get_issue():
4845      issue = Issue(
4846          issue_id="iss-get-123",
4847          experiment_id="exp-123",
4848          name="Test issue",
4849          description="Test description",
4850          status=IssueStatus.RESOLVED,
4851          severity=IssueSeverity.HIGH,
4852          root_causes=["Root cause 1"],
4853          created_timestamp=1234567890,
4854          last_updated_timestamp=1234567890,
4855      )
4856  
4857      with mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store:
4858          mock_store.return_value.get_issue.return_value = issue
4859  
4860          with app.test_request_context():
4861              response = _get_issue("iss-get-123")
4862  
4863          mock_store.return_value.get_issue.assert_called_once_with("iss-get-123")
4864  
4865          json_response = json.loads(response.get_data())
4866          assert json_response["issue"]["issue_id"] == "iss-get-123"
4867          assert json_response["issue"]["name"] == "Test issue"
4868          assert json_response["issue"]["severity"] == "high"
4869          assert json_response["issue"]["root_causes"] == ["Root cause 1"]
4870  
4871  
4872  def test_get_issue_not_found():
4873      with mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store:
4874          mock_store.return_value.get_issue.side_effect = MlflowException(
4875              "Issue not found", error_code=RESOURCE_DOES_NOT_EXIST
4876          )
4877  
4878          with app.test_request_context():
4879              response = _get_issue("nonexistent-id")
4880  
4881          # The @catch_mlflow_exception decorator catches and returns error as JSON
4882          assert response.status_code == 404
4883          json_response = json.loads(response.get_data())
4884          assert json_response["error_code"] == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
4885          assert "Issue not found" in json_response["message"]
4886  
4887  
4888  def test_update_issue():
4889      request_message = UpdateIssue()
4890      request_message.issue_id = "iss-update-123"
4891      request_message.name = "Updated issue name"
4892      request_message.description = "Updated description"
4893      request_message.status = "resolved"
4894      request_message.severity = "medium"
4895  
4896      updated_issue = Issue(
4897          issue_id="iss-update-123",
4898          experiment_id="exp-123",
4899          name="Updated issue name",
4900          description="Updated description",
4901          status=IssueStatus.RESOLVED,
4902          severity=IssueSeverity.MEDIUM,
4903          created_timestamp=1234567890,
4904          last_updated_timestamp=1234567900,
4905      )
4906  
4907      with (
4908          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4909          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
4910      ):
4911          mock_store.return_value.update_issue.return_value = updated_issue
4912  
4913          response = _update_issue("iss-update-123")
4914  
4915          mock_store.return_value.update_issue.assert_called_once()
4916          call_kwargs = mock_store.return_value.update_issue.call_args[1]
4917          assert call_kwargs["issue_id"] == "iss-update-123"
4918          assert call_kwargs["name"] == "Updated issue name"
4919          assert call_kwargs["description"] == "Updated description"
4920          assert call_kwargs["status"] == IssueStatus.RESOLVED
4921          assert call_kwargs["severity"] == IssueSeverity.MEDIUM.value
4922  
4923          json_response = json.loads(response.get_data())
4924          assert json_response["issue"]["issue_id"] == "iss-update-123"
4925          assert json_response["issue"]["name"] == "Updated issue name"
4926          assert json_response["issue"]["severity"] == "medium"
4927  
4928  
4929  def test_search_issues_all():
4930      request_message = SearchIssues()
4931  
4932      issues = [
4933          Issue(
4934              issue_id="iss-1",
4935              experiment_id="exp-1",
4936              name="Issue 1",
4937              description="Description 1",
4938              status=IssueStatus.PENDING,
4939              created_timestamp=1234567890,
4940              last_updated_timestamp=1234567890,
4941          ),
4942          Issue(
4943              issue_id="iss-2",
4944              experiment_id="exp-1",
4945              name="Issue 2",
4946              description="Description 2",
4947              status=IssueStatus.RESOLVED,
4948              created_timestamp=1234567891,
4949              last_updated_timestamp=1234567891,
4950          ),
4951      ]
4952  
4953      with (
4954          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4955          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
4956      ):
4957          mock_store.return_value.search_issues.return_value = PagedList(issues, token="next-token")
4958  
4959          response = _search_issues()
4960  
4961          mock_store.return_value.search_issues.assert_called_once()
4962          call_kwargs = mock_store.return_value.search_issues.call_args[1]
4963          # max_results not specified in request, so it's not passed to store
4964          # The store will use its own default parameter value (SEARCH_ISSUES_DEFAULT_MAX_RESULTS)
4965          assert "max_results" not in call_kwargs
4966          assert call_kwargs["experiment_id"] is None
4967          assert call_kwargs["filter_string"] is None
4968  
4969          json_response = json.loads(response.get_data())
4970          assert len(json_response["issues"]) == 2
4971          assert json_response["issues"][0]["issue_id"] == "iss-1"
4972          assert json_response["issues"][1]["issue_id"] == "iss-2"
4973          assert json_response["next_page_token"] == "next-token"
4974  
4975  
4976  def test_search_issues_with_filters():
4977      request_message = SearchIssues()
4978      request_message.experiment_id = "exp-specific"
4979      request_message.filter_string = "status = 'resolved' AND source_run_id = 'run-specific'"
4980      request_message.max_results = 50
4981  
4982      issues = [
4983          Issue(
4984              issue_id="iss-filtered",
4985              experiment_id="exp-specific",
4986              name="Filtered issue",
4987              description="Description",
4988              status=IssueStatus.RESOLVED,
4989              source_run_id="run-specific",
4990              created_timestamp=1234567890,
4991              last_updated_timestamp=1234567890,
4992          ),
4993      ]
4994  
4995      with (
4996          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
4997          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
4998      ):
4999          mock_store.return_value.search_issues.return_value = PagedList(issues, token=None)
5000  
5001          response = _search_issues()
5002  
5003          call_kwargs = mock_store.return_value.search_issues.call_args[1]
5004          assert call_kwargs["experiment_id"] == "exp-specific"
5005          assert (
5006              call_kwargs["filter_string"] == "status = 'resolved' AND source_run_id = 'run-specific'"
5007          )
5008          assert call_kwargs["max_results"] == 50
5009  
5010          json_response = json.loads(response.get_data())
5011          assert len(json_response["issues"]) == 1
5012          assert json_response["issues"][0]["issue_id"] == "iss-filtered"
5013          assert json_response["next_page_token"] == ""
5014  
5015  
5016  def test_search_issues_with_pagination():
5017      request_message = SearchIssues()
5018      request_message.max_results = 10
5019      request_message.page_token = "token-123"
5020  
5021      issues = [
5022          Issue(
5023              issue_id=f"iss-{i}",
5024              experiment_id="exp-1",
5025              name=f"Issue {i}",
5026              description=f"Description {i}",
5027              status=IssueStatus.PENDING,
5028              created_timestamp=1234567890 + i,
5029              last_updated_timestamp=1234567890 + i,
5030          )
5031          for i in range(10)
5032      ]
5033  
5034      with (
5035          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
5036          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
5037      ):
5038          mock_store.return_value.search_issues.return_value = PagedList(issues, token="token-456")
5039  
5040          response = _search_issues()
5041  
5042          call_kwargs = mock_store.return_value.search_issues.call_args[1]
5043          assert call_kwargs["max_results"] == 10
5044          assert call_kwargs["page_token"] == "token-123"
5045  
5046          json_response = json.loads(response.get_data())
5047          assert len(json_response["issues"]) == 10
5048          assert json_response["next_page_token"] == "token-456"
5049  
5050  
5051  def test_search_issues_empty_results():
5052      request_message = SearchIssues()
5053  
5054      with (
5055          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
5056          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
5057      ):
5058          mock_store.return_value.search_issues.return_value = PagedList([], token=None)
5059  
5060          response = _search_issues()
5061  
5062          json_response = json.loads(response.get_data())
5063          assert len(json_response.get("issues", [])) == 0
5064          assert json_response["next_page_token"] == ""
5065  
5066  
5067  def test_search_issues_with_trace_count():
5068      request_message = SearchIssues()
5069      request_message.include_trace_count = True
5070  
5071      issues = [
5072          Issue(
5073              issue_id="iss-1",
5074              experiment_id="exp-1",
5075              name="Issue with traces",
5076              description="Has 2 traces",
5077              status=IssueStatus.PENDING,
5078              created_timestamp=1234567890,
5079              last_updated_timestamp=1234567890,
5080              trace_count=2,
5081          ),
5082          Issue(
5083              issue_id="iss-2",
5084              experiment_id="exp-1",
5085              name="Issue without traces",
5086              description="Has no traces",
5087              status=IssueStatus.PENDING,
5088              created_timestamp=1234567891,
5089              last_updated_timestamp=1234567891,
5090              trace_count=0,
5091          ),
5092      ]
5093  
5094      with (
5095          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
5096          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
5097      ):
5098          mock_store.return_value.search_issues.return_value = PagedList(issues, token=None)
5099  
5100          response = _search_issues()
5101  
5102          call_kwargs = mock_store.return_value.search_issues.call_args[1]
5103          assert call_kwargs["include_trace_count"] is True
5104  
5105          json_response = json.loads(response.get_data())
5106          assert len(json_response["issues"]) == 2
5107          assert json_response["issues"][0]["trace_count"] == 2
5108          assert json_response["issues"][1]["trace_count"] == 0
5109  
5110  
5111  def test_create_issue_with_empty_lists():
5112      request_message = CreateIssue()
5113      request_message.experiment_id = "exp-123"
5114      request_message.name = "Test issue"
5115      request_message.description = "Test description"
5116  
5117      issue = Issue(
5118          issue_id="iss-empty-lists",
5119          experiment_id="exp-123",
5120          name="Test issue",
5121          description="Test description",
5122          status=IssueStatus.PENDING,
5123          created_timestamp=1234567890,
5124          last_updated_timestamp=1234567890,
5125      )
5126  
5127      with (
5128          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
5129          mock.patch("mlflow.server.handlers._get_request_message", return_value=request_message),
5130      ):
5131          mock_store.return_value.create_issue.return_value = issue
5132  
5133          _create_issue()
5134  
5135          call_kwargs = mock_store.return_value.create_issue.call_args[1]
5136          # Empty lists should be passed as None
5137          assert call_kwargs["root_causes"] is None
5138  
5139  
5140  def test_invoke_issue_detection_handler_success(monkeypatch):
5141      monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "true")
5142  
5143      mock_job = JobEntity(
5144          job_id="job-123",
5145          creation_time=1234567890000,
5146          job_name="invoke_issue_detection",
5147          params='{"experiment_id": "exp-123"}',
5148          timeout=None,
5149          status=JobStatus.PENDING,
5150          result=None,
5151          retry_count=0,
5152          last_update_time=1234567890000,
5153          status_details=None,
5154      )
5155  
5156      mock_run_info = mock.MagicMock()
5157      mock_run_info.run_id = "run-123"
5158      mock_run = mock.MagicMock()
5159      mock_run.info = mock_run_info
5160  
5161      request_json = {
5162          "experiment_id": "exp-123",
5163          "trace_ids": ["trace-1", "trace-2"],
5164          "categories": ["correctness", "safety"],
5165          "provider": "openai",
5166          "model": "gpt-4o",
5167          "secret_id": "secret-123",
5168      }
5169  
5170      with (
5171          mock.patch("mlflow.server.handlers._get_tracking_store") as mock_store,
5172          mock.patch(
5173              "mlflow.genai.discovery.job._fetch_provider_credentials",
5174              return_value={"OPENAI_API_KEY": "test-key"},
5175          ) as mock_fetch_creds,
5176          mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job) as mock_submit_job,
5177          mock.patch("mlflow.start_run", return_value=mock_run),
5178          mock.patch("mlflow.set_tag"),
5179          mock.patch("mlflow.end_run"),
5180          app.test_client() as c,
5181      ):
5182          resp = c.post(
5183              "/ajax-api/3.0/mlflow/issues/invoke",
5184              json=request_json,
5185          )
5186          assert resp.status_code == 200
5187          json_response = resp.get_json()
5188  
5189          assert json_response["job_id"] == "job-123"
5190          assert json_response["run_id"] == "run-123"
5191  
5192          mock_fetch_creds.assert_called_once_with(mock_store.return_value, "openai", "secret-123")
5193          mock_submit_job.assert_called_once()
5194          call_kwargs = mock_submit_job.call_args.kwargs
5195          assert call_kwargs["params"]["experiment_id"] == "exp-123"
5196          assert call_kwargs["params"]["trace_ids"] == ["trace-1", "trace-2"]
5197          assert call_kwargs["params"]["categories"] == ["correctness", "safety"]
5198          assert call_kwargs["params"]["model"] == "openai:/gpt-4o"
5199          assert call_kwargs["extra_envs"] == {"OPENAI_API_KEY": "test-key"}
5200  
5201  
5202  def test_invoke_issue_detection_handler_with_endpoint(monkeypatch):
5203      monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "true")
5204  
5205      mock_job = JobEntity(
5206          job_id="job-456",
5207          creation_time=1234567890000,
5208          job_name="invoke_issue_detection",
5209          params='{"experiment_id": "exp-123"}',
5210          timeout=None,
5211          status=JobStatus.PENDING,
5212          result=None,
5213          retry_count=0,
5214          last_update_time=1234567890000,
5215          status_details=None,
5216      )
5217  
5218      mock_run_info = mock.MagicMock()
5219      mock_run_info.run_id = "run-456"
5220      mock_run = mock.MagicMock()
5221      mock_run.info = mock_run_info
5222  
5223      request_json = {
5224          "experiment_id": "exp-123",
5225          "trace_ids": ["trace-1"],
5226          "categories": ["correctness"],
5227          "provider": "openai",
5228          "endpoint_name": "my-endpoint",
5229          "secret_id": "secret-123",
5230      }
5231  
5232      with (
5233          mock.patch(
5234              "mlflow.genai.discovery.job._fetch_provider_credentials",
5235              return_value={"OPENAI_API_KEY": "test-key"},
5236          ),
5237          mock.patch("mlflow.server.jobs.submit_job", return_value=mock_job) as mock_submit_job,
5238          mock.patch("mlflow.start_run", return_value=mock_run),
5239          mock.patch("mlflow.set_tag"),
5240          mock.patch("mlflow.end_run"),
5241          app.test_client() as c,
5242      ):
5243          resp = c.post(
5244              "/ajax-api/3.0/mlflow/issues/invoke",
5245              json=request_json,
5246          )
5247          assert resp.status_code == 200
5248          json_response = resp.get_json()
5249  
5250          assert json_response["job_id"] == "job-456"
5251          assert json_response["run_id"] == "run-456"
5252  
5253          call_kwargs = mock_submit_job.call_args.kwargs
5254          assert call_kwargs["params"]["model"] == "gateway:/my-endpoint"
5255  
5256  
5257  def test_invoke_issue_detection_handler_missing_required_params(monkeypatch):
5258      monkeypatch.setenv("MLFLOW_SERVER_ENABLE_JOB_EXECUTION", "true")
5259  
5260      request_json = {
5261          "experiment_id": "exp-123",
5262          "trace_ids": ["trace-1"],
5263          "categories": ["correctness"],
5264          "provider": "openai",
5265          # Missing both 'model' and 'endpoint_name'
5266          "secret_id": "secret-123",
5267      }
5268  
5269      with (
5270          mock.patch(
5271              "mlflow.genai.discovery.job._fetch_provider_credentials",
5272              return_value={"OPENAI_API_KEY": "test-key"},
5273          ),
5274          app.test_client() as c,
5275      ):
5276          resp = c.post(
5277              "/ajax-api/3.0/mlflow/issues/invoke",
5278              json=request_json,
5279          )
5280          assert resp.status_code == 500
5281          json_response = resp.get_json()
5282          assert (
5283              "Either 'endpoint_name' or both 'provider' and 'model' must be provided"
5284              in json_response["message"]
5285          )
5286  
5287  
5288  def test_get_job_success(mock_job_store):
5289      mock_job = JobEntity(
5290          job_id="job-123",
5291          creation_time=1234567890000,
5292          job_name="invoke_issue_detection",
5293          params='{"experiment_id": "exp-123"}',
5294          timeout=None,
5295          status=JobStatus.SUCCEEDED,
5296          result='{"summary": "Found 3 issues", "issues": 3, "total_traces_analyzed": 10}',
5297          retry_count=0,
5298          last_update_time=1234567900000,
5299          status_details=None,
5300      )
5301  
5302      with (
5303          mock.patch("mlflow.server.jobs.get_job", return_value=mock_job),
5304          app.test_client() as c,
5305      ):
5306          resp = c.get("/ajax-api/3.0/mlflow/jobs/job-123")
5307          assert resp.status_code == 200
5308          json_response = resp.get_json()
5309  
5310          assert json_response["status"] == "SUCCEEDED"
5311          assert json_response["result"]["summary"] == "Found 3 issues"
5312          assert json_response["result"]["issues"] == 3
5313          assert json_response["result"]["total_traces_analyzed"] == 10
5314          assert json_response["status_details"] is None
5315  
5316  
5317  def test_get_job_pending(mock_job_store):
5318      mock_job = JobEntity(
5319          job_id="job-pending",
5320          creation_time=1234567890000,
5321          job_name="invoke_issue_detection",
5322          params='{"experiment_id": "exp-123"}',
5323          timeout=None,
5324          status=JobStatus.PENDING,
5325          result=None,
5326          retry_count=0,
5327          last_update_time=1234567890000,
5328          status_details=None,
5329      )
5330  
5331      with (
5332          mock.patch("mlflow.server.jobs.get_job", return_value=mock_job),
5333          app.test_client() as c,
5334      ):
5335          resp = c.get("/ajax-api/3.0/mlflow/jobs/job-pending")
5336          assert resp.status_code == 200
5337          json_response = resp.get_json()
5338  
5339          assert json_response["status"] == "PENDING"
5340          assert json_response["result"] is None
5341          assert json_response["status_details"] is None
5342  
5343  
5344  def test_cancel_job_success(mock_job_store):
5345      mock_job = JobEntity(
5346          job_id="job-123",
5347          creation_time=1234567890000,
5348          job_name="invoke_issue_detection",
5349          params='{"experiment_id": "exp-123"}',
5350          timeout=None,
5351          status=JobStatus.CANCELED,
5352          result=None,
5353          retry_count=0,
5354          last_update_time=1234567900000,
5355          status_details=None,
5356      )
5357  
5358      with (
5359          mock.patch("mlflow.server.jobs.cancel_job", return_value=mock_job) as mock_cancel,
5360          app.test_client() as c,
5361      ):
5362          resp = c.patch("/ajax-api/3.0/mlflow/jobs/cancel/job-123")
5363          assert resp.status_code == 200
5364          json_response = resp.get_json()
5365  
5366          assert json_response["status"] == "CANCELED"
5367          mock_cancel.assert_called_once_with("job-123")
5368  
5369  
5370  def test_get_rest_path_respects_static_prefix(monkeypatch):
5371      # Without prefix, both return bare paths
5372      assert _get_rest_path("/mlflow/experiments/search") == "/api/2.0/mlflow/experiments/search"
5373      assert _get_ajax_path("/mlflow/experiments/search") == "/ajax-api/2.0/mlflow/experiments/search"
5374  
5375      # With prefix, both should include the prefix
5376      monkeypatch.setenv(STATIC_PREFIX_ENV_VAR, "/myapp")
5377      assert (
5378          _get_rest_path("/mlflow/experiments/search") == "/myapp/api/2.0/mlflow/experiments/search"
5379      )
5380      assert (
5381          _get_ajax_path("/mlflow/experiments/search")
5382          == "/myapp/ajax-api/2.0/mlflow/experiments/search"
5383      )