/ tests / tracking / test_rest_tracking.py
test_rest_tracking.py
   1  """
   2  Integration test which starts a local Tracking Server on an ephemeral port,
   3  and ensures we can use the tracking API to communicate with it.
   4  """
   5  
   6  import json
   7  import logging
   8  import math
   9  import os
  10  import pathlib
  11  import posixpath
  12  import subprocess
  13  import sys
  14  import time
  15  import urllib.parse
  16  from dataclasses import asdict
  17  from io import StringIO
  18  from pathlib import Path
  19  from unittest import mock
  20  
  21  import flask
  22  import pandas as pd
  23  import pytest
  24  import requests
  25  from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
  26  
  27  import mlflow.experiments
  28  import mlflow.pyfunc
  29  from mlflow import MlflowClient
  30  from mlflow.artifacts import download_artifacts
  31  from mlflow.data.pandas_dataset import from_pandas
  32  from mlflow.entities import (
  33      Dataset,
  34      DatasetInput,
  35      FallbackConfig,
  36      FallbackStrategy,
  37      GatewayEndpointModelConfig,
  38      GatewayModelLinkageType,
  39      GatewayResourceType,
  40      InputTag,
  41      IssueSeverity,
  42      IssueStatus,
  43      Metric,
  44      Param,
  45      RoutingStrategy,
  46      RunInputs,
  47      RunTag,
  48      Span,
  49      SpanEvent,
  50      SpanStatusCode,
  51      ViewType,
  52  )
  53  from mlflow.entities.logged_model_input import LoggedModelInput
  54  from mlflow.entities.logged_model_output import LoggedModelOutput
  55  from mlflow.entities.logged_model_status import LoggedModelStatus
  56  from mlflow.entities.span import SpanAttributeKey
  57  from mlflow.entities.trace_data import TraceData
  58  from mlflow.entities.trace_info import TraceInfo
  59  from mlflow.entities.trace_location import TraceLocation
  60  from mlflow.entities.trace_metrics import (
  61      AggregationType,
  62      MetricAggregation,
  63      MetricViewType,
  64  )
  65  from mlflow.entities.trace_state import TraceState
  66  from mlflow.entities.trace_status import TraceStatus
  67  from mlflow.environment_variables import (
  68      _MLFLOW_GO_STORE_TESTING,
  69      MLFLOW_SERVER_GRAPHQL_MAX_ALIASES,
  70      MLFLOW_SERVER_GRAPHQL_MAX_ROOT_FIELDS,
  71      MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT,
  72  )
  73  from mlflow.exceptions import MlflowException, RestException
  74  from mlflow.genai.datasets import (
  75      add_dataset_to_experiments,
  76      create_dataset,
  77      remove_dataset_from_experiments,
  78  )
  79  from mlflow.models import Model
  80  from mlflow.protos.databricks_pb2 import RESOURCE_DOES_NOT_EXIST, ErrorCode
  81  from mlflow.server import handlers
  82  from mlflow.server.fastapi_app import app
  83  from mlflow.server.handlers import initialize_backend_stores
  84  from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore
  85  from mlflow.tracing.analysis import TraceFilterCorrelationResult
  86  from mlflow.tracing.client import TracingClient
  87  from mlflow.tracing.constant import (
  88      TRACE_SCHEMA_VERSION_KEY,
  89      TraceMetricDimensionKey,
  90      TraceMetricKey,
  91  )
  92  from mlflow.tracing.utils import build_otel_context
  93  from mlflow.utils import mlflow_tags
  94  from mlflow.utils.file_utils import TempDir, path_to_local_file_uri
  95  from mlflow.utils.mlflow_tags import (
  96      MLFLOW_DATASET_CONTEXT,
  97      MLFLOW_GIT_COMMIT,
  98      MLFLOW_PARENT_RUN_ID,
  99      MLFLOW_PROJECT_ENTRY_POINT,
 100      MLFLOW_SOURCE_NAME,
 101      MLFLOW_SOURCE_TYPE,
 102      MLFLOW_USER,
 103  )
 104  from mlflow.utils.os import is_windows
 105  from mlflow.utils.proto_json_utils import message_to_json
 106  from mlflow.utils.time import get_current_time_millis
 107  
 108  from tests.helper_functions import get_safe_port
 109  from tests.integration.utils import invoke_cli_runner
 110  from tests.tracking.integration_test_utils import (
 111      ServerThread,
 112      _init_server,
 113      _send_rest_tracking_post_request,
 114  )
 115  
 116  _logger = logging.getLogger(__name__)
 117  
 118  
 119  @pytest.fixture(params=["file", "sqlalchemy"])
 120  def store_type(request):
 121      """Provides the store type for parameterized tests."""
 122      if request.param == "file":
 123          pytest.skip("FileStore is no longer supported.")
 124      return request.param
 125  
 126  
 127  @pytest.fixture
 128  def mlflow_client(store_type: str, tmp_path: Path, db_uri: str, monkeypatch):
 129      """Provides an MLflow Tracking API client pointed at the local tracking server."""
 130      # Set passphrase for secrets management (required for encryption)
 131      monkeypatch.setenv(
 132          "MLFLOW_CRYPTO_KEK_PASSPHRASE", "test-passphrase-at-least-32-characters-long"
 133      )
 134  
 135      if store_type == "file":
 136          backend_uri = tmp_path.joinpath("file").as_uri()
 137      elif store_type == "sqlalchemy":
 138          backend_uri = db_uri
 139  
 140      # Force-reset backend stores before each test.
 141      handlers._tracking_store = None
 142      handlers._model_registry_store = None
 143      initialize_backend_stores(backend_uri, default_artifact_root=tmp_path.as_uri())
 144  
 145      with ServerThread(app, get_safe_port()) as url:
 146          yield MlflowClient(url)
 147  
 148  
 149  @pytest.fixture
 150  def mlflow_client_with_secrets(tmp_path: Path, monkeypatch):
 151      """Provides an MLflow Tracking API client with fresh database for secrets management.
 152  
 153      Creates a fresh SQLite database for each test to avoid encryption state pollution.
 154      This is necessary because the KEK encryption state can persist across tests when
 155      using a shared cached database.
 156      """
 157      # Set passphrase for secrets management (required for encryption)
 158      monkeypatch.setenv(
 159          "MLFLOW_CRYPTO_KEK_PASSPHRASE", "test-passphrase-at-least-32-characters-long"
 160      )
 161  
 162      # Create fresh database for this test (not using cached_db)
 163      backend_uri = f"sqlite:///{tmp_path}/mlflow.db"
 164      artifact_uri = (tmp_path / "artifacts").as_uri()
 165  
 166      # Initialize the store (which creates tables)
 167      store = SqlAlchemyStore(backend_uri, artifact_uri)
 168      store.engine.dispose()
 169  
 170      # Force-reset backend stores before each test
 171      handlers._tracking_store = None
 172      handlers._model_registry_store = None
 173      initialize_backend_stores(backend_uri, default_artifact_root=artifact_uri)
 174  
 175      with ServerThread(app, get_safe_port()) as url:
 176          yield MlflowClient(url)
 177  
 178  
 179  @pytest.fixture
 180  def cli_env(mlflow_client):
 181      """Provides an environment for the MLflow CLI pointed at the local tracking server."""
 182      return {
 183          "LC_ALL": "en_US.UTF-8",
 184          "LANG": "en_US.UTF-8",
 185          "MLFLOW_TRACKING_URI": mlflow_client.tracking_uri,
 186      }
 187  
 188  
 189  def create_experiments(client, names):
 190      return [client.create_experiment(n) for n in names]
 191  
 192  
 193  def test_create_get_search_experiment(mlflow_client):
 194      experiment_id = mlflow_client.create_experiment(
 195          "My Experiment",
 196          artifact_location="my_location",
 197          tags={"key1": "val1", "key2": "val2"},
 198      )
 199      exp = mlflow_client.get_experiment(experiment_id)
 200      assert exp.name == "My Experiment"
 201      if is_windows():
 202          assert exp.artifact_location == pathlib.Path.cwd().joinpath("my_location").as_uri()
 203      else:
 204          assert exp.artifact_location == str(pathlib.Path.cwd().joinpath("my_location"))
 205      assert len(exp.tags) == 2
 206      assert exp.tags["key1"] == "val1"
 207      assert exp.tags["key2"] == "val2"
 208  
 209      experiments = mlflow_client.search_experiments()
 210      assert {e.name for e in experiments} == {"My Experiment", "Default"}
 211      mlflow_client.delete_experiment(experiment_id)
 212      assert {e.name for e in mlflow_client.search_experiments()} == {"Default"}
 213      assert {e.name for e in mlflow_client.search_experiments(view_type=ViewType.ACTIVE_ONLY)} == {
 214          "Default"
 215      }
 216      assert {e.name for e in mlflow_client.search_experiments(view_type=ViewType.DELETED_ONLY)} == {
 217          "My Experiment"
 218      }
 219      assert {e.name for e in mlflow_client.search_experiments(view_type=ViewType.ALL)} == {
 220          "My Experiment",
 221          "Default",
 222      }
 223      active_exps_paginated = mlflow_client.search_experiments(max_results=1)
 224      assert {e.name for e in active_exps_paginated} == {"Default"}
 225      assert active_exps_paginated.token is None
 226  
 227      all_exps_paginated = mlflow_client.search_experiments(max_results=1, view_type=ViewType.ALL)
 228      first_page_names = {e.name for e in all_exps_paginated}
 229      all_exps_second_page = mlflow_client.search_experiments(
 230          max_results=1, view_type=ViewType.ALL, page_token=all_exps_paginated.token
 231      )
 232      second_page_names = {e.name for e in all_exps_second_page}
 233      assert len(first_page_names) == 1
 234      assert len(second_page_names) == 1
 235      assert first_page_names.union(second_page_names) == {"Default", "My Experiment"}
 236  
 237  
 238  def test_create_experiment_validation(mlflow_client):
 239      def assert_bad_request(payload, expected_error_message):
 240          response = _send_rest_tracking_post_request(
 241              mlflow_client.tracking_uri,
 242              "/api/2.0/mlflow/experiments/create",
 243              payload,
 244          )
 245          assert response.status_code == 400
 246          assert expected_error_message in response.text
 247  
 248      assert_bad_request(
 249          {
 250              "name": 123,
 251          },
 252          "Invalid value 123 for parameter 'name'",
 253      )
 254      assert_bad_request({}, "Missing value for required parameter 'name'.")
 255      assert_bad_request(
 256          {
 257              "name": "experiment name",
 258              "artifact_location": 9.0,
 259              "tags": [{"key": "key", "value": "value"}],
 260          },
 261          "Invalid value 9.0 for parameter 'artifact_location'",
 262      )
 263      assert_bad_request(
 264          {
 265              "name": "experiment name",
 266              "artifact_location": "my_location",
 267              "tags": "5",
 268          },
 269          "Invalid value \\\"5\\\" for parameter 'tags'",
 270      )
 271  
 272  
 273  def test_delete_restore_experiment(mlflow_client):
 274      experiment_id = mlflow_client.create_experiment("Deleterious")
 275      assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active"
 276      mlflow_client.delete_experiment(experiment_id)
 277      assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "deleted"
 278      mlflow_client.restore_experiment(experiment_id)
 279      assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active"
 280  
 281  
 282  def test_delete_restore_experiment_cli(mlflow_client, cli_env):
 283      experiment_name = "DeleteriousCLI"
 284      invoke_cli_runner(
 285          mlflow.experiments.commands,
 286          ["create", "--experiment-name", experiment_name],
 287          env=cli_env,
 288      )
 289      experiment_id = mlflow_client.get_experiment_by_name(experiment_name).experiment_id
 290      assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active"
 291      invoke_cli_runner(
 292          mlflow.experiments.commands, ["delete", "-x", str(experiment_id)], env=cli_env
 293      )
 294      assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "deleted"
 295      invoke_cli_runner(
 296          mlflow.experiments.commands, ["restore", "-x", str(experiment_id)], env=cli_env
 297      )
 298      assert mlflow_client.get_experiment(experiment_id).lifecycle_stage == "active"
 299  
 300  
 301  def test_rename_experiment(mlflow_client):
 302      experiment_id = mlflow_client.create_experiment("BadName")
 303      assert mlflow_client.get_experiment(experiment_id).name == "BadName"
 304      mlflow_client.rename_experiment(experiment_id, "GoodName")
 305      assert mlflow_client.get_experiment(experiment_id).name == "GoodName"
 306  
 307  
 308  def test_rename_experiment_cli(mlflow_client, cli_env):
 309      bad_experiment_name = "CLIBadName"
 310      good_experiment_name = "CLIGoodName"
 311  
 312      invoke_cli_runner(
 313          mlflow.experiments.commands, ["create", "-n", bad_experiment_name], env=cli_env
 314      )
 315      experiment_id = mlflow_client.get_experiment_by_name(bad_experiment_name).experiment_id
 316      assert mlflow_client.get_experiment(experiment_id).name == bad_experiment_name
 317      invoke_cli_runner(
 318          mlflow.experiments.commands,
 319          [
 320              "rename",
 321              "--experiment-id",
 322              str(experiment_id),
 323              "--new-name",
 324              good_experiment_name,
 325          ],
 326          env=cli_env,
 327      )
 328      assert mlflow_client.get_experiment(experiment_id).name == good_experiment_name
 329  
 330  
 331  @pytest.mark.parametrize("parent_run_id_kwarg", [None, "my-parent-id"])
 332  def test_create_run_all_args(mlflow_client, parent_run_id_kwarg):
 333      user = "username"
 334      source_name = "Hello"
 335      entry_point = "entry"
 336      source_version = "abc"
 337      create_run_kwargs = {
 338          "start_time": 456,
 339          "run_name": "my name",
 340          "tags": {
 341              MLFLOW_USER: user,
 342              MLFLOW_SOURCE_TYPE: "LOCAL",
 343              MLFLOW_SOURCE_NAME: source_name,
 344              MLFLOW_PROJECT_ENTRY_POINT: entry_point,
 345              MLFLOW_GIT_COMMIT: source_version,
 346              MLFLOW_PARENT_RUN_ID: "7",
 347              "my": "tag",
 348              "other": "tag",
 349          },
 350      }
 351      experiment_id = mlflow_client.create_experiment(
 352          f"Run A Lot (parent_run_id={parent_run_id_kwarg})"
 353      )
 354      created_run = mlflow_client.create_run(experiment_id, **create_run_kwargs)
 355      run_id = created_run.info.run_id
 356      _logger.info(f"Run id={run_id}")
 357      fetched_run = mlflow_client.get_run(run_id)
 358      for run in [created_run, fetched_run]:
 359          assert run.info.run_id == run_id
 360          assert run.info.experiment_id == experiment_id
 361          assert run.info.user_id == user
 362          assert run.info.start_time == create_run_kwargs["start_time"]
 363          assert run.info.run_name == "my name"
 364          for tag in create_run_kwargs["tags"]:
 365              assert tag in run.data.tags
 366          assert run.data.tags.get(MLFLOW_USER) == user
 367          assert run.data.tags.get(MLFLOW_PARENT_RUN_ID) == parent_run_id_kwarg or "7"
 368          assert [run.info for run in mlflow_client.search_runs([experiment_id])] == [run.info]
 369  
 370  
 371  def test_create_run_defaults(mlflow_client):
 372      experiment_id = mlflow_client.create_experiment("Run A Little")
 373      created_run = mlflow_client.create_run(experiment_id)
 374      run_id = created_run.info.run_id
 375      run = mlflow_client.get_run(run_id)
 376      assert run.info.run_id == run_id
 377      assert run.info.experiment_id == experiment_id
 378      assert run.info.user_id == "unknown"
 379  
 380  
 381  def test_log_metrics_params_tags(mlflow_client):
 382      experiment_id = mlflow_client.create_experiment("Oh My")
 383      created_run = mlflow_client.create_run(experiment_id)
 384      run_id = created_run.info.run_id
 385      mlflow_client.log_metric(run_id, key="metric", value=123.456, timestamp=789, step=2)
 386      mlflow_client.log_metric(run_id, key="nan_metric", value=float("nan"))
 387      mlflow_client.log_metric(run_id, key="inf_metric", value=float("inf"))
 388      mlflow_client.log_metric(run_id, key="-inf_metric", value=-float("inf"))
 389      mlflow_client.log_metric(run_id, key="stepless-metric", value=987.654, timestamp=321)
 390      mlflow_client.log_param(run_id, "param", "value")
 391      mlflow_client.set_tag(run_id, "taggity", "do-dah")
 392      run = mlflow_client.get_run(run_id)
 393      assert run.data.metrics.get("metric") == 123.456
 394      assert math.isnan(run.data.metrics.get("nan_metric"))
 395      assert run.data.metrics.get("inf_metric") >= 1.7976931348623157e308
 396      assert run.data.metrics.get("-inf_metric") <= -1.7976931348623157e308
 397      assert run.data.metrics.get("stepless-metric") == 987.654
 398      assert run.data.params.get("param") == "value"
 399      assert run.data.tags.get("taggity") == "do-dah"
 400      metric_history0 = mlflow_client.get_metric_history(run_id, "metric")
 401      assert len(metric_history0) == 1
 402      metric0 = metric_history0[0]
 403      assert metric0.key == "metric"
 404      assert metric0.value == 123.456
 405      assert metric0.timestamp == 789
 406      assert metric0.step == 2
 407      metric_history1 = mlflow_client.get_metric_history(run_id, "stepless-metric")
 408      assert len(metric_history1) == 1
 409      metric1 = metric_history1[0]
 410      assert metric1.key == "stepless-metric"
 411      assert metric1.value == 987.654
 412      assert metric1.timestamp == 321
 413      assert metric1.step == 0
 414  
 415      metric_history = mlflow_client.get_metric_history(run_id, "a_test_accuracy")
 416      assert metric_history == []
 417  
 418  
 419  def test_log_metric_validation(mlflow_client):
 420      experiment_id = mlflow_client.create_experiment("metrics validation")
 421      created_run = mlflow_client.create_run(experiment_id)
 422      run_id = created_run.info.run_id
 423  
 424      def assert_bad_request(payload, expected_error_message):
 425          response = _send_rest_tracking_post_request(
 426              mlflow_client.tracking_uri,
 427              "/api/2.0/mlflow/runs/log-metric",
 428              payload,
 429          )
 430          assert response.status_code == 400
 431          assert expected_error_message in response.text
 432  
 433      assert_bad_request(
 434          {
 435              "run_id": 31,
 436              "key": "metric",
 437              "value": 41,
 438              "timestamp": 59,
 439              "step": 26,
 440          },
 441          "Invalid value 31 for parameter 'run_id' supplied",
 442      )
 443      assert_bad_request(
 444          {
 445              "run_id": run_id,
 446              "key": 31,
 447              "value": 41,
 448              "timestamp": 59,
 449              "step": 26,
 450          },
 451          "Invalid value 31 for parameter 'key' supplied",
 452      )
 453      assert_bad_request(
 454          {
 455              "run_id": run_id,
 456              "key": "foo",
 457              "value": 31,
 458              "timestamp": 59,
 459              "step": "foo",
 460          },
 461          "Invalid value \\\"foo\\\" for parameter 'step' supplied",
 462      )
 463      assert_bad_request(
 464          {
 465              "run_id": run_id,
 466              "key": "foo",
 467              "value": 31,
 468              "timestamp": "foo",
 469              "step": 41,
 470          },
 471          "Invalid value \\\"foo\\\" for parameter 'timestamp' supplied",
 472      )
 473      assert_bad_request(
 474          {
 475              "run_id": None,
 476              "key": "foo",
 477              "value": 31,
 478              "timestamp": 59,
 479              "step": 41,
 480          },
 481          "Missing value for required parameter 'run_id'",
 482      )
 483      assert_bad_request(
 484          {
 485              "run_id": run_id,
 486              # Missing key
 487              "value": 31,
 488              "timestamp": 59,
 489              "step": 41,
 490          },
 491          "Missing value for required parameter 'key'",
 492      )
 493      assert_bad_request(
 494          {
 495              "run_id": run_id,
 496              "key": None,
 497              "value": 31,
 498              "timestamp": 59,
 499              "step": 41,
 500          },
 501          "Missing value for required parameter 'key'",
 502      )
 503  
 504  
 505  def test_log_metric_model(mlflow_client: MlflowClient):
 506      experiment_id = mlflow_client.create_experiment("metrics validation")
 507      run = mlflow_client.create_run(experiment_id)
 508      model = mlflow_client.create_logged_model(experiment_id)
 509      mlflow_client.log_metric(
 510          run.info.run_id,
 511          key="metric",
 512          value=0.5,
 513          timestamp=123456789,
 514          step=1,
 515          dataset_name="name",
 516          dataset_digest="digest",
 517          model_id=model.model_id,
 518      )
 519  
 520      model = mlflow_client.get_logged_model(model.model_id)
 521      assert model.metrics == [
 522          Metric(
 523              key="metric",
 524              value=0.5,
 525              timestamp=123456789,
 526              step=1,
 527              model_id=model.model_id,
 528              dataset_name="name",
 529              dataset_digest="digest",
 530              run_id=run.info.run_id,
 531          )
 532      ]
 533  
 534  
 535  def test_log_param_validation(mlflow_client):
 536      experiment_id = mlflow_client.create_experiment("params validation")
 537      created_run = mlflow_client.create_run(experiment_id)
 538      run_id = created_run.info.run_id
 539  
 540      def assert_bad_request(payload, expected_error_message):
 541          response = _send_rest_tracking_post_request(
 542              mlflow_client.tracking_uri,
 543              "/api/2.0/mlflow/runs/log-parameter",
 544              payload,
 545          )
 546          assert response.status_code == 400
 547          assert expected_error_message in response.text
 548  
 549      assert_bad_request(
 550          {
 551              "run_id": 31,
 552              "key": "param",
 553              "value": 41,
 554          },
 555          "Invalid value 31 for parameter 'run_id' supplied",
 556      )
 557      assert_bad_request(
 558          {
 559              "run_id": run_id,
 560              "key": 31,
 561              "value": 41,
 562          },
 563          "Invalid value 31 for parameter 'key' supplied",
 564      )
 565  
 566  
 567  def test_log_param_with_empty_string_as_value(mlflow_client):
 568      experiment_id = mlflow_client.create_experiment(
 569          test_log_param_with_empty_string_as_value.__name__
 570      )
 571      created_run = mlflow_client.create_run(experiment_id)
 572      run_id = created_run.info.run_id
 573  
 574      mlflow_client.log_param(run_id, "param_key", "")
 575      assert {"param_key": ""}.items() <= mlflow_client.get_run(run_id).data.params.items()
 576  
 577  
 578  def test_set_tag_with_empty_string_as_value(mlflow_client):
 579      experiment_id = mlflow_client.create_experiment(
 580          test_set_tag_with_empty_string_as_value.__name__
 581      )
 582      created_run = mlflow_client.create_run(experiment_id)
 583      run_id = created_run.info.run_id
 584  
 585      mlflow_client.set_tag(run_id, "tag_key", "")
 586      assert {"tag_key": ""}.items() <= mlflow_client.get_run(run_id).data.tags.items()
 587  
 588  
 589  def test_log_batch_containing_params_and_tags_with_empty_string_values(mlflow_client):
 590      experiment_id = mlflow_client.create_experiment(
 591          test_log_batch_containing_params_and_tags_with_empty_string_values.__name__
 592      )
 593      created_run = mlflow_client.create_run(experiment_id)
 594      run_id = created_run.info.run_id
 595  
 596      mlflow_client.log_batch(
 597          run_id=run_id,
 598          params=[Param("param_key", "")],
 599          tags=[RunTag("tag_key", "")],
 600      )
 601      assert {"param_key": ""}.items() <= mlflow_client.get_run(run_id).data.params.items()
 602      assert {"tag_key": ""}.items() <= mlflow_client.get_run(run_id).data.tags.items()
 603  
 604  
 605  def test_set_tag_validation(mlflow_client):
 606      experiment_id = mlflow_client.create_experiment("tags validation")
 607      created_run = mlflow_client.create_run(experiment_id)
 608      run_id = created_run.info.run_id
 609  
 610      def assert_bad_request(payload, expected_error_message):
 611          response = _send_rest_tracking_post_request(
 612              mlflow_client.tracking_uri,
 613              "/api/2.0/mlflow/runs/set-tag",
 614              payload,
 615          )
 616          assert response.status_code == 400
 617          assert expected_error_message in response.text
 618  
 619      assert_bad_request(
 620          {
 621              "run_id": 31,
 622              "key": "tag",
 623              "value": 41,
 624          },
 625          "Invalid value 31 for parameter 'run_id' supplied",
 626      )
 627      assert_bad_request(
 628          {
 629              "run_id": run_id,
 630              "key": "param",
 631              "value": 41,
 632          },
 633          "Invalid value 41 for parameter 'value' supplied",
 634      )
 635      assert_bad_request(
 636          {
 637              "run_id": run_id,
 638              # Missing key
 639              "value": "value",
 640          },
 641          "Missing value for required parameter 'key'",
 642      )
 643  
 644      response = _send_rest_tracking_post_request(
 645          mlflow_client.tracking_uri,
 646          "/api/2.0/mlflow/runs/set-tag",
 647          {
 648              "run_uuid": run_id,
 649              "key": "key",
 650              "value": "value",
 651          },
 652      )
 653      assert response.status_code == 200
 654  
 655  
 656  def test_path_validation(mlflow_client):
 657      experiment_id = mlflow_client.create_experiment("tags validation")
 658      created_run = mlflow_client.create_run(experiment_id)
 659      run_id = created_run.info.run_id
 660      invalid_path = "../path"
 661  
 662      def assert_response(resp):
 663          assert resp.status_code == 400
 664          body = response.json()
 665          assert body["error_code"] == "INVALID_PARAMETER_VALUE"
 666          assert body["message"] == "Invalid path"
 667  
 668      response = requests.get(
 669          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/artifacts/list",
 670          params={"run_id": run_id, "path": invalid_path},
 671      )
 672      assert_response(response)
 673  
 674      response = requests.get(
 675          f"{mlflow_client.tracking_uri}/get-artifact",
 676          params={"run_id": run_id, "path": invalid_path},
 677      )
 678      assert_response(response)
 679  
 680      response = requests.get(
 681          f"{mlflow_client.tracking_uri}//model-versions/get-artifact",
 682          params={"name": "model", "version": 1, "path": invalid_path},
 683      )
 684      assert_response(response)
 685  
 686  
 687  def test_set_experiment_tag(mlflow_client):
 688      experiment_id = mlflow_client.create_experiment("SetExperimentTagTest")
 689      mlflow_client.set_experiment_tag(experiment_id, "dataset", "imagenet1K")
 690      experiment = mlflow_client.get_experiment(experiment_id)
 691      assert "dataset" in experiment.tags
 692      assert experiment.tags["dataset"] == "imagenet1K"
 693      # test that updating a tag works
 694      mlflow_client.set_experiment_tag(experiment_id, "dataset", "birdbike")
 695      experiment = mlflow_client.get_experiment(experiment_id)
 696      assert "dataset" in experiment.tags
 697      assert experiment.tags["dataset"] == "birdbike"
 698      # test that setting a tag on 1 experiment does not impact another experiment.
 699      experiment_id_2 = mlflow_client.create_experiment("SetExperimentTagTest2")
 700      experiment2 = mlflow_client.get_experiment(experiment_id_2)
 701      assert len(experiment2.tags) == 0
 702      # test that setting a tag on different experiments maintain different values across experiments
 703      mlflow_client.set_experiment_tag(experiment_id_2, "dataset", "birds200")
 704      experiment = mlflow_client.get_experiment(experiment_id)
 705      experiment2 = mlflow_client.get_experiment(experiment_id_2)
 706      assert "dataset" in experiment.tags
 707      assert experiment.tags["dataset"] == "birdbike"
 708      assert "dataset" in experiment2.tags
 709      assert experiment2.tags["dataset"] == "birds200"
 710      # test can set multi-line tags
 711      mlflow_client.set_experiment_tag(experiment_id, "multiline tag", "value2\nvalue2\nvalue2")
 712      experiment = mlflow_client.get_experiment(experiment_id)
 713      assert "multiline tag" in experiment.tags
 714      assert experiment.tags["multiline tag"] == "value2\nvalue2\nvalue2"
 715  
 716  
 717  def test_set_experiment_tag_with_empty_string_as_value(mlflow_client):
 718      experiment_id = mlflow_client.create_experiment(
 719          test_set_experiment_tag_with_empty_string_as_value.__name__
 720      )
 721      mlflow_client.set_experiment_tag(experiment_id, "tag_key", "")
 722      assert {"tag_key": ""}.items() <= mlflow_client.get_experiment(experiment_id).tags.items()
 723  
 724  
 725  def test_delete_experiment_tag(mlflow_client):
 726      experiment_id = mlflow_client.create_experiment("DeleteExperimentTagTest")
 727      mlflow_client.set_experiment_tag(experiment_id, "dataset", "imagenet1K")
 728      experiment = mlflow_client.get_experiment(experiment_id)
 729      assert experiment.tags["dataset"] == "imagenet1K"
 730      # test that deleting a tag works
 731      mlflow_client.delete_experiment_tag(experiment_id, "dataset")
 732      experiment = mlflow_client.get_experiment(experiment_id)
 733      assert "dataset" not in experiment.tags
 734  
 735  
 736  def test_delete_tag(mlflow_client):
 737      experiment_id = mlflow_client.create_experiment("DeleteTagExperiment")
 738      created_run = mlflow_client.create_run(experiment_id)
 739      run_id = created_run.info.run_id
 740      mlflow_client.log_metric(run_id, key="metric", value=123.456, timestamp=789, step=2)
 741      mlflow_client.log_metric(run_id, key="stepless-metric", value=987.654, timestamp=321)
 742      mlflow_client.log_param(run_id, "param", "value")
 743      mlflow_client.set_tag(run_id, "taggity", "do-dah")
 744      run = mlflow_client.get_run(run_id)
 745      assert "taggity" in run.data.tags
 746      assert run.data.tags["taggity"] == "do-dah"
 747      mlflow_client.delete_tag(run_id, "taggity")
 748      run = mlflow_client.get_run(run_id)
 749      assert "taggity" not in run.data.tags
 750      with pytest.raises(MlflowException, match=r"Run .+ not found"):
 751          mlflow_client.delete_tag("fake_run_id", "taggity")
 752      with pytest.raises(MlflowException, match="No tag with name: fakeTag"):
 753          mlflow_client.delete_tag(run_id, "fakeTag")
 754      mlflow_client.delete_run(run_id)
 755      with pytest.raises(MlflowException, match=f"The run {run_id} must be in"):
 756          mlflow_client.delete_tag(run_id, "taggity")
 757  
 758  
 759  def test_log_batch(mlflow_client):
 760      experiment_id = mlflow_client.create_experiment("Batch em up")
 761      created_run = mlflow_client.create_run(experiment_id)
 762      run_id = created_run.info.run_id
 763      mlflow_client.log_batch(
 764          run_id=run_id,
 765          metrics=[Metric("metric", 123.456, 789, 3)],
 766          params=[Param("param", "value")],
 767          tags=[RunTag("taggity", "do-dah")],
 768      )
 769      run = mlflow_client.get_run(run_id)
 770      assert run.data.metrics.get("metric") == 123.456
 771      assert run.data.params.get("param") == "value"
 772      assert run.data.tags.get("taggity") == "do-dah"
 773      metric_history = mlflow_client.get_metric_history(run_id, "metric")
 774      assert len(metric_history) == 1
 775      metric = metric_history[0]
 776      assert metric.key == "metric"
 777      assert metric.value == 123.456
 778      assert metric.timestamp == 789
 779      assert metric.step == 3
 780  
 781  
 782  def test_log_batch_validation(mlflow_client):
 783      experiment_id = mlflow_client.create_experiment("log_batch validation")
 784      created_run = mlflow_client.create_run(experiment_id)
 785      run_id = created_run.info.run_id
 786  
 787      def assert_bad_request(payload, expected_error_message):
 788          response = _send_rest_tracking_post_request(
 789              mlflow_client.tracking_uri,
 790              "/api/2.0/mlflow/runs/log-batch",
 791              payload,
 792          )
 793          assert response.status_code == 400
 794          assert expected_error_message in response.text
 795  
 796      for request_parameter in ["metrics", "params", "tags"]:
 797          assert_bad_request(
 798              {
 799                  "run_id": run_id,
 800                  request_parameter: "foo",
 801              },
 802              f"Invalid value \\\"foo\\\" for parameter '{request_parameter}' supplied",
 803          )
 804  
 805      ## Should 400 if missing timestamp
 806      assert_bad_request(
 807          {"run_id": run_id, "metrics": [{"key": "mae", "value": 2.5}]},
 808          "Missing value for required parameter 'metrics[0].timestamp'",
 809      )
 810  
 811      ## Should 200 if timestamp provided but step is not
 812      response = _send_rest_tracking_post_request(
 813          mlflow_client.tracking_uri,
 814          "/api/2.0/mlflow/runs/log-batch",
 815          {
 816              "run_id": run_id,
 817              "metrics": [{"key": "mae", "value": 2.5, "timestamp": 123456789}],
 818          },
 819      )
 820  
 821      assert response.status_code == 200
 822  
 823  
 824  @pytest.mark.xfail(reason="Tracking server does not support logged-model endpoints yet")
 825  @pytest.mark.allow_infer_pip_requirements_fallback
 826  def test_log_model(mlflow_client):
 827      experiment_id = mlflow_client.create_experiment("Log models")
 828      with TempDir(chdr=True):
 829          model_paths = [f"model/path/{i}" for i in range(3)]
 830          mlflow.set_tracking_uri(mlflow_client.tracking_uri)
 831          with mlflow.start_run(experiment_id=experiment_id) as run:
 832              for i, m in enumerate(model_paths):
 833                  mlflow.pyfunc.log_model(name=m, loader_module="mlflow.pyfunc")
 834                  mlflow.pyfunc.save_model(
 835                      m,
 836                      mlflow_model=Model(artifact_path=m, run_id=run.info.run_id),
 837                      loader_module="mlflow.pyfunc",
 838                  )
 839                  model = Model.load(os.path.join(m, "MLmodel"))
 840                  run = mlflow.get_run(run.info.run_id)
 841                  tag = run.data.tags["mlflow.log-model.history"]
 842                  models = json.loads(tag)
 843                  model.utc_time_created = models[i]["utc_time_created"]
 844  
 845                  history_model_meta = models[i].copy()
 846                  original_model_uuid = history_model_meta.pop("model_uuid")
 847                  model_meta = model.get_tags_dict().copy()
 848                  new_model_uuid = model_meta.pop("model_uuid")
 849                  assert history_model_meta == model_meta
 850                  assert original_model_uuid != new_model_uuid
 851                  assert len(models) == i + 1
 852                  for j in range(0, i + 1):
 853                      assert models[j]["artifact_path"] == model_paths[j]
 854  
 855  
 856  def test_set_terminated_defaults(mlflow_client):
 857      experiment_id = mlflow_client.create_experiment("Terminator 1")
 858      created_run = mlflow_client.create_run(experiment_id)
 859      run_id = created_run.info.run_id
 860      assert mlflow_client.get_run(run_id).info.status == "RUNNING"
 861      assert mlflow_client.get_run(run_id).info.end_time is None
 862      mlflow_client.set_terminated(run_id)
 863      assert mlflow_client.get_run(run_id).info.status == "FINISHED"
 864      assert mlflow_client.get_run(run_id).info.end_time <= get_current_time_millis()
 865  
 866  
 867  def test_set_terminated_status(mlflow_client):
 868      experiment_id = mlflow_client.create_experiment("Terminator 2")
 869      created_run = mlflow_client.create_run(experiment_id)
 870      run_id = created_run.info.run_id
 871      assert mlflow_client.get_run(run_id).info.status == "RUNNING"
 872      assert mlflow_client.get_run(run_id).info.end_time is None
 873      mlflow_client.set_terminated(run_id, "FAILED")
 874      assert mlflow_client.get_run(run_id).info.status == "FAILED"
 875      assert mlflow_client.get_run(run_id).info.end_time <= get_current_time_millis()
 876  
 877  
 878  def test_artifacts(mlflow_client, tmp_path):
 879      experiment_id = mlflow_client.create_experiment("Art In Fact")
 880      experiment_info = mlflow_client.get_experiment(experiment_id)
 881      assert experiment_info.artifact_location.startswith(path_to_local_file_uri(str(tmp_path)))
 882      artifact_path = urllib.parse.urlparse(experiment_info.artifact_location).path
 883      assert posixpath.split(artifact_path)[-1] == experiment_id
 884  
 885      created_run = mlflow_client.create_run(experiment_id)
 886      assert created_run.info.artifact_uri.startswith(experiment_info.artifact_location)
 887      run_id = created_run.info.run_id
 888      src_dir = tmp_path.joinpath("test_artifacts_src")
 889      src_dir.mkdir()
 890      src_file = os.path.join(src_dir, "my.file")
 891      with open(src_file, "w") as f:
 892          f.write("Hello, World!")
 893      mlflow_client.log_artifact(run_id, src_file, None)
 894      mlflow_client.log_artifacts(run_id, src_dir, "dir")
 895  
 896      root_artifacts_list = mlflow_client.list_artifacts(run_id)
 897      assert {a.path for a in root_artifacts_list} == {"my.file", "dir"}
 898  
 899      dir_artifacts_list = mlflow_client.list_artifacts(run_id, "dir")
 900      assert {a.path for a in dir_artifacts_list} == {"dir/my.file"}
 901  
 902      all_artifacts = download_artifacts(
 903          run_id=run_id, artifact_path=".", tracking_uri=mlflow_client.tracking_uri
 904      )
 905      with open(f"{all_artifacts}/my.file") as f:
 906          assert f.read() == "Hello, World!"
 907      with open(f"{all_artifacts}/dir/my.file") as f:
 908          assert f.read() == "Hello, World!"
 909  
 910      dir_artifacts = download_artifacts(
 911          run_id=run_id, artifact_path="dir", tracking_uri=mlflow_client.tracking_uri
 912      )
 913      with open(f"{dir_artifacts}/my.file") as f:
 914          assert f.read() == "Hello, World!"
 915  
 916  
 917  def test_search_pagination(mlflow_client):
 918      experiment_id = mlflow_client.create_experiment("search_pagination")
 919      runs = [mlflow_client.create_run(experiment_id, start_time=1).info.run_id for _ in range(0, 10)]
 920      runs = sorted(runs)
 921      result = mlflow_client.search_runs([experiment_id], max_results=4, page_token=None)
 922      assert [r.info.run_id for r in result] == runs[0:4]
 923      assert result.token is not None
 924      result = mlflow_client.search_runs([experiment_id], max_results=4, page_token=result.token)
 925      assert [r.info.run_id for r in result] == runs[4:8]
 926      assert result.token is not None
 927      result = mlflow_client.search_runs([experiment_id], max_results=4, page_token=result.token)
 928      assert [r.info.run_id for r in result] == runs[8:]
 929      assert result.token is None
 930  
 931  
 932  def test_search_validation(mlflow_client):
 933      experiment_id = mlflow_client.create_experiment("search_validation")
 934      with pytest.raises(
 935          MlflowException,
 936          match=r"Invalid value 123456789 for parameter 'max_results' supplied",
 937      ):
 938          mlflow_client.search_runs([experiment_id], max_results=123456789)
 939  
 940  
 941  def test_get_experiment_by_name(mlflow_client):
 942      name = "test_get_experiment_by_name"
 943      experiment_id = mlflow_client.create_experiment(name)
 944      res = mlflow_client.get_experiment_by_name(name)
 945      assert res.experiment_id == experiment_id
 946      assert res.name == name
 947      assert mlflow_client.get_experiment_by_name("idontexist") is None
 948  
 949  
 950  def test_get_experiment(mlflow_client):
 951      name = "test_get_experiment"
 952      experiment_id = mlflow_client.create_experiment(name)
 953      res = mlflow_client.get_experiment(experiment_id)
 954      assert res.experiment_id == experiment_id
 955      assert res.name == name
 956  
 957  
 958  def test_search_experiments(mlflow_client):
 959      # To ensure the default experiment and non-default experiments have different creation_time
 960      # for deterministic search results, send a request to the server and initialize the tracking
 961      # store.
 962      assert mlflow_client.search_experiments()[0].name == "Default"
 963  
 964      experiments = [
 965          ("a", {"key": "value"}),
 966          ("ab", {"key": "vaLue"}),
 967          ("Abc", None),
 968      ]
 969      experiment_ids = []
 970      for name, tags in experiments:
 971          # sleep for windows file system current_time precision in Python to enforce
 972          # deterministic ordering based on last_update_time (creation_time due to no
 973          # mutation of experiment state)
 974          time.sleep(0.001)
 975          experiment_ids.append(mlflow_client.create_experiment(name, tags=tags))
 976  
 977      # filter_string
 978      experiments = mlflow_client.search_experiments(filter_string="attribute.name = 'a'")
 979      assert [e.name for e in experiments] == ["a"]
 980      experiments = mlflow_client.search_experiments(filter_string="attribute.name != 'a'")
 981      assert [e.name for e in experiments] == ["Abc", "ab", "Default"]
 982      experiments = mlflow_client.search_experiments(filter_string="name LIKE 'a%'")
 983      assert [e.name for e in experiments] == ["ab", "a"]
 984      experiments = mlflow_client.search_experiments(filter_string="tag.key = 'value'")
 985      assert [e.name for e in experiments] == ["a"]
 986      experiments = mlflow_client.search_experiments(filter_string="tag.key != 'value'")
 987      assert [e.name for e in experiments] == ["ab"]
 988      experiments = mlflow_client.search_experiments(filter_string="tag.key ILIKE '%alu%'")
 989      assert [e.name for e in experiments] == ["ab", "a"]
 990  
 991      # order_by
 992      experiments = mlflow_client.search_experiments(order_by=["name DESC"])
 993      assert [e.name for e in experiments] == ["ab", "a", "Default", "Abc"]
 994  
 995      # max_results
 996      experiments = mlflow_client.search_experiments(max_results=2)
 997      assert [e.name for e in experiments] == ["Abc", "ab"]
 998      # page_token
 999      experiments = mlflow_client.search_experiments(page_token=experiments.token)
1000      assert [e.name for e in experiments] == ["a", "Default"]
1001  
1002      # view_type
1003      time.sleep(0.001)
1004      mlflow_client.delete_experiment(experiment_ids[1])
1005      experiments = mlflow_client.search_experiments(view_type=ViewType.ACTIVE_ONLY)
1006      assert [e.name for e in experiments] == ["Abc", "a", "Default"]
1007      experiments = mlflow_client.search_experiments(view_type=ViewType.DELETED_ONLY)
1008      assert [e.name for e in experiments] == ["ab"]
1009      experiments = mlflow_client.search_experiments(view_type=ViewType.ALL)
1010      assert [e.name for e in experiments] == ["Abc", "ab", "a", "Default"]
1011  
1012  
1013  def test_get_metric_history_bulk_rejects_invalid_requests(mlflow_client):
1014      def assert_response(resp, message_part):
1015          assert resp.status_code == 400
1016          response_json = resp.json()
1017          assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE"
1018          assert message_part in response_json.get("message", "")
1019  
1020      response_no_run_ids_field = requests.get(
1021          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1022          params={"metric_key": "key"},
1023      )
1024      assert_response(
1025          response_no_run_ids_field,
1026          "GetMetricHistoryBulk request must specify at least one run_id",
1027      )
1028  
1029      response_empty_run_ids = requests.get(
1030          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1031          params={"run_id": [], "metric_key": "key"},
1032      )
1033      assert_response(
1034          response_empty_run_ids,
1035          "GetMetricHistoryBulk request must specify at least one run_id",
1036      )
1037  
1038      response_too_many_run_ids = requests.get(
1039          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1040          params={"run_id": [f"id_{i}" for i in range(1000)], "metric_key": "key"},
1041      )
1042      assert_response(
1043          response_too_many_run_ids,
1044          "GetMetricHistoryBulk request cannot specify more than",
1045      )
1046  
1047      response_no_metric_key_field = requests.get(
1048          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1049          params={"run_id": ["123"]},
1050      )
1051      assert_response(
1052          response_no_metric_key_field,
1053          "GetMetricHistoryBulk request must specify a metric_key",
1054      )
1055  
1056  
1057  def test_get_metric_history_bulk_returns_expected_metrics_in_expected_order(
1058      mlflow_client,
1059  ):
1060      experiment_id = mlflow_client.create_experiment("get metric history bulk")
1061      created_run1 = mlflow_client.create_run(experiment_id)
1062      run_id1 = created_run1.info.run_id
1063      created_run2 = mlflow_client.create_run(experiment_id)
1064      run_id2 = created_run2.info.run_id
1065      created_run3 = mlflow_client.create_run(experiment_id)
1066      run_id3 = created_run3.info.run_id
1067  
1068      metricA_history = [
1069          {"key": "metricA", "timestamp": 1, "step": 2, "value": 10.0},
1070          {"key": "metricA", "timestamp": 1, "step": 3, "value": 11.0},
1071          {"key": "metricA", "timestamp": 1, "step": 3, "value": 12.0},
1072          {"key": "metricA", "timestamp": 2, "step": 3, "value": 12.0},
1073      ]
1074      for metric in metricA_history:
1075          mlflow_client.log_metric(run_id1, **metric)
1076          metric_for_run2 = dict(metric)
1077          metric_for_run2["value"] += 1.0
1078          mlflow_client.log_metric(run_id2, **metric_for_run2)
1079  
1080      metricB_history = [
1081          {"key": "metricB", "timestamp": 7, "step": -2, "value": -100.0},
1082          {"key": "metricB", "timestamp": 8, "step": 0, "value": 0.0},
1083          {"key": "metricB", "timestamp": 8, "step": 0, "value": 1.0},
1084          {"key": "metricB", "timestamp": 9, "step": 1, "value": 12.0},
1085      ]
1086      for metric in metricB_history:
1087          mlflow_client.log_metric(run_id1, **metric)
1088          metric_for_run2 = dict(metric)
1089          metric_for_run2["value"] += 1.0
1090          mlflow_client.log_metric(run_id2, **metric_for_run2)
1091  
1092      response_run1_metricA = requests.get(
1093          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1094          params={"run_id": [run_id1], "metric_key": "metricA"},
1095      )
1096      assert response_run1_metricA.status_code == 200
1097      assert response_run1_metricA.json().get("metrics") == [
1098          {**metric, "run_id": run_id1} for metric in metricA_history
1099      ]
1100  
1101      response_run2_metricB = requests.get(
1102          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1103          params={"run_id": [run_id2], "metric_key": "metricB"},
1104      )
1105      assert response_run2_metricB.status_code == 200
1106      assert response_run2_metricB.json().get("metrics") == [
1107          {**metric, "run_id": run_id2, "value": metric["value"] + 1.0} for metric in metricB_history
1108      ]
1109  
1110      response_run1_run2_metricA = requests.get(
1111          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1112          params={"run_id": [run_id1, run_id2], "metric_key": "metricA"},
1113      )
1114      assert response_run1_run2_metricA.status_code == 200
1115      assert response_run1_run2_metricA.json().get("metrics") == sorted(
1116          [{**metric, "run_id": run_id1} for metric in metricA_history]
1117          + [
1118              {**metric, "run_id": run_id2, "value": metric["value"] + 1.0}
1119              for metric in metricA_history
1120          ],
1121          key=lambda metric: metric["run_id"],
1122      )
1123  
1124      response_run1_run2_run_3_metricB = requests.get(
1125          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1126          params={"run_id": [run_id1, run_id2, run_id3], "metric_key": "metricB"},
1127      )
1128      assert response_run1_run2_run_3_metricB.status_code == 200
1129      assert response_run1_run2_run_3_metricB.json().get("metrics") == sorted(
1130          [{**metric, "run_id": run_id1} for metric in metricB_history]
1131          + [
1132              {**metric, "run_id": run_id2, "value": metric["value"] + 1.0}
1133              for metric in metricB_history
1134          ],
1135          key=lambda metric: metric["run_id"],
1136      )
1137  
1138  
1139  def test_get_metric_history_bulk_respects_max_results(mlflow_client):
1140      experiment_id = mlflow_client.create_experiment("get metric history bulk")
1141      run_id = mlflow_client.create_run(experiment_id).info.run_id
1142      max_results = 2
1143  
1144      metricA_history = [
1145          {"key": "metricA", "timestamp": 1, "step": 2, "value": 10.0},
1146          {"key": "metricA", "timestamp": 1, "step": 3, "value": 11.0},
1147          {"key": "metricA", "timestamp": 1, "step": 3, "value": 12.0},
1148          {"key": "metricA", "timestamp": 2, "step": 3, "value": 12.0},
1149      ]
1150      for metric in metricA_history:
1151          mlflow_client.log_metric(run_id, **metric)
1152  
1153      response_limited = requests.get(
1154          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk",
1155          params={
1156              "run_id": [run_id],
1157              "metric_key": "metricA",
1158              "max_results": max_results,
1159          },
1160      )
1161      assert response_limited.status_code == 200
1162      assert response_limited.json().get("metrics") == [
1163          {**metric, "run_id": run_id} for metric in metricA_history[:max_results]
1164      ]
1165  
1166  
1167  def test_get_metric_history_bulk_calls_optimized_impl_when_expected(tmp_path):
1168      from mlflow.server.handlers import get_metric_history_bulk_handler
1169  
1170      path = path_to_local_file_uri(str(tmp_path.joinpath("sqlalchemy.db")))
1171      uri = ("sqlite://" if sys.platform == "win32" else "sqlite:////") + path[len("file://") :]
1172      mock_store = mock.Mock(wraps=SqlAlchemyStore(uri, str(tmp_path)))
1173  
1174      flask_app = flask.Flask("test_flask_app")
1175  
1176      class MockRequestArgs:
1177          def __init__(self, args_dict):
1178              self.args_dict = args_dict
1179  
1180          def to_dict(
1181              self,
1182              flat,
1183          ):
1184              return self.args_dict
1185  
1186          def get(self, key, default=None):
1187              return self.args_dict.get(key, default)
1188  
1189      with (
1190          mock.patch("mlflow.server.handlers._get_tracking_store", return_value=mock_store),
1191          flask_app.test_request_context() as mock_context,
1192      ):
1193          run_ids = [str(i) for i in range(10)]
1194          mock_context.request.args = MockRequestArgs({
1195              "run_id": run_ids,
1196              "metric_key": "mock_key",
1197          })
1198  
1199          get_metric_history_bulk_handler()
1200  
1201          mock_store.get_metric_history_bulk.assert_called_once_with(
1202              run_ids=run_ids,
1203              metric_key="mock_key",
1204              max_results=25000,
1205          )
1206  
1207  
1208  def test_get_metric_history_respects_max_results(mlflow_client):
1209      experiment_id = mlflow_client.create_experiment("test max_results")
1210      run = mlflow_client.create_run(experiment_id)
1211      run_id = run.info.run_id
1212  
1213      metric_history = [
1214          {"key": "test_metric", "value": float(i), "step": i, "timestamp": 1000 + i}
1215          for i in range(5)
1216      ]
1217      for metric in metric_history:
1218          mlflow_client.log_metric(run_id, **metric)
1219  
1220      # Test without max_results - should return all metrics
1221      all_metrics = mlflow_client.get_metric_history(run_id, "test_metric")
1222      assert len(all_metrics) == 5
1223  
1224      # Test with max_results=3 - should return only 3 metrics
1225      response = requests.get(
1226          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history",
1227          params={"run_id": run_id, "metric_key": "test_metric", "max_results": 3},
1228      )
1229      assert response.status_code == 200
1230      response_data = response.json()
1231      assert len(response_data["metrics"]) == 3
1232  
1233      returned_metrics = response_data["metrics"]
1234      for i, metric in enumerate(returned_metrics):
1235          assert metric["key"] == "test_metric"
1236          assert metric["value"] == float(i)
1237          if _MLFLOW_GO_STORE_TESTING.get():
1238              assert int(metric["step"]) == i
1239          else:
1240              assert metric["step"] == i
1241  
1242  
1243  def test_get_metric_history_with_page_token(mlflow_client):
1244      experiment_id = mlflow_client.create_experiment("test page_token")
1245      run = mlflow_client.create_run(experiment_id)
1246      run_id = run.info.run_id
1247  
1248      metric_history = [
1249          {"key": "test_metric", "value": float(i), "step": i, "timestamp": 1000 + i}
1250          for i in range(10)
1251      ]
1252      for metric in metric_history:
1253          mlflow_client.log_metric(run_id, **metric)
1254  
1255      page_size = 4
1256  
1257      first_response = requests.get(
1258          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history",
1259          params={
1260              "run_id": run_id,
1261              "metric_key": "test_metric",
1262              "max_results": page_size,
1263          },
1264      )
1265      assert first_response.status_code == 200
1266      first_data = first_response.json()
1267      first_metrics = first_data["metrics"]
1268      first_token = first_data.get("next_page_token")
1269  
1270      assert first_token is not None
1271      assert len(first_metrics) == 4
1272  
1273      second_response = requests.get(
1274          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history",
1275          params={
1276              "run_id": run_id,
1277              "metric_key": "test_metric",
1278              "max_results": page_size,
1279              "page_token": first_token,
1280          },
1281      )
1282      assert second_response.status_code == 200
1283      second_data = second_response.json()
1284      second_metrics = second_data["metrics"]
1285      second_token = second_data.get("next_page_token")
1286  
1287      assert second_token is not None
1288      assert len(second_metrics) == 4
1289  
1290      third_response = requests.get(
1291          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history",
1292          params={
1293              "run_id": run_id,
1294              "metric_key": "test_metric",
1295              "max_results": page_size,
1296              "page_token": second_token,
1297          },
1298      )
1299      assert third_response.status_code == 200
1300      third_data = third_response.json()
1301      third_metrics = third_data["metrics"]
1302      third_token = third_data.get("next_page_token")
1303  
1304      assert third_token is None
1305      assert len(third_metrics) == 2
1306  
1307      all_paginated_metrics = first_metrics + second_metrics + third_metrics
1308      assert len(all_paginated_metrics) == 10
1309  
1310      for i, metric in enumerate(all_paginated_metrics):
1311          assert metric["key"] == "test_metric"
1312          assert metric["value"] == float(i)
1313          if _MLFLOW_GO_STORE_TESTING.get():
1314              assert int(metric["step"]) == i
1315          else:
1316              assert metric["step"] == i
1317          if _MLFLOW_GO_STORE_TESTING.get():
1318              assert int(metric["timestamp"]) == 1000 + i
1319          else:
1320              assert metric["timestamp"] == 1000 + i
1321  
1322      # Test with invalid page_token
1323      response = requests.get(
1324          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history",
1325          params={
1326              "run_id": run_id,
1327              "metric_key": "test_metric",
1328              "page_token": "invalid_token",
1329          },
1330      )
1331      assert response.status_code == 400
1332      response_data = response.json()
1333      assert "INVALID_PARAMETER_VALUE" in response_data.get("error_code", "")
1334  
1335  
1336  def test_get_metric_history_bulk_interval_rejects_invalid_requests(mlflow_client):
1337      def assert_response(resp, message_part):
1338          assert resp.status_code == 400
1339          response_json = resp.json()
1340          assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE"
1341          assert message_part in response_json.get("message", "")
1342  
1343      url = f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk-interval"
1344  
1345      assert_response(
1346          requests.get(url, params={"metric_key": "key"}),
1347          "Missing value for required parameter 'run_ids'.",
1348      )
1349  
1350      assert_response(
1351          requests.get(url, params={"run_ids": [], "metric_key": "key"}),
1352          "Missing value for required parameter 'run_ids'.",
1353      )
1354  
1355      assert_response(
1356          requests.get(
1357              url,
1358              params={"run_ids": [f"id_{i}" for i in range(1000)], "metric_key": "key"},
1359          ),
1360          "GetMetricHistoryBulkInterval request must specify at most 100 run_ids.",
1361      )
1362  
1363      assert_response(
1364          requests.get(url, params={"run_ids": ["123"], "metric_key": "key", "max_results": 0}),
1365          "max_results must be between 1 and 2500",
1366      )
1367  
1368      assert_response(
1369          requests.get(url, params={"run_ids": ["123"], "metric_key": ""}),
1370          "Missing value for required parameter 'metric_key'",
1371      )
1372  
1373      assert_response(
1374          requests.get(url, params={"run_ids": ["123"], "max_results": 5}),
1375          "Missing value for required parameter 'metric_key'",
1376      )
1377  
1378      assert_response(
1379          requests.get(
1380              url,
1381              params={
1382                  "run_ids": ["123"],
1383                  "metric_key": "key",
1384                  "start_step": 1,
1385                  "end_step": 0,
1386                  "max_results": 5,
1387              },
1388          ),
1389          "end_step must be greater than start_step. ",
1390      )
1391  
1392      assert_response(
1393          requests.get(
1394              url,
1395              params={
1396                  "run_ids": ["123"],
1397                  "metric_key": "key",
1398                  "start_step": 1,
1399                  "max_results": 5,
1400              },
1401          ),
1402          "If either start step or end step are specified, both must be specified.",
1403      )
1404  
1405  
1406  def test_get_metric_history_bulk_interval_respects_max_results(mlflow_client):
1407      experiment_id = mlflow_client.create_experiment("get metric history bulk")
1408      run_id1 = mlflow_client.create_run(experiment_id).info.run_id
1409      metric_history = [
1410          {"key": "metricA", "timestamp": 1, "step": i, "value": 10.0} for i in range(10)
1411      ]
1412      for metric in metric_history:
1413          mlflow_client.log_metric(run_id1, **metric)
1414  
1415      url = f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/metrics/get-history-bulk-interval"
1416      response_limited = requests.get(
1417          url,
1418          params={"run_ids": [run_id1], "metric_key": "metricA", "max_results": 5},
1419      )
1420      assert response_limited.status_code == 200
1421      expected_steps = [0, 2, 4, 6, 8, 9]
1422      expected_metrics = [
1423          {**metric, "run_id": run_id1}
1424          for metric in metric_history
1425          if metric["step"] in expected_steps
1426      ]
1427      assert response_limited.json().get("metrics") == expected_metrics
1428  
1429      # with start_step and end_step
1430      response_limited = requests.get(
1431          url,
1432          params={
1433              "run_ids": [run_id1],
1434              "metric_key": "metricA",
1435              "start_step": 0,
1436              "end_step": 4,
1437              "max_results": 5,
1438          },
1439      )
1440      assert response_limited.status_code == 200
1441      assert response_limited.json().get("metrics") == [
1442          {**metric, "run_id": run_id1} for metric in metric_history[:5]
1443      ]
1444  
1445      # multiple runs
1446      run_id2 = mlflow_client.create_run(experiment_id).info.run_id
1447      metric_history2 = [
1448          {"key": "metricA", "timestamp": 1, "step": i, "value": 10.0} for i in range(20)
1449      ]
1450      for metric in metric_history2:
1451          mlflow_client.log_metric(run_id2, **metric)
1452      response_limited = requests.get(
1453          url,
1454          params={
1455              "run_ids": [run_id1, run_id2],
1456              "metric_key": "metricA",
1457              "max_results": 5,
1458          },
1459      )
1460      expected_steps = [0, 4, 8, 9, 12, 16, 19]
1461      expected_metrics = []
1462      for run_id, metric_history in [
1463          (run_id1, metric_history),
1464          (run_id2, metric_history2),
1465      ]:
1466          expected_metrics.extend([
1467              {**metric, "run_id": run_id}
1468              for metric in metric_history
1469              if metric["step"] in expected_steps
1470          ])
1471      assert response_limited.json().get("metrics") == expected_metrics
1472  
1473      # test metrics with same steps
1474      metric_history_timestamp2 = [
1475          {"key": "metricA", "timestamp": 2, "step": i, "value": 10.0} for i in range(10)
1476      ]
1477      for metric in metric_history_timestamp2:
1478          mlflow_client.log_metric(run_id1, **metric)
1479  
1480      response_limited = requests.get(
1481          url,
1482          params={"run_ids": [run_id1], "metric_key": "metricA", "max_results": 5},
1483      )
1484      assert response_limited.status_code == 200
1485      expected_steps = [0, 2, 4, 6, 8, 9]
1486      expected_metrics = [
1487          {"key": "metricA", "timestamp": j, "step": i, "value": 10.0, "run_id": run_id1}
1488          for i in expected_steps
1489          for j in [1, 2]
1490      ]
1491      assert response_limited.json().get("metrics") == expected_metrics
1492  
1493  
1494  def test_search_dataset_handler_rejects_invalid_requests(mlflow_client):
1495      def assert_response(resp, message_part):
1496          assert resp.status_code == 400
1497          response_json = resp.json()
1498          assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE"
1499          assert message_part in response_json.get("message", "")
1500  
1501      response_no_experiment_id_field = requests.post(
1502          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets",
1503          json={},
1504      )
1505      assert_response(
1506          response_no_experiment_id_field,
1507          "SearchDatasets request must specify at least one experiment_id.",
1508      )
1509  
1510      response_empty_experiment_id_field = requests.post(
1511          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets",
1512          json={"experiment_ids": []},
1513      )
1514      assert_response(
1515          response_empty_experiment_id_field,
1516          "SearchDatasets request must specify at least one experiment_id.",
1517      )
1518  
1519      response_too_many_experiment_ids = requests.post(
1520          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets",
1521          json={"experiment_ids": [f"id_{i}" for i in range(1000)]},
1522      )
1523      assert_response(
1524          response_too_many_experiment_ids,
1525          "SearchDatasets request cannot specify more than",
1526      )
1527  
1528  
1529  def test_search_dataset_handler_returns_expected_results(mlflow_client):
1530      experiment_id = mlflow_client.create_experiment("log inputs test")
1531      created_run = mlflow_client.create_run(experiment_id)
1532      run_id = created_run.info.run_id
1533  
1534      dataset1 = Dataset(
1535          name="name1",
1536          digest="digest1",
1537          source_type="source_type1",
1538          source="source1",
1539      )
1540      dataset_inputs1 = [
1541          DatasetInput(
1542              dataset=dataset1,
1543              tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="training")],
1544          )
1545      ]
1546      mlflow_client.log_inputs(run_id, dataset_inputs1)
1547  
1548      response = requests.post(
1549          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/experiments/search-datasets",
1550          json={"experiment_ids": [experiment_id]},
1551      )
1552      expected = {
1553          "experiment_id": experiment_id,
1554          "name": "name1",
1555          "digest": "digest1",
1556          "context": "training",
1557      }
1558  
1559      assert response.status_code == 200
1560      assert response.json().get("dataset_summaries") == [expected]
1561  
1562  
1563  def test_create_model_version_with_path_source(mlflow_client):
1564      name = "model"
1565      mlflow_client.create_registered_model(name)
1566      exp_id = mlflow_client.create_experiment("test")
1567      run = mlflow_client.create_run(experiment_id=exp_id)
1568  
1569      response = requests.post(
1570          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1571          json={
1572              "name": name,
1573              "source": run.info.artifact_uri[len("file://") :],
1574              "run_id": run.info.run_id,
1575          },
1576      )
1577      assert response.status_code == 200
1578  
1579      # run_id is not specified
1580      response = requests.post(
1581          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1582          json={
1583              "name": name,
1584              "source": run.info.artifact_uri[len("file://") :],
1585          },
1586      )
1587      assert response.status_code == 400
1588      assert "To use a local path as a model version" in response.json()["message"]
1589  
1590      # run_id is specified but source is not in the run's artifact directory
1591      response = requests.post(
1592          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1593          json={
1594              "name": name,
1595              "source": "/tmp",
1596              "run_id": run.info.run_id,
1597          },
1598      )
1599      assert response.status_code == 400
1600      assert "To use a local path as a model version" in response.json()["message"]
1601  
1602  
1603  def test_create_model_version_with_non_local_source(mlflow_client):
1604      name = "model"
1605      mlflow_client.create_registered_model(name)
1606      exp_id = mlflow_client.create_experiment("test")
1607      run = mlflow_client.create_run(experiment_id=exp_id)
1608  
1609      response = requests.post(
1610          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1611          json={
1612              "name": name,
1613              "source": run.info.artifact_uri[len("file://") :],
1614              "run_id": run.info.run_id,
1615          },
1616      )
1617      assert response.status_code == 200
1618  
1619      # Test that remote uri's supplied as a source with absolute paths work fine
1620      response = requests.post(
1621          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1622          json={
1623              "name": name,
1624              "source": "mlflow-artifacts:/models",
1625              "run_id": run.info.run_id,
1626          },
1627      )
1628      assert response.status_code == 200
1629  
1630      # A single trailing slash
1631      response = requests.post(
1632          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1633          json={
1634              "name": name,
1635              "source": "mlflow-artifacts:/models/",
1636              "run_id": run.info.run_id,
1637          },
1638      )
1639      assert response.status_code == 200
1640  
1641      # Multiple trailing slashes
1642      response = requests.post(
1643          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1644          json={
1645              "name": name,
1646              "source": "mlflow-artifacts:/models///",
1647              "run_id": run.info.run_id,
1648          },
1649      )
1650      assert response.status_code == 200
1651  
1652      # Multiple slashes
1653      response = requests.post(
1654          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1655          json={
1656              "name": name,
1657              "source": "mlflow-artifacts:/models/foo///bar",
1658              "run_id": run.info.run_id,
1659          },
1660      )
1661      assert response.status_code == 200
1662  
1663      response = requests.post(
1664          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1665          json={
1666              "name": name,
1667              "source": "mlflow-artifacts://host:9000/models",
1668              "run_id": run.info.run_id,
1669          },
1670      )
1671      assert response.status_code == 200
1672  
1673      # Multiple dots
1674      response = requests.post(
1675          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1676          json={
1677              "name": name,
1678              "source": "mlflow-artifacts://host:9000/models/artifact/..../",
1679              "run_id": run.info.run_id,
1680          },
1681      )
1682      assert response.status_code == 200
1683  
1684      # Test that invalid remote uri's cannot be created
1685      response = requests.post(
1686          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1687          json={
1688              "name": name,
1689              "source": "mlflow-artifacts://host:9000/models/../../../",
1690              "run_id": run.info.run_id,
1691          },
1692      )
1693      assert response.status_code == 400
1694      assert "If supplying a source as an http, https," in response.json()["message"]
1695  
1696      response = requests.post(
1697          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1698          json={
1699              "name": name,
1700              "source": "http://host:9000/models/../../../",
1701              "run_id": run.info.run_id,
1702          },
1703      )
1704      assert response.status_code == 400
1705      assert "If supplying a source as an http, https," in response.json()["message"]
1706  
1707      response = requests.post(
1708          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1709          json={
1710              "name": name,
1711              "source": "https://host/api/2.0/mlflow-artifacts/artifacts/../../../",
1712              "run_id": run.info.run_id,
1713          },
1714      )
1715      assert response.status_code == 400
1716      assert "If supplying a source as an http, https," in response.json()["message"]
1717  
1718      response = requests.post(
1719          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1720          json={
1721              "name": name,
1722              "source": "s3a://my_bucket/api/2.0/mlflow-artifacts/artifacts/../../../",
1723              "run_id": run.info.run_id,
1724          },
1725      )
1726      assert response.status_code == 400
1727      assert "If supplying a source as an http, https," in response.json()["message"]
1728  
1729      response = requests.post(
1730          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1731          json={
1732              "name": name,
1733              "source": "ftp://host:8888/api/2.0/mlflow-artifacts/artifacts/../../../",
1734              "run_id": run.info.run_id,
1735          },
1736      )
1737      assert response.status_code == 400
1738      assert "If supplying a source as an http, https," in response.json()["message"]
1739  
1740      response = requests.post(
1741          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1742          json={
1743              "name": name,
1744              "source": "mlflow-artifacts://host:9000/models/..%2f..%2fartifacts",
1745              "run_id": run.info.run_id,
1746          },
1747      )
1748      assert response.status_code == 400
1749      assert "If supplying a source as an http, https," in response.json()["message"]
1750  
1751      response = requests.post(
1752          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1753          json={
1754              "name": name,
1755              "source": "mlflow-artifacts://host:9000/models/artifact%00",
1756              "run_id": run.info.run_id,
1757          },
1758      )
1759      assert response.status_code == 400
1760      assert "If supplying a source as an http, https," in response.json()["message"]
1761  
1762      response = requests.post(
1763          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1764          json={
1765              "name": name,
1766              "source": f"dbfs:/{run.info.run_id}/artifacts/a%3f/../../../../../../../../../../",
1767              "run_id": run.info.run_id,
1768          },
1769      )
1770      assert response.status_code == 400
1771      assert "Invalid model version source" in response.json()["message"]
1772  
1773      model = mlflow_client.create_logged_model(experiment_id=exp_id)
1774      response = requests.post(
1775          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1776          json={
1777              "name": name,
1778              "source": model.artifact_location,
1779              "model_id": model.model_id,
1780          },
1781      )
1782      assert response.status_code == 200
1783  
1784      response = requests.post(
1785          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1786          json={
1787              "name": name,
1788              "source": model.model_uri,
1789              "model_id": model.model_id,
1790          },
1791      )
1792      assert response.status_code == 200
1793  
1794      response = requests.post(
1795          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1796          json={
1797              "name": name,
1798              "source": "file:///path/to/model",
1799              "model_id": model.model_id,
1800          },
1801      )
1802      assert response.status_code == 400
1803  
1804  
1805  def test_create_model_version_with_file_uri(mlflow_client):
1806      name = "test"
1807      mlflow_client.create_registered_model(name)
1808      exp_id = mlflow_client.create_experiment("test")
1809      run = mlflow_client.create_run(experiment_id=exp_id)
1810      assert run.info.artifact_uri.startswith("file://")
1811      response = requests.post(
1812          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1813          json={
1814              "name": name,
1815              "source": run.info.artifact_uri,
1816              "run_id": run.info.run_id,
1817          },
1818      )
1819      assert response.status_code == 200
1820  
1821      response = requests.post(
1822          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1823          json={
1824              "name": name,
1825              "source": f"{run.info.artifact_uri}/model",
1826              "run_id": run.info.run_id,
1827          },
1828      )
1829      assert response.status_code == 200
1830  
1831      response = requests.post(
1832          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1833          json={
1834              "name": name,
1835              "source": f"{run.info.artifact_uri}/.",
1836              "run_id": run.info.run_id,
1837          },
1838      )
1839      assert response.status_code == 200
1840  
1841      response = requests.post(
1842          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1843          json={
1844              "name": name,
1845              "source": f"{run.info.artifact_uri}/model/..",
1846              "run_id": run.info.run_id,
1847          },
1848      )
1849      assert response.status_code == 200
1850  
1851      # run_id is not specified
1852      response = requests.post(
1853          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1854          json={
1855              "name": name,
1856              "source": run.info.artifact_uri,
1857          },
1858      )
1859      assert response.status_code == 400
1860      assert "To use a local path as a model version" in response.json()["message"]
1861  
1862      # run_id is specified but source is not in the run's artifact directory
1863      response = requests.post(
1864          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1865          json={
1866              "name": name,
1867              "source": "file:///tmp",
1868          },
1869      )
1870      assert response.status_code == 400
1871      assert "To use a local path as a model version" in response.json()["message"]
1872  
1873      response = requests.post(
1874          f"{mlflow_client.tracking_uri}/api/2.0/mlflow/model-versions/create",
1875          json={
1876              "name": name,
1877              "source": "file://123.456.789.123/path/to/source",
1878              "run_id": run.info.run_id,
1879          },
1880      )
1881      assert response.status_code == 500, response.json()
1882      assert "is not a valid remote uri" in response.json()["message"]
1883  
1884  
1885  def test_create_model_version_with_validation_regex(db_uri: str):
1886      port = get_safe_port()
1887      with subprocess.Popen(
1888          [
1889              sys.executable,
1890              "-m",
1891              "mlflow",
1892              "server",
1893              "--port",
1894              str(port),
1895              "--backend-store-uri",
1896              db_uri,
1897          ],
1898          env=(
1899              os.environ.copy()
1900              | {
1901                  "MLFLOW_CREATE_MODEL_VERSION_SOURCE_VALIDATION_REGEX": r"^mlflow-artifacts:/.*$",
1902                  "MLFLOW_SERVER_ENABLE_JOB_EXECUTION": "false",
1903              }
1904          ),
1905      ) as proc:
1906          try:
1907              # Wait for the server to start
1908              for _ in range(10):
1909                  try:
1910                      if requests.get(f"http://localhost:{port}/health").ok:
1911                          break
1912                  except requests.ConnectionError:
1913                      time.sleep(1)
1914              else:
1915                  raise RuntimeError("Failed to connect to the MLflow server")
1916  
1917              # Test that the validation regex works as expected
1918              client = MlflowClient(f"http://localhost:{port}")
1919              name = "test"
1920              client.create_registered_model(name)
1921              # Invalid source
1922              with pytest.raises(MlflowException, match="Invalid model version source"):
1923                  client.create_model_version(name, source="s3://path/to/model")
1924              # Valid source
1925              experiment_id = client.create_experiment("test")
1926              run = client.create_run(experiment_id=experiment_id)
1927              assert run.info.artifact_uri.startswith("mlflow-artifacts:/")
1928              client.create_model_version(
1929                  name, source=f"{run.info.artifact_uri}/model", run_id=run.info.run_id
1930              )
1931          finally:
1932              proc.terminate()
1933              proc.wait()
1934  
1935  
1936  @pytest.mark.xfail(reason="Tracking server does not support logged-model endpoints yet")
1937  def test_logging_model_with_local_artifact_uri(mlflow_client):
1938      from sklearn.linear_model import LogisticRegression
1939  
1940      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
1941      with mlflow.start_run() as run:
1942          assert run.info.artifact_uri.startswith("file://")
1943          mlflow.sklearn.log_model(LogisticRegression(), name="model", registered_model_name="rmn")
1944          mlflow.pyfunc.load_model("models:/rmn/1")
1945  
1946  
1947  def test_log_input(mlflow_client, tmp_path):
1948      df = pd.DataFrame([[1, 2, 3], [1, 2, 3]], columns=["a", "b", "c"])
1949      path = tmp_path / "temp.csv"
1950      df.to_csv(path)
1951      dataset = from_pandas(df, source=path)
1952  
1953      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
1954  
1955      with mlflow.start_run() as run:
1956          mlflow.log_input(dataset, "train", {"foo": "baz"})
1957  
1958      dataset_inputs = mlflow_client.get_run(run.info.run_id).inputs.dataset_inputs
1959  
1960      assert len(dataset_inputs) == 1
1961      assert dataset_inputs[0].dataset.name == "dataset"
1962      assert dataset_inputs[0].dataset.digest == "f0f3e026"
1963      assert dataset_inputs[0].dataset.source_type == "local"
1964      assert json.loads(dataset_inputs[0].dataset.source) == {"uri": str(path)}
1965      assert json.loads(dataset_inputs[0].dataset.schema) == {
1966          "mlflow_colspec": [
1967              {"name": "a", "type": "long", "required": True},
1968              {"name": "b", "type": "long", "required": True},
1969              {"name": "c", "type": "long", "required": True},
1970          ]
1971      }
1972      assert json.loads(dataset_inputs[0].dataset.profile) == {
1973          "num_rows": 2,
1974          "num_elements": 6,
1975      }
1976  
1977      assert len(dataset_inputs[0].tags) == 2
1978      assert dataset_inputs[0].tags[0].key == "foo"
1979      assert dataset_inputs[0].tags[0].value == "baz"
1980      assert dataset_inputs[0].tags[1].key == mlflow_tags.MLFLOW_DATASET_CONTEXT
1981      assert dataset_inputs[0].tags[1].value == "train"
1982  
1983  
1984  def test_create_model_version_model_id(mlflow_client):
1985      name = "model"
1986      mlflow_client.create_registered_model(name)
1987      exp_id = mlflow_client.create_experiment("test")
1988      model = mlflow_client.create_logged_model(experiment_id=exp_id)
1989      mlflow_client.create_model_version(
1990          name=name,
1991          source=model.artifact_location,
1992          model_id=model.model_id,
1993      )
1994      model = mlflow_client.get_logged_model(model.model_id)
1995      assert model.tags["mlflow.modelVersions"] == '[{"name": "model", "version": 1}]'
1996      mlflow_client.create_model_version(
1997          name=name,
1998          source=model.artifact_location,
1999          model_id=model.model_id,
2000      )
2001      model = mlflow_client.get_logged_model(model.model_id)
2002      assert (
2003          model.tags["mlflow.modelVersions"]
2004          == '[{"name": "model", "version": 1}, {"name": "model", "version": 2}]'
2005      )
2006  
2007  
2008  def test_log_inputs(mlflow_client):
2009      experiment_id = mlflow_client.create_experiment("log inputs test")
2010      created_run = mlflow_client.create_run(experiment_id)
2011      run_id = created_run.info.run_id
2012  
2013      dataset1 = Dataset(
2014          name="name1",
2015          digest="digest1",
2016          source_type="source_type1",
2017          source="source1",
2018      )
2019      dataset_inputs1 = [DatasetInput(dataset=dataset1, tags=[InputTag(key="tag1", value="value1")])]
2020  
2021      mlflow_client.log_inputs(run_id, dataset_inputs1)
2022      run = mlflow_client.get_run(run_id)
2023      assert len(run.inputs.dataset_inputs) == 1
2024  
2025      assert isinstance(run.inputs, RunInputs)
2026      assert isinstance(run.inputs.dataset_inputs[0], DatasetInput)
2027      assert isinstance(run.inputs.dataset_inputs[0].dataset, Dataset)
2028      assert run.inputs.dataset_inputs[0].dataset.name == "name1"
2029      assert run.inputs.dataset_inputs[0].dataset.digest == "digest1"
2030      assert run.inputs.dataset_inputs[0].dataset.source_type == "source_type1"
2031      assert run.inputs.dataset_inputs[0].dataset.source == "source1"
2032      assert len(run.inputs.dataset_inputs[0].tags) == 1
2033      assert run.inputs.dataset_inputs[0].tags[0].key == "tag1"
2034      assert run.inputs.dataset_inputs[0].tags[0].value == "value1"
2035  
2036  
2037  def test_log_inputs_validation(mlflow_client):
2038      def assert_bad_request(payload, expected_error_message):
2039          response = _send_rest_tracking_post_request(
2040              mlflow_client.tracking_uri,
2041              "/api/2.0/mlflow/runs/log-inputs",
2042              payload,
2043          )
2044          assert response.status_code == 400
2045          assert expected_error_message in response.text
2046  
2047      dataset = Dataset(
2048          name="name1",
2049          digest="digest1",
2050          source_type="source_type1",
2051          source="source1",
2052      )
2053      tags = [InputTag(key="tag1", value="value1")]
2054      dataset_inputs = [
2055          json.loads(message_to_json(DatasetInput(dataset=dataset, tags=tags).to_proto()))
2056      ]
2057      assert_bad_request(
2058          {
2059              "datasets": dataset_inputs,
2060          },
2061          "Missing value for required parameter 'run_id'",
2062      )
2063  
2064  
2065  def test_log_inputs_model(mlflow_client):
2066      experiment_id = mlflow_client.create_experiment("log inputs test")
2067      run = mlflow_client.create_run(experiment_id)
2068      model = mlflow_client.create_logged_model(experiment_id=experiment_id)
2069      dataset = Dataset(
2070          name="name1",
2071          digest="digest1",
2072          source_type="source_type1",
2073          source="source1",
2074      )
2075      dataset_inputs = [
2076          DatasetInput(
2077              dataset=dataset,
2078              tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="training")],
2079          )
2080      ]
2081      mlflow_client.log_inputs(
2082          run.info.run_id,
2083          models=[LoggedModelInput(model_id=model.model_id)],
2084          datasets=dataset_inputs,
2085      )
2086      run = mlflow_client.get_run(run.info.run_id)
2087      assert len(run.inputs.model_inputs) == 1
2088  
2089  
2090  def test_update_run_name_without_changing_status(mlflow_client):
2091      experiment_id = mlflow_client.create_experiment("update run name")
2092      created_run = mlflow_client.create_run(experiment_id)
2093      mlflow_client.set_terminated(created_run.info.run_id, "FINISHED")
2094  
2095      mlflow_client.update_run(created_run.info.run_id, name="name_abc")
2096      updated_run_info = mlflow_client.get_run(created_run.info.run_id).info
2097      assert updated_run_info.run_name == "name_abc"
2098      assert updated_run_info.status == "FINISHED"
2099  
2100  
2101  def test_create_promptlab_run_handler_rejects_invalid_requests(mlflow_client):
2102      def assert_response(resp, message_part):
2103          assert resp.status_code == 400
2104          response_json = resp.json()
2105          assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE"
2106          assert message_part in response_json.get("message", "")
2107  
2108      response = requests.post(
2109          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2110          json={},
2111      )
2112      assert_response(
2113          response,
2114          "CreatePromptlabRun request must specify experiment_id.",
2115      )
2116  
2117      response = requests.post(
2118          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2119          json={"experiment_id": "123"},
2120      )
2121      assert_response(
2122          response,
2123          "CreatePromptlabRun request must specify prompt_template.",
2124      )
2125  
2126      response = requests.post(
2127          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2128          json={"experiment_id": "123", "prompt_template": "my_prompt_template"},
2129      )
2130      assert_response(
2131          response,
2132          "CreatePromptlabRun request must specify prompt_parameters.",
2133      )
2134  
2135      response = requests.post(
2136          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2137          json={
2138              "experiment_id": "123",
2139              "prompt_template": "my_prompt_template",
2140              "prompt_parameters": [{"key": "my_key", "value": "my_value"}],
2141          },
2142      )
2143      assert_response(
2144          response,
2145          "CreatePromptlabRun request must specify model_route.",
2146      )
2147  
2148      response = requests.post(
2149          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2150          json={
2151              "experiment_id": "123",
2152              "prompt_template": "my_prompt_template",
2153              "prompt_parameters": [{"key": "my_key", "value": "my_value"}],
2154              "model_route": "my_route",
2155          },
2156      )
2157      assert_response(
2158          response,
2159          "CreatePromptlabRun request must specify model_input.",
2160      )
2161  
2162      response = requests.post(
2163          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2164          json={
2165              "experiment_id": "123",
2166              "prompt_template": "my_prompt_template",
2167              "prompt_parameters": [{"key": "my_key", "value": "my_value"}],
2168              "model_route": "my_route",
2169              "model_input": "my_input",
2170          },
2171      )
2172      assert_response(
2173          response,
2174          "CreatePromptlabRun request must specify mlflow_version.",
2175      )
2176  
2177      response = requests.post(
2178          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2179          json={
2180              "experiment_id": "123",
2181              "prompt_template": "my_prompt_template",
2182              "prompt_parameters": [{"key": "my_key", "value": "my_value"}],
2183              "model_route": "my_route",
2184              "model_input": "my_input",
2185              "mlflow_version": "1.0.0",
2186          },
2187      )
2188  
2189  
2190  def test_create_promptlab_run_handler_returns_expected_results(mlflow_client):
2191      experiment_id = mlflow_client.create_experiment("log inputs test")
2192  
2193      response = requests.post(
2194          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/runs/create-promptlab-run",
2195          json={
2196              "experiment_id": experiment_id,
2197              "run_name": "my_run_name",
2198              "prompt_template": "my_prompt_template",
2199              "prompt_parameters": [{"key": "my_key", "value": "my_value"}],
2200              "model_route": "my_route",
2201              "model_parameters": [{"key": "temperature", "value": "0.1"}],
2202              "model_input": "my_input",
2203              "model_output": "my_output",
2204              "model_output_parameters": [{"key": "latency", "value": "100"}],
2205              "mlflow_version": "1.0.0",
2206              "user_id": "username",
2207              "start_time": 456,
2208          },
2209      )
2210      assert response.status_code == 200
2211      run_json = response.json()
2212      assert run_json["run"]["info"]["run_name"] == "my_run_name"
2213      assert run_json["run"]["info"]["experiment_id"] == experiment_id
2214      assert run_json["run"]["info"]["user_id"] == "username"
2215      assert run_json["run"]["info"]["status"] == "FINISHED"
2216      assert run_json["run"]["info"]["start_time"] == 456
2217  
2218      assert {"key": "model_route", "value": "my_route"} in run_json["run"]["data"]["params"]
2219      assert {"key": "prompt_template", "value": "my_prompt_template"} in run_json["run"]["data"][
2220          "params"
2221      ]
2222      assert {"key": "temperature", "value": "0.1"} in run_json["run"]["data"]["params"]
2223  
2224      assert {
2225          "key": "mlflow.loggedArtifacts",
2226          "value": '[{"path": "eval_results_table.json", "type": "table"}]',
2227      } in run_json["run"]["data"]["tags"]
2228      assert {"key": "mlflow.runSourceType", "value": "PROMPT_ENGINEERING"} in run_json["run"][
2229          "data"
2230      ]["tags"]
2231  
2232  
2233  def test_gateway_proxy_handler_rejects_invalid_requests(mlflow_client):
2234      def assert_response(resp, message_part):
2235          assert resp.status_code == 400
2236          response_json = resp.json()
2237          assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE"
2238          assert message_part in response_json.get("message", "")
2239  
2240      with _init_server(
2241          backend_uri=mlflow_client.tracking_uri,
2242          root_artifact_uri=mlflow_client.tracking_uri,
2243          extra_env={"MLFLOW_DEPLOYMENTS_TARGET": "http://localhost:5001"},
2244          server_type="flask",
2245      ) as url:
2246          patched_client = MlflowClient(url)
2247  
2248          response = requests.post(
2249              f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy",
2250              json={},
2251          )
2252          assert_response(
2253              response,
2254              "Deployments proxy request must specify a gateway_path.",
2255          )
2256  
2257          response = requests.post(
2258              f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy",
2259              json={"gateway_path": "foo/bar"},
2260          )
2261          assert_response(
2262              response,
2263              "Invalid gateway_path: foo/bar for method: POST",
2264          )
2265  
2266          response = requests.post(
2267              f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy",
2268              json={"gateway_path": "foo/bar/baz"},
2269          )
2270          assert_response(
2271              response,
2272              "Invalid gateway_path: foo/bar/baz for method: POST",
2273          )
2274  
2275          response = requests.get(
2276              f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy",
2277              params={"gateway_path": "hello/world"},
2278          )
2279          assert_response(
2280              response,
2281              "Invalid gateway_path: hello/world for method: GET",
2282          )
2283  
2284          # Unsupported method
2285          response = requests.delete(
2286              f"{patched_client.tracking_uri}/ajax-api/2.0/mlflow/gateway-proxy",
2287          )
2288          assert response.status_code == 405
2289  
2290  
2291  def test_upload_artifact_handler_rejects_invalid_requests(mlflow_client):
2292      def assert_response(resp, message_part):
2293          assert resp.status_code == 400
2294          response_json = resp.json()
2295          assert response_json.get("error_code") == "INVALID_PARAMETER_VALUE"
2296          assert message_part in response_json.get("message", "")
2297  
2298      experiment_id = mlflow_client.create_experiment("upload_artifacts_test")
2299      created_run = mlflow_client.create_run(experiment_id)
2300  
2301      response = requests.post(
2302          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact", params={}
2303      )
2304      assert_response(response, "Request must specify run_uuid.")
2305  
2306      response = requests.post(
2307          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact",
2308          params={
2309              "run_uuid": created_run.info.run_id,
2310          },
2311      )
2312      assert_response(response, "Request must specify path.")
2313  
2314      response = requests.post(
2315          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact",
2316          params={"run_uuid": created_run.info.run_id, "path": ""},
2317      )
2318      assert_response(response, "Request must specify path.")
2319  
2320      response = requests.post(
2321          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact",
2322          params={"run_uuid": created_run.info.run_id, "path": "../test.txt"},
2323      )
2324      assert_response(response, "Invalid path")
2325  
2326      response = requests.post(
2327          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact",
2328          params={
2329              "run_uuid": created_run.info.run_id,
2330              "path": "test.txt",
2331          },
2332      )
2333      assert_response(response, "Request must specify data.")
2334  
2335  
2336  def test_upload_artifact_handler(mlflow_client):
2337      experiment_id = mlflow_client.create_experiment("upload_artifacts_test")
2338      created_run = mlflow_client.create_run(experiment_id)
2339  
2340      response = requests.post(
2341          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/upload-artifact",
2342          params={
2343              "run_uuid": created_run.info.run_id,
2344              "path": "test.txt",
2345          },
2346          data="hello world",
2347      )
2348      assert response.status_code == 200
2349  
2350      response = requests.get(
2351          f"{mlflow_client.tracking_uri}/get-artifact",
2352          params={
2353              "run_uuid": created_run.info.run_id,
2354              "path": "test.txt",
2355          },
2356      )
2357      assert response.status_code == 200
2358      assert response.text == "hello world"
2359  
2360  
2361  def test_graphql_handler(mlflow_client):
2362      response = requests.post(
2363          f"{mlflow_client.tracking_uri}/graphql",
2364          json={
2365              "query": 'query testQuery {test(inputString: "abc") { output }}',
2366              "operationName": "testQuery",
2367          },
2368          headers={"content-type": "application/json; charset=utf-8"},
2369      )
2370      assert response.status_code == 200
2371  
2372  
2373  def test_graphql_handler_batching_raise_error(mlflow_client):
2374      # Test max root fields limit
2375      batch_query = (
2376          "query testQuery {"
2377          + " ".join([
2378              f"key_{i}: " + 'test(inputString: "abc") { output }'
2379              for i in range(int(MLFLOW_SERVER_GRAPHQL_MAX_ROOT_FIELDS.get()) + 2)
2380          ])
2381          + "}"
2382      )
2383      response = requests.post(
2384          f"{mlflow_client.tracking_uri}/graphql",
2385          json={
2386              "query": batch_query,
2387              "operationName": "testQuery",
2388          },
2389          headers={"content-type": "application/json; charset=utf-8"},
2390      )
2391      assert response.status_code == 200
2392      assert (
2393          f"GraphQL queries should have at most {MLFLOW_SERVER_GRAPHQL_MAX_ROOT_FIELDS.get()}"
2394          in response.json()["errors"][0]
2395      )
2396  
2397      # Test max aliases limit
2398      batch_query = (
2399          'query testQuery {mlflowGetExperiment(input: {experimentId: "123"}) {'
2400          + " ".join(
2401              f"experiment_{i}: " + "experiment { name }"
2402              for i in range(int(MLFLOW_SERVER_GRAPHQL_MAX_ALIASES.get()) + 2)
2403          )
2404          + "}}"
2405      )
2406      response = requests.post(
2407          f"{mlflow_client.tracking_uri}/graphql",
2408          json={
2409              "query": batch_query,
2410              "operationName": "testQuery",
2411          },
2412      )
2413      assert response.status_code == 200
2414      assert (
2415          f"queries should have at most {MLFLOW_SERVER_GRAPHQL_MAX_ALIASES.get()} aliases"
2416          in response.json()["errors"][0]
2417      )
2418  
2419      # Test max depth limit
2420      inner = "name"
2421      for _ in range(12):
2422          inner = f"name {{ {inner} }}"
2423      deep_query = (
2424          'query testQuery { mlflowGetExperiment(input: {experimentId: "123"}) { experiment { '
2425          + inner
2426          + " } } }"
2427      )
2428      response = requests.post(
2429          f"{mlflow_client.tracking_uri}/graphql",
2430          json={
2431              "query": deep_query,
2432              "operationName": "testQuery",
2433          },
2434      )
2435      assert response.status_code == 200
2436      assert "Query exceeds maximum depth of 10" in response.json()["errors"][0]
2437  
2438      # Test max selections limit
2439      # Exceed the 1000 selection limit
2440      selections = [f"field_{i} {{ name }}" for i in range(1002)]
2441      selections_query = (
2442          'query testQuery { mlflowGetExperiment(input: {experimentId: "123"}) { experiment { '
2443          + " ".join(selections)
2444          + " } } }"
2445      )
2446      response = requests.post(
2447          f"{mlflow_client.tracking_uri}/graphql",
2448          json={
2449              "query": selections_query,
2450              "operationName": "testQuery",
2451          },
2452      )
2453      assert response.status_code == 200
2454      assert "Query exceeds maximum total selections of 1000" in response.json()["errors"][0]
2455  
2456  
2457  def test_get_experiment_graphql(mlflow_client):
2458      experiment_id = mlflow_client.create_experiment("GraphqlTest")
2459      response = requests.post(
2460          f"{mlflow_client.tracking_uri}/graphql",
2461          json={
2462              "query": 'query testQuery {mlflowGetExperiment(input: {experimentId: "'
2463              + experiment_id
2464              + '"}) { experiment { name } }}',
2465              "operationName": "testQuery",
2466          },
2467          headers={"content-type": "application/json; charset=utf-8"},
2468      )
2469      assert response.status_code == 200
2470      json = response.json()
2471      assert json["data"]["mlflowGetExperiment"]["experiment"]["name"] == "GraphqlTest"
2472  
2473  
2474  def test_get_run_and_experiment_graphql(mlflow_client):
2475      name = "GraphqlTest"
2476      mlflow_client.create_registered_model(name)
2477      experiment_id = mlflow_client.create_experiment(name)
2478      created_run = mlflow_client.create_run(experiment_id)
2479      run_id = created_run.info.run_id
2480      mlflow_client.create_model_version("GraphqlTest", "runs:/graphql_test/model", run_id)
2481      response = requests.post(
2482          f"{mlflow_client.tracking_uri}/graphql",
2483          json={
2484              "query": f"""
2485                  query testQuery @component(name: "Test") {{
2486                      mlflowGetRun(input: {{runId: "{run_id}"}}) {{
2487                          run {{
2488                              info {{
2489                                  status
2490                              }}
2491                              experiment {{
2492                                  name
2493                              }}
2494                              modelVersions {{
2495                                  name
2496                              }}
2497                          }}
2498                      }}
2499                  }}
2500              """,
2501              "operationName": "testQuery",
2502          },
2503          headers={"content-type": "application/json; charset=utf-8"},
2504      )
2505      assert response.status_code == 200
2506      json = response.json()
2507      assert json["errors"] is None
2508      assert json["data"]["mlflowGetRun"]["run"]["info"]["status"] == created_run.info.status
2509      assert json["data"]["mlflowGetRun"]["run"]["experiment"]["name"] == name
2510      assert json["data"]["mlflowGetRun"]["run"]["modelVersions"][0]["name"] == name
2511  
2512  
2513  def test_legacy_start_and_end_trace_v2(mlflow_client):
2514      experiment_id = mlflow_client.create_experiment("start end trace")
2515  
2516      # Trace CRUD APIs are not directly exposed as public API of MlflowClient,
2517      # so we use the underlying tracking client to test them.
2518      store = mlflow_client._tracing_client.store
2519  
2520      # Helper function to remove auto-added system tags (mlflow.xxx) from testing
2521      def _exclude_system_tags(tags: dict[str, str]):
2522          return {k: v for k, v in tags.items() if not k.startswith("mlflow.")}
2523  
2524      trace_info = store.deprecated_start_trace_v2(
2525          experiment_id=experiment_id,
2526          timestamp_ms=1000,
2527          request_metadata={
2528              "meta1": "apple",
2529              "meta2": "grape",
2530          },
2531          tags={
2532              "tag1": "football",
2533              "tag2": "basketball",
2534          },
2535      )
2536      assert trace_info.request_id is not None
2537      assert trace_info.experiment_id == experiment_id
2538      assert trace_info.timestamp_ms == 1000
2539      assert trace_info.execution_time_ms == 0
2540      assert trace_info.status == TraceStatus.IN_PROGRESS
2541      assert trace_info.request_metadata == {
2542          "meta1": "apple",
2543          "meta2": "grape",
2544      }
2545      assert _exclude_system_tags(trace_info.tags) == {
2546          "tag1": "football",
2547          "tag2": "basketball",
2548      }
2549  
2550      trace_info = store.deprecated_end_trace_v2(
2551          request_id=trace_info.request_id,
2552          timestamp_ms=3000,
2553          status=TraceStatus.OK,
2554          request_metadata={
2555              "meta1": "orange",
2556              "meta3": "banana",
2557          },
2558          tags={
2559              "tag1": "soccer",
2560              "tag3": "tennis",
2561          },
2562      )
2563      assert trace_info.request_id is not None
2564      assert trace_info.experiment_id == experiment_id
2565      assert trace_info.timestamp_ms == 1000
2566      assert trace_info.execution_time_ms == 2000
2567      assert trace_info.status == TraceStatus.OK
2568      assert trace_info.request_metadata == {
2569          "meta1": "orange",
2570          "meta2": "grape",
2571          "meta3": "banana",
2572      }
2573      assert _exclude_system_tags(trace_info.tags) == {
2574          "tag1": "soccer",
2575          "tag2": "basketball",
2576          "tag3": "tennis",
2577      }
2578  
2579  
2580  def test_start_trace(mlflow_client):
2581      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2582      experiment_id = mlflow.set_experiment("start end trace").experiment_id
2583  
2584      # Helper function to remove auto-added system tags (mlflow.xxx) from testing
2585      def _exclude_system_keys(d: dict[str, str]):
2586          return {k: v for k, v in d.items() if not k.startswith("mlflow.")}
2587  
2588      with mock.patch("mlflow.tracing.export.mlflow_v3._logger.warning") as mock_warning:
2589          with mlflow.start_span(name="test") as span:
2590              mlflow.update_current_trace(
2591                  tags={
2592                      "tag1": "football",
2593                      "tag2": "basketball",
2594                  },
2595                  metadata={
2596                      "meta1": "apple",
2597                      "meta2": "grape",
2598                  },
2599              )
2600  
2601      trace = mlflow_client.get_trace(span.trace_id, flush=True)
2602      assert trace.info.trace_id == span.trace_id
2603      assert trace.info.experiment_id == experiment_id
2604      assert trace.info.request_time > 0
2605      assert trace.info.execution_duration is not None
2606      assert trace.info.state == TraceState.OK
2607      assert _exclude_system_keys(trace.info.trace_metadata) == {
2608          "meta1": "apple",
2609          "meta2": "grape",
2610      }
2611      assert trace.info.trace_metadata[TRACE_SCHEMA_VERSION_KEY] == "3"
2612      assert _exclude_system_keys(trace.info.tags) == {
2613          "tag1": "football",
2614          "tag2": "basketball",
2615      }
2616  
2617      # No "Failed to log span to MLflow backend" warning should be issued
2618      for call in mock_warning.call_args_list:
2619          assert "Failed to log span to MLflow backend" not in str(call)
2620  
2621  
2622  def test_get_trace(mlflow_client):
2623      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2624      experiment_id = mlflow_client.create_experiment("get trace")
2625      span = mlflow_client.start_trace(name="test", experiment_id=experiment_id)
2626      mlflow_client.end_trace(request_id=span.request_id, status=TraceStatus.OK)
2627      trace = mlflow_client.get_trace(span.request_id, flush=True)
2628      assert trace is not None
2629      assert trace.info.request_id == span.request_id
2630      assert trace.info.experiment_id == experiment_id
2631      assert trace.info.state == TraceState.OK
2632      assert len(trace.data.spans) == 1
2633      assert trace.data.spans[0].name == "test"
2634      assert trace.data.spans[0].status.status_code == SpanStatusCode.OK
2635      assert trace.data.spans[0].status.description == ""
2636  
2637  
2638  def test_search_traces(mlflow_client):
2639      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2640      experiment_id = mlflow_client.create_experiment("search traces")
2641  
2642      # Create test traces
2643      def _create_trace(name, status):
2644          span = mlflow_client.start_trace(name=name, experiment_id=experiment_id)
2645          mlflow_client.end_trace(request_id=span.request_id, status=status)
2646          return span.request_id
2647  
2648      # Flush between creations to ensure distinct timestamps. Without this, all three traces
2649      # can land in the same millisecond on a fast local server, making max_results ordering
2650      # non-deterministic.
2651      request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK)
2652      mlflow.flush_trace_async_logging()
2653      request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK)
2654      mlflow.flush_trace_async_logging()
2655      request_id_3 = _create_trace(name="trace3", status=TraceStatus.ERROR)
2656      mlflow.flush_trace_async_logging()
2657  
2658      def _get_request_ids(traces):
2659          return [t.info.request_id for t in traces]
2660  
2661      # Validate search
2662      traces = mlflow_client.search_traces(locations=[experiment_id])
2663      assert set(_get_request_ids(traces)) == {request_id_3, request_id_2, request_id_1}
2664      assert traces.token is None
2665  
2666      traces = mlflow_client.search_traces(
2667          locations=[experiment_id],
2668          filter_string="status = 'OK'",
2669          order_by=["timestamp ASC"],
2670      )
2671      assert set(_get_request_ids(traces)) == {request_id_1, request_id_2}
2672      assert traces.token is None
2673  
2674      traces = mlflow_client.search_traces(
2675          locations=[experiment_id],
2676          max_results=2,
2677      )
2678      assert set(_get_request_ids(traces)) == {request_id_3, request_id_2}
2679      assert traces.token is not None
2680      traces = mlflow_client.search_traces(
2681          locations=[experiment_id],
2682          page_token=traces.token,
2683      )
2684      assert _get_request_ids(traces) == [request_id_1]
2685      assert traces.token is None
2686  
2687  
2688  def test_search_traces_parameter_validation(mlflow_client):
2689      with pytest.raises(
2690          MlflowException,
2691          match="Locations must be a list of experiment IDs",
2692      ):
2693          mlflow_client.search_traces(locations=["catalog.schema"])
2694  
2695  
2696  def test_search_traces_match_text(mlflow_client, store_type):
2697      if store_type == "file":
2698          pytest.skip("File store doesn't support full text search")
2699  
2700      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2701      experiment_id = mlflow_client.create_experiment("search traces full text")
2702  
2703      # Create test traces
2704      def _create_trace(name, attributes):
2705          span = mlflow_client.start_trace(name=name, experiment_id=experiment_id)
2706          span.set_attributes(attributes)
2707          mlflow_client.end_trace(request_id=span.trace_id, status=TraceStatus.OK)
2708          return span.trace_id
2709  
2710      trace_id_1 = _create_trace(name="trace1", attributes={"test": "value1"})
2711      trace_id_2 = _create_trace(name="trace2", attributes={"test": "value2"})
2712      trace_id_3 = _create_trace(name="trace3", attributes={"test3": "I like it"})
2713  
2714      traces = mlflow_client.search_traces(locations=[experiment_id], flush=True)
2715      assert len([t.info.trace_id for t in traces]) == 3
2716      assert traces.token is None
2717  
2718      traces = mlflow_client.search_traces(
2719          locations=[experiment_id], filter_string="trace.text LIKE '%trace%'"
2720      )
2721      assert len([t.info.trace_id for t in traces]) == 3
2722      assert traces.token is None
2723  
2724      traces = mlflow_client.search_traces(
2725          locations=[experiment_id], filter_string="trace.text LIKE '%value%'"
2726      )
2727      assert {t.info.trace_id for t in traces} == {trace_id_1, trace_id_2}
2728  
2729      traces = mlflow_client.search_traces(
2730          locations=[experiment_id], filter_string="trace.text LIKE '%I like it%'"
2731      )
2732      assert [t.info.trace_id for t in traces] == [trace_id_3]
2733  
2734  
2735  def test_delete_traces(mlflow_client):
2736      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2737      experiment_id = mlflow_client.create_experiment("delete traces")
2738  
2739      def _create_trace(name, status):
2740          span = mlflow_client.start_trace(name=name, experiment_id=experiment_id)
2741          mlflow_client.end_trace(request_id=span.request_id, status=status)
2742          return span.request_id
2743  
2744      def _is_trace_exists(request_id):
2745          try:
2746              trace_info = mlflow_client._tracing_client.get_trace_info(request_id)
2747              return trace_info is not None
2748          except RestException as e:
2749              if e.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST):
2750                  return False
2751              raise
2752  
2753      # Case 1: Delete all traces under experiment ID
2754      request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK)
2755      request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK)
2756      mlflow.flush_trace_async_logging()
2757      assert _is_trace_exists(request_id_1)
2758      assert _is_trace_exists(request_id_2)
2759  
2760      deleted_count = mlflow_client.delete_traces(experiment_id, max_timestamp_millis=int(1e15))
2761      assert deleted_count == 2
2762      assert not _is_trace_exists(request_id_1)
2763      assert not _is_trace_exists(request_id_2)
2764  
2765      # Case 2: Delete with max_traces limit
2766      request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK)
2767      time.sleep(0.1)  # Add some time gap to avoid timestamp collision
2768      request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK)
2769      mlflow.flush_trace_async_logging()
2770  
2771      deleted_count = mlflow_client.delete_traces(
2772          experiment_id, max_traces=1, max_timestamp_millis=int(1e15)
2773      )
2774      assert deleted_count == 1
2775      # TODO: Currently the deletion order in the file store is random (based on
2776      # the order of the trace files in the directory), so we don't validate which
2777      # one is deleted. Uncomment the following lines once the deletion order is fixed.
2778      # assert not _is_trace_exists(request_id_1)  # Old created trace should be deleted
2779      # assert _is_trace_exists(request_id_2)
2780  
2781      # Case 3: Delete with explicit request ID
2782      request_id_1 = _create_trace(name="trace1", status=TraceStatus.OK)
2783      request_id_2 = _create_trace(name="trace2", status=TraceStatus.OK)
2784      mlflow.flush_trace_async_logging()
2785  
2786      deleted_count = mlflow_client.delete_traces(experiment_id, trace_ids=[request_id_1])
2787      assert deleted_count == 1
2788      assert not _is_trace_exists(request_id_1)
2789      assert _is_trace_exists(request_id_2)
2790  
2791  
2792  def test_calculate_trace_filter_correlation(mlflow_client, store_type):
2793      if store_type == "file":
2794          pytest.skip("File store doesn't support calculate_trace_filter_correlation")
2795  
2796      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2797      experiment_id = mlflow_client.create_experiment("correlation test")
2798  
2799      def _create_trace(name, tags):
2800          span = mlflow_client.start_trace(name=name, experiment_id=experiment_id, tags=tags)
2801          mlflow_client.end_trace(request_id=span.request_id, status=TraceStatus.OK)
2802          return span.request_id
2803  
2804      for i in range(6):
2805          _create_trace(f"trace-prod-tool-{i}", {"env": "prod", "span_type": "TOOL"})
2806  
2807      for i in range(4):
2808          _create_trace(f"trace-dev-{i}", {"env": "dev", "span_type": "LLM" if i >= 1 else "TOOL"})
2809  
2810      client = TracingClient(tracking_uri=mlflow_client.tracking_uri)
2811  
2812      mlflow.flush_trace_async_logging()
2813  
2814      result = client.calculate_trace_filter_correlation(
2815          experiment_ids=[experiment_id],
2816          filter_string1="tags.env = 'prod'",
2817          filter_string2="tags.span_type = 'TOOL'",
2818      )
2819  
2820      assert isinstance(result, TraceFilterCorrelationResult)
2821      assert result.total_count == 10
2822      assert result.filter1_count == 6
2823      assert result.filter2_count == 7
2824      assert result.joint_count == 6
2825      assert 0.6 < result.npmi < 0.8
2826      assert result.npmi_smoothed is not None
2827  
2828      result2 = client.calculate_trace_filter_correlation(
2829          experiment_ids=[experiment_id],
2830          filter_string1="tags.env = 'dev'",
2831          filter_string2="tags.span_type = 'LLM'",
2832      )
2833  
2834      assert result2.total_count == 10
2835      assert result2.filter1_count == 4
2836      assert result2.filter2_count == 3
2837      assert result2.joint_count == 3
2838      assert result2.npmi > 0.5
2839  
2840      result3 = client.calculate_trace_filter_correlation(
2841          experiment_ids=[experiment_id],
2842          filter_string1="tags.env = 'staging'",
2843          filter_string2="tags.span_type = 'TOOL'",
2844      )
2845  
2846      assert result3.total_count == 10
2847      assert result3.filter1_count == 0
2848      assert result3.filter2_count == 7
2849      assert result3.joint_count == 0
2850      assert math.isnan(result3.npmi)
2851  
2852      with pytest.raises(MlflowException, match="Invalid"):
2853          client.calculate_trace_filter_correlation(
2854              experiment_ids=[experiment_id],
2855              filter_string1="invalid.filter = 'test'",
2856              filter_string2="tags.span_type = 'TOOL'",
2857          )
2858  
2859  
2860  def test_set_and_delete_trace_tag(mlflow_client):
2861      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2862      experiment_id = mlflow_client.create_experiment("set delete tag")
2863  
2864      # Create test trace
2865      trace_info = mlflow_client._tracing_client.start_trace(
2866          TraceInfo(
2867              trace_id="tr-1234",
2868              trace_location=TraceLocation.from_experiment_id(experiment_id),
2869              request_time=1000,
2870              execution_duration=2000,
2871              state=TraceState.OK,
2872              tags={
2873                  "tag1": "red",
2874                  "tag2": "blue",
2875              },
2876          )
2877      )
2878  
2879      # Validate set tag
2880      mlflow_client.set_trace_tag(trace_info.request_id, "tag1", "green")
2881      trace_info = mlflow_client._tracing_client.get_trace_info(trace_info.request_id)
2882      assert trace_info.tags["tag1"] == "green"
2883  
2884      # Validate delete tag
2885      mlflow_client.delete_trace_tag(trace_info.request_id, "tag2")
2886      trace_info = mlflow_client._tracing_client.get_trace_info(trace_info.request_id)
2887      assert "tag2" not in trace_info.tags
2888  
2889  
2890  def test_query_trace_metrics(mlflow_client, store_type):
2891      if store_type == "file":
2892          pytest.skip("File store doesn't support query trace metrics")
2893  
2894      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2895      experiment_id = mlflow_client.create_experiment("query trace metrics")
2896  
2897      # Create test traces
2898      def _create_trace(name, status):
2899          span = mlflow_client.start_trace(name=name, experiment_id=experiment_id)
2900          mlflow_client.end_trace(request_id=span.request_id, status=status)
2901          return span.request_id
2902  
2903      _create_trace(name="trace1", status=TraceStatus.OK)
2904      _create_trace(name="trace2", status=TraceStatus.OK)
2905      _create_trace(name="trace3", status=TraceStatus.ERROR)
2906  
2907      mlflow.flush_trace_async_logging()
2908  
2909      metrics = mlflow_client._tracing_client.store.query_trace_metrics(
2910          experiment_ids=[experiment_id],
2911          view_type=MetricViewType.TRACES,
2912          metric_name=TraceMetricKey.TRACE_COUNT,
2913          aggregations=[MetricAggregation(aggregation_type=AggregationType.COUNT)],
2914          dimensions=[TraceMetricDimensionKey.TRACE_STATUS],
2915      )
2916      assert len(metrics) == 2
2917      assert asdict(metrics[0]) == {
2918          "metric_name": TraceMetricKey.TRACE_COUNT,
2919          "dimensions": {TraceMetricDimensionKey.TRACE_STATUS: "ERROR"},
2920          "values": {"COUNT": 1},
2921      }
2922  
2923      assert asdict(metrics[1]) == {
2924          "metric_name": TraceMetricKey.TRACE_COUNT,
2925          "dimensions": {TraceMetricDimensionKey.TRACE_STATUS: "OK"},
2926          "values": {"COUNT": 2},
2927      }
2928  
2929  
2930  @pytest.mark.parametrize("allow_partial", [True, False])
2931  def test_get_trace_handler(mlflow_client, allow_partial: bool, store_type):
2932      if store_type == "file":
2933          pytest.skip("File store doesn't support get trace handler")
2934  
2935      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2936  
2937      with mlflow.start_span(name="test") as span:
2938          span.set_attributes({"fruit": "apple"})
2939  
2940      mlflow.flush_trace_async_logging()
2941  
2942      response = requests.get(
2943          f"{mlflow_client.tracking_uri}/ajax-api/3.0/mlflow/traces/get",
2944          params={"trace_id": span.trace_id, "allow_partial": allow_partial},
2945      )
2946  
2947      assert response.status_code == 200
2948  
2949      trace = response.json()["trace"]
2950      assert trace["trace_info"]["trace_id"] == span.trace_id
2951      assert len(trace["spans"]) == 1
2952      assert trace["spans"][0]["name"] == "test"
2953      attributes = trace["spans"][0]["attributes"]
2954      assert {"key": "fruit", "value": {"string_value": "apple"}} in attributes
2955  
2956  
2957  def test_get_trace_artifact_handler(mlflow_client):
2958      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2959  
2960      with mlflow.start_span(name="test") as span:
2961          span.set_attributes({"fruit": "apple"})
2962          span.add_event(SpanEvent("test_event", timestamp=99999, attributes={"foo": "bar"}))
2963  
2964      mlflow.flush_trace_async_logging()
2965  
2966      response = requests.get(
2967          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/get-trace-artifact",
2968          params={"request_id": span.trace_id},
2969      )
2970      assert response.status_code == 200
2971      assert response.headers["Content-Disposition"] == "attachment; filename=traces.json"
2972  
2973      # Validate content
2974      trace_data = TraceData.from_dict(json.loads(response.text))
2975      assert trace_data.spans[0].to_dict() == span.to_dict()
2976  
2977  
2978  def test_link_traces_to_run_and_search_traces(mlflow_client, store_type):
2979      # Skip file store because it doesn't support linking traces to runs
2980      if store_type == "file":
2981          pytest.skip("File store doesn't support linking traces to runs")
2982  
2983      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
2984      experiment_id = mlflow.set_experiment("link traces to run test").experiment_id
2985  
2986      run = mlflow_client.create_run(experiment_id)
2987      run_id = run.info.run_id
2988  
2989      # 1. Trace created under a run
2990      with mlflow.start_run(run_id=run_id):
2991          with mlflow.start_span(name="trace1") as span1:
2992              span1.set_attributes({"test": "value1"})
2993          trace_id_1 = span1.trace_id
2994  
2995      # 2. Trace associated with a run
2996      with mlflow.start_span(name="trace2") as span2:
2997          span2.set_attributes({"test": "value2"})
2998      trace_id_2 = span2.trace_id
2999      mlflow_client.link_traces_to_run(trace_ids=[trace_id_2], run_id=run_id)
3000  
3001      # 3. Trace not associated with a run
3002      with mlflow.start_span(name="trace3") as span3:
3003          span3.set_attributes({"test": "value3"})
3004      trace_id_3 = span3.trace_id
3005  
3006      # Search traces without run_id filter - should return all traces in experiment
3007      all_traces = mlflow_client.search_traces(locations=[experiment_id], flush=True)
3008      assert {t.info.trace_id for t in all_traces} == {trace_id_1, trace_id_2, trace_id_3}
3009  
3010      # Search traces with run_id filter - should return only linked traces
3011      linked_traces = mlflow_client.search_traces(
3012          locations=[experiment_id], filter_string=f"attribute.run_id = '{run_id}'"
3013      )
3014      linked_trace_ids = [t.info.trace_id for t in linked_traces]
3015      assert len(linked_trace_ids) == 2
3016      assert set(linked_trace_ids) == {trace_id_1, trace_id_2}
3017  
3018  
3019  def test_get_metric_history_bulk_interval_graphql(mlflow_client):
3020      name = "GraphqlTest"
3021      mlflow_client.create_registered_model(name)
3022      experiment_id = mlflow_client.create_experiment(name)
3023      created_run = mlflow_client.create_run(experiment_id)
3024  
3025      metric_name = "metric_0"
3026      for i in range(10):
3027          mlflow_client.log_metric(created_run.info.run_id, metric_name, i, step=i)
3028  
3029      response = requests.post(
3030          f"{mlflow_client.tracking_uri}/graphql",
3031          json={
3032              "query": f"""
3033                  query testQuery {{
3034                      mlflowGetMetricHistoryBulkInterval(input: {{
3035                          runIds: ["{created_run.info.run_id}"],
3036                          metricKey: "{metric_name}",
3037                      }}) {{
3038                          metrics {{
3039                              key
3040                              timestamp
3041                              value
3042                          }}
3043                      }}
3044                  }}
3045              """,
3046              "operationName": "testQuery",
3047          },
3048          headers={"content-type": "application/json; charset=utf-8"},
3049      )
3050  
3051      assert response.status_code == 200
3052      json = response.json()
3053      expected = [{"key": metric_name, "timestamp": mock.ANY, "value": i} for i in range(10)]
3054      assert json["data"]["mlflowGetMetricHistoryBulkInterval"]["metrics"] == expected
3055  
3056  
3057  def test_search_runs_graphql(mlflow_client):
3058      name = "GraphqlTest"
3059      mlflow_client.create_registered_model(name)
3060      experiment_id = mlflow_client.create_experiment(name)
3061      created_run_1 = mlflow_client.create_run(experiment_id)
3062      created_run_2 = mlflow_client.create_run(experiment_id)
3063  
3064      response = requests.post(
3065          f"{mlflow_client.tracking_uri}/graphql",
3066          json={
3067              "query": f"""
3068                  mutation testMutation {{
3069                      mlflowSearchRuns(input: {{ experimentIds: ["{experiment_id}"] }}) {{
3070                          runs {{
3071                              info {{
3072                                  runId
3073                              }}
3074                          }}
3075                      }}
3076                  }}
3077              """,
3078              "operationName": "testMutation",
3079          },
3080          headers={"content-type": "application/json; charset=utf-8"},
3081      )
3082  
3083      assert response.status_code == 200
3084      json = response.json()
3085      expected = [
3086          {"info": {"runId": created_run_2.info.run_id}},
3087          {"info": {"runId": created_run_1.info.run_id}},
3088      ]
3089      assert json["data"]["mlflowSearchRuns"]["runs"] == expected
3090  
3091  
3092  def test_list_artifacts_graphql(mlflow_client, tmp_path):
3093      name = "GraphqlTest"
3094      experiment_id = mlflow_client.create_experiment(name)
3095      created_run_id = mlflow_client.create_run(experiment_id).info.run_id
3096      file_path = tmp_path / "test.txt"
3097      file_path.write_text("hello world")
3098      mlflow_client.log_artifact(created_run_id, file_path.absolute().as_posix())
3099      mlflow_client.log_artifact(created_run_id, file_path.absolute().as_posix(), "testDir")
3100  
3101      response = requests.post(
3102          f"{mlflow_client.tracking_uri}/graphql",
3103          json={
3104              "query": f"""
3105                  query testQuery {{
3106                      files: mlflowListArtifacts(input: {{
3107                          runId: "{created_run_id}",
3108                      }}) {{
3109                              files {{
3110                              path
3111                              isDir
3112                              fileSize
3113                          }}
3114                      }}
3115                  }}
3116              """,
3117              "operationName": "testQuery",
3118          },
3119          headers={"content-type": "application/json; charset=utf-8"},
3120      )
3121  
3122      assert response.status_code == 200
3123      json = response.json()
3124      file_expected = [
3125          {"path": "test.txt", "isDir": False, "fileSize": "11"},
3126          {"path": "testDir", "isDir": True, "fileSize": "0"},
3127      ]
3128      assert json["data"]["files"]["files"] == file_expected
3129  
3130      response = requests.post(
3131          f"{mlflow_client.tracking_uri}/graphql",
3132          json={
3133              "query": f"""
3134                  query testQuery {{
3135                      subdir: mlflowListArtifacts(input: {{
3136                          runId: "{created_run_id}",
3137                          path: "testDir",
3138                      }}) {{
3139                              files {{
3140                              path
3141                              isDir
3142                              fileSize
3143                          }}
3144                      }}
3145                  }}
3146              """,
3147              "operationName": "testQuery",
3148          },
3149          headers={"content-type": "application/json; charset=utf-8"},
3150      )
3151  
3152      assert response.status_code == 200
3153      json = response.json()
3154      subdir_expected = [
3155          {"path": "testDir/test.txt", "isDir": False, "fileSize": "11"},
3156      ]
3157      assert json["data"]["subdir"]["files"] == subdir_expected
3158  
3159  
3160  def test_search_datasets_graphql(mlflow_client):
3161      name = "GraphqlTest"
3162      experiment_id = mlflow_client.create_experiment(name)
3163      created_run_id = mlflow_client.create_run(experiment_id).info.run_id
3164      dataset1 = Dataset(
3165          name="test-dataset-1",
3166          digest="12345",
3167          source_type="script",
3168          source="test",
3169      )
3170      dataset_input1 = DatasetInput(dataset=dataset1, tags=[])
3171      dataset2 = Dataset(
3172          name="test-dataset-2",
3173          digest="12346",
3174          source_type="script",
3175          source="test",
3176      )
3177      dataset_input2 = DatasetInput(
3178          dataset=dataset2, tags=[InputTag(key=MLFLOW_DATASET_CONTEXT, value="training")]
3179      )
3180      mlflow_client.log_inputs(created_run_id, [dataset_input1, dataset_input2])
3181  
3182      response = requests.post(
3183          f"{mlflow_client.tracking_uri}/graphql",
3184          json={
3185              "query": f"""
3186                  mutation testMutation {{
3187                      mlflowSearchDatasets(input:{{experimentIds: ["{experiment_id}"]}}) {{
3188                          datasetSummaries {{
3189                              experimentId
3190                              name
3191                              digest
3192                              context
3193                          }}
3194                      }}
3195                  }}
3196              """,
3197              "operationName": "testMutation",
3198          },
3199          headers={"content-type": "application/json; charset=utf-8"},
3200      )
3201  
3202      assert response.status_code == 200
3203      json = response.json()
3204  
3205      def sort_dataset_summaries(l1):
3206          return sorted(l1, key=lambda x: x["digest"])
3207  
3208      expected = sort_dataset_summaries([
3209          {
3210              "experimentId": experiment_id,
3211              "name": "test-dataset-2",
3212              "digest": "12346",
3213              "context": "training",
3214          },
3215          {
3216              "experimentId": experiment_id,
3217              "name": "test-dataset-1",
3218              "digest": "12345",
3219              "context": "",
3220          },
3221      ])
3222      assert (
3223          sort_dataset_summaries(json["data"]["mlflowSearchDatasets"]["datasetSummaries"]) == expected
3224      )
3225  
3226  
3227  def test_create_logged_model(mlflow_client: MlflowClient):
3228      exp_id = mlflow_client.create_experiment("create_logged_model")
3229      model = mlflow_client.create_logged_model(exp_id)
3230      loaded_model = mlflow_client.get_logged_model(model.model_id)
3231      assert model.model_id == loaded_model.model_id
3232  
3233      model = mlflow_client.create_logged_model(exp_id, name="my_model")
3234      loaded_model = mlflow_client.get_logged_model(model.model_id)
3235      assert model.name == "my_model"
3236  
3237      model = mlflow_client.create_logged_model(exp_id, model_type="LLM")
3238      loaded_model = mlflow_client.get_logged_model(model.model_id)
3239      assert model.model_type == "LLM"
3240  
3241      model = mlflow_client.create_logged_model(exp_id, source_run_id="123")
3242      loaded_model = mlflow_client.get_logged_model(model.model_id)
3243      assert model.source_run_id == "123"
3244  
3245      model = mlflow_client.create_logged_model(exp_id, params={"param": "value"})
3246      loaded_model = mlflow_client.get_logged_model(model.model_id)
3247      assert model.params == {"param": "value"}
3248  
3249      model = mlflow_client.create_logged_model(exp_id, tags={"tag": "value"})
3250      loaded_model = mlflow_client.get_logged_model(model.model_id)
3251      assert model.tags == {"tag": "value"}
3252  
3253  
3254  def test_log_logged_model_params(mlflow_client: MlflowClient):
3255      exp_id = mlflow_client.create_experiment("create_logged_model")
3256      model = mlflow_client.create_logged_model(exp_id)
3257      mlflow_client.log_model_params(model.model_id, {"param": "value"})
3258      loaded_model = mlflow_client.get_logged_model(model.model_id)
3259      assert loaded_model.params == {"param": "value"}
3260  
3261  
3262  def test_finalize_logged_model(mlflow_client: MlflowClient):
3263      exp_id = mlflow_client.create_experiment("create_logged_model")
3264      model = mlflow_client.create_logged_model(exp_id)
3265      finalized_model = mlflow_client.finalize_logged_model(model.model_id, LoggedModelStatus.READY)
3266      assert finalized_model.status == LoggedModelStatus.READY
3267  
3268      finalized_model = mlflow_client.finalize_logged_model(model.model_id, LoggedModelStatus.FAILED)
3269      assert finalized_model.status == LoggedModelStatus.FAILED
3270  
3271  
3272  def test_delete_logged_model(mlflow_client: MlflowClient):
3273      exp_id = mlflow_client.create_experiment("delete_logged_model")
3274      model = mlflow_client.create_logged_model(experiment_id=exp_id)
3275      mlflow_client.delete_logged_model(model.model_id)
3276      with pytest.raises(MlflowException, match="not found"):
3277          mlflow_client.get_logged_model(model.model_id)
3278  
3279      models = mlflow_client.search_logged_models(experiment_ids=[exp_id])
3280      assert len(models) == 0
3281  
3282  
3283  def test_set_logged_model_tags(mlflow_client: MlflowClient):
3284      exp_id = mlflow_client.create_experiment("create_logged_model")
3285      model = mlflow_client.create_logged_model(exp_id)
3286      mlflow_client.set_logged_model_tags(model.model_id, {"tag1": "value1", "tag2": "value2"})
3287      loaded_model = mlflow_client.get_logged_model(model.model_id)
3288      assert loaded_model.tags == {"tag1": "value1", "tag2": "value2"}
3289  
3290      mlflow_client.set_logged_model_tags(model.model_id, {"tag1": "value3"})
3291      loaded_model = mlflow_client.get_logged_model(model.model_id)
3292      assert loaded_model.tags == {"tag1": "value3", "tag2": "value2"}
3293  
3294  
3295  def test_delete_logged_model_tag(mlflow_client: MlflowClient):
3296      exp_id = mlflow_client.create_experiment("create_logged_model")
3297      model = mlflow_client.create_logged_model(exp_id)
3298      mlflow_client.set_logged_model_tags(model.model_id, {"tag1": "value1", "tag2": "value2"})
3299      mlflow_client.delete_logged_model_tag(model.model_id, "tag1")
3300      loaded_model = mlflow_client.get_logged_model(model.model_id)
3301      assert loaded_model.tags == {"tag2": "value2"}
3302  
3303      with pytest.raises(MlflowException, match="No tag with key"):
3304          mlflow_client.delete_logged_model_tag(model.model_id, "tag1")
3305  
3306  
3307  def test_search_logged_models(mlflow_client: MlflowClient):
3308      exp_id = mlflow_client.create_experiment("create_logged_model")
3309      model_1 = mlflow_client.create_logged_model(exp_id)
3310      time.sleep(0.001)  # to ensure different created time
3311      models = mlflow_client.search_logged_models(experiment_ids=[exp_id])
3312      assert [m.name for m in models] == [model_1.name]
3313  
3314      # max_results
3315      model_2 = mlflow_client.create_logged_model(exp_id)
3316      page_1 = mlflow_client.search_logged_models(experiment_ids=[exp_id], max_results=1)
3317      assert [m.name for m in page_1] == [model_2.name]
3318      assert page_1.token is not None
3319  
3320      # pagination
3321      page_2 = mlflow_client.search_logged_models(
3322          experiment_ids=[exp_id], max_results=1, page_token=page_1.token
3323      )
3324      assert [m.name for m in page_2] == [model_1.name]
3325      assert page_2.token is None
3326  
3327      # filter_string
3328      models = mlflow_client.search_logged_models(
3329          experiment_ids=[exp_id], filter_string=f"name = {model_1.name!r}"
3330      )
3331      assert [m.name for m in models] == [model_1.name]
3332  
3333      # datasets
3334      run_1 = mlflow_client.create_run(exp_id)
3335      mlflow_client.log_metric(
3336          run_1.info.run_id,
3337          key="metric",
3338          value=1,
3339          dataset_name="dataset",
3340          dataset_digest="123",
3341          model_id=model_1.model_id,
3342      )
3343      models = mlflow_client.search_logged_models(
3344          experiment_ids=[exp_id],
3345          datasets=[{"dataset_name": "dataset", "dataset_digest": "123"}],
3346      )
3347  
3348      assert [m.name for m in models] == [model_1.name]
3349  
3350      # order_by
3351      models = mlflow_client.search_logged_models(
3352          experiment_ids=[exp_id],
3353          order_by=[{"field_name": "creation_timestamp", "ascending": False}],
3354      )
3355      assert [m.name for m in models] == [model_2.name, model_1.name]
3356  
3357  
3358  def test_log_outputs(mlflow_client: MlflowClient):
3359      exp_id = mlflow_client.create_experiment("log_outputs")
3360      run = mlflow_client.create_run(experiment_id=exp_id)
3361      model = mlflow_client.create_logged_model(experiment_id=exp_id)
3362      model_outputs = [LoggedModelOutput(model.model_id, 1)]
3363      mlflow_client.log_outputs(run.info.run_id, model_outputs)
3364      run = mlflow_client.get_run(run.info.run_id)
3365      assert run.outputs.model_outputs == model_outputs
3366  
3367  
3368  def test_list_logged_model_artifacts(mlflow_client: MlflowClient):
3369      class Model(mlflow.pyfunc.PythonModel):
3370          def predict(self, context, model_input):
3371              return model_input
3372  
3373      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
3374      model_info = mlflow.pyfunc.log_model(name="model", python_model=Model())
3375      resp = requests.get(
3376          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/logged-models/{model_info.model_id}/artifacts/directories"
3377      )
3378      assert resp.status_code == 200
3379      data = resp.json()
3380      paths = [f["path"] for f in data["files"]]
3381      assert "MLmodel" in paths
3382  
3383  
3384  def test_get_logged_model_artifact(mlflow_client: MlflowClient):
3385      class Model(mlflow.pyfunc.PythonModel):
3386          def predict(self, context, model_input):
3387              return model_input
3388  
3389      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
3390      model_info = mlflow.pyfunc.log_model(name="model", python_model=Model())
3391      resp = requests.get(
3392          f"{mlflow_client.tracking_uri}/ajax-api/2.0/mlflow/logged-models/{model_info.model_id}/artifacts/files",
3393          params={"artifact_file_path": "MLmodel"},
3394      )
3395      assert resp.status_code == 200
3396      assert model_info.model_id in resp.text
3397  
3398  
3399  def test_suppress_url_printing(mlflow_client: MlflowClient, monkeypatch):
3400      monkeypatch.setenv(MLFLOW_SUPPRESS_PRINTING_URL_TO_STDOUT.name, "true")
3401      exp_id = mlflow_client.create_experiment("test_suppress_url_printing")
3402      run = mlflow_client.create_run(experiment_id=exp_id)
3403      captured_output = StringIO()
3404      monkeypatch.setattr(sys, "stdout", captured_output)
3405      mlflow_client._tracking_client._log_url(run.info.run_id)
3406      assert captured_output.getvalue() == ""
3407  
3408  
3409  def test_log_url_includes_workspace_when_set(mlflow_client: MlflowClient, monkeypatch):
3410      exp_id = mlflow_client.create_experiment("test_log_url_workspace")
3411      run = mlflow_client.create_run(experiment_id=exp_id)
3412      captured_output = StringIO()
3413      monkeypatch.setattr(sys, "stdout", captured_output)
3414      monkeypatch.setattr(
3415          "mlflow.tracking._tracking_service.client.get_workspace_url", lambda: "http://localhost"
3416      )
3417      monkeypatch.setattr(
3418          "mlflow.tracking._tracking_service.client.get_request_workspace", lambda: "team-space"
3419      )
3420  
3421      mlflow_client._tracking_client._log_url(run.info.run_id)
3422  
3423      out = captured_output.getvalue()
3424      expected_fragment = f"/#/experiments/{exp_id}/runs/{run.info.run_id}?workspace=team-space"
3425      assert expected_fragment in out
3426  
3427  
3428  def test_assessments_end_to_end(mlflow_client):
3429      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
3430  
3431      # Set up experiment and trace
3432      experiment_id = mlflow_client.create_experiment("assessment_crud_test")
3433      trace_info = mlflow_client.start_trace(name="test_trace", experiment_id=experiment_id)
3434      mlflow_client.end_trace(request_id=trace_info.request_id)
3435      mlflow.flush_trace_async_logging()
3436  
3437      # CREATE initial feedback assessment
3438      feedback_payload = {
3439          "assessment": {
3440              "assessment_name": "quality_score",
3441              "feedback": {"value": {"rating": 4, "comments": "Good response"}},
3442              "source": {"source_type": "HUMAN", "source_id": "evaluator@company.com"},
3443              "rationale": "Response was accurate and helpful",
3444              "metadata": {"model": "gpt-4", "version": "1.0"},
3445          }
3446      }
3447  
3448      # CREATE assessment
3449      create_response = requests.post(
3450          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments",
3451          json=feedback_payload,
3452      )
3453      assert create_response.status_code == 200
3454      assessment = create_response.json()["assessment"]
3455      assessment_id = assessment["assessment_id"]
3456  
3457      # Verify creation
3458      assert assessment["assessment_name"] == "quality_score"
3459      assert assessment["feedback"]["value"]["rating"] == 4
3460      assert assessment["source"]["source_type"] == "HUMAN"
3461      assert assessment["valid"] is True
3462  
3463      # GET assessment
3464      get_response = requests.get(
3465          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}"
3466      )
3467      assert get_response.status_code == 200
3468      retrieved = get_response.json()["assessment"]
3469      assert retrieved["assessment_id"] == assessment_id
3470      assert retrieved["feedback"]["value"]["rating"] == 4
3471  
3472      # UPDATE assessment
3473      update_payload = {
3474          "assessment": {
3475              "assessment_id": assessment_id,
3476              "trace_id": trace_info.request_id,
3477              "assessment_name": "updated_quality_score",
3478              "feedback": {"value": {"rating": 5, "comments": "Excellent response"}},
3479              "rationale": "Actually, the response was excellent",
3480              "metadata": {"model": "gpt-4", "version": "2.0"},
3481          },
3482          "update_mask": "assessmentName,feedback,rationale,metadata",
3483      }
3484  
3485      update_response = requests.patch(
3486          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}",
3487          json=update_payload,
3488      )
3489      assert update_response.status_code == 200
3490      updated = update_response.json()["assessment"]
3491      assert updated["assessment_name"] == "updated_quality_score"
3492      assert updated["feedback"]["value"]["rating"] == 5
3493      assert updated["rationale"] == "Actually, the response was excellent"
3494  
3495      # CREATE override assessment
3496      override_payload = {
3497          "assessment": {
3498              "assessment_name": "corrected_quality_score",
3499              "feedback": {"value": {"rating": 3, "comments": "Actually needs improvement"}},
3500              "source": {"source_type": "HUMAN", "source_id": "senior_evaluator@company.com"},
3501              "overrides": assessment_id,
3502          }
3503      }
3504  
3505      override_response = requests.post(
3506          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments",
3507          json=override_payload,
3508      )
3509      assert override_response.status_code == 200
3510      override_assessment = override_response.json()["assessment"]
3511      override_id = override_assessment["assessment_id"]
3512  
3513      # Verify original is now invalid
3514      get_original = requests.get(
3515          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}"
3516      )
3517      assert get_original.status_code == 200
3518      assert get_original.json()["assessment"]["valid"] is False
3519  
3520      # Verify override is valid
3521      get_override = requests.get(
3522          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{override_id}"
3523      )
3524      assert get_override.status_code == 200
3525      assert get_override.json()["assessment"]["valid"] is True
3526      assert get_override.json()["assessment"]["overrides"] == assessment_id
3527  
3528      # DELETE override assessment (should restore original)
3529      delete_response = requests.delete(
3530          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{override_id}"
3531      )
3532      assert delete_response.status_code == 200
3533  
3534      # Verify override is deleted
3535      get_deleted = requests.get(
3536          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{override_id}"
3537      )
3538      assert get_deleted.status_code == 404
3539  
3540      # Verify original is restored to valid
3541      get_restored = requests.get(
3542          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{assessment_id}"
3543      )
3544      assert get_restored.status_code == 200
3545      assert get_restored.json()["assessment"]["valid"] is True
3546  
3547      # CREATE expectation assessment to test different type
3548      expectation_payload = {
3549          "assessment": {
3550              "assessment_name": "response_time_check",
3551              "expectation": {"value": {"threshold_ms": 1000, "actual_ms": 750, "passed": True}},
3552              "source": {"source_type": "CODE", "source_id": "automated_test"},
3553          }
3554      }
3555  
3556      expectation_response = requests.post(
3557          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments",
3558          json=expectation_payload,
3559      )
3560      assert expectation_response.status_code == 200
3561      expectation = expectation_response.json()["assessment"]
3562      expectation_id = expectation["assessment_id"]
3563  
3564      # Verify expectation was created correctly
3565      expectation_value = json.loads(expectation["expectation"]["serialized_value"]["value"])
3566      assert expectation_value["passed"] is True
3567      assert expectation_value["threshold_ms"] == 1000
3568      assert expectation_value["actual_ms"] == 750
3569      assert expectation["source"]["source_type"] == "CODE"
3570  
3571      # Clean up - delete remaining assessments
3572      for aid in [assessment_id, expectation_id]:
3573          delete_resp = requests.delete(
3574              f"{mlflow_client.tracking_uri}/api/3.0/mlflow/traces/{trace_info.request_id}/assessments/{aid}"
3575          )
3576          assert delete_resp.status_code == 200
3577  
3578  
3579  def test_graphql_nan_metric_handling(mlflow_client):
3580      experiment_id = mlflow_client.create_experiment("test_graphql_nan_metrics")
3581      created_run = mlflow_client.create_run(experiment_id)
3582      run_id = created_run.info.run_id
3583  
3584      # Log a normal metric and a NaN metric
3585      mlflow_client.log_metric(run_id, key="normal_metric", value=123, timestamp=1, step=1)
3586      mlflow_client.log_metric(run_id, key="nan_metric", value=math.nan, timestamp=2, step=2)
3587  
3588      response = requests.post(
3589          f"{mlflow_client.tracking_uri}/graphql",
3590          json={
3591              "query": f"""
3592                  query testQuery {{
3593                      mlflowGetRun(input: {{runId: "{run_id}"}}) {{
3594                          run {{
3595                              data {{
3596                                  metrics {{
3597                                      key
3598                                      value
3599                                      timestamp
3600                                      step
3601                                  }}
3602                              }}
3603                          }}
3604                      }}
3605                  }}
3606              """,
3607              "operationName": "testQuery",
3608          },
3609          headers={"content-type": "application/json; charset=utf-8"},
3610      )
3611  
3612      assert response.status_code == 200
3613      json_response = response.json()
3614      assert json_response["errors"] is None
3615  
3616      metrics = json_response["data"]["mlflowGetRun"]["run"]["data"]["metrics"]
3617  
3618      # Find the normal metric and nan metric
3619      normal_metric = None
3620      nan_metric = None
3621      for metric in metrics:
3622          if metric["key"] == "normal_metric":
3623              normal_metric = metric
3624          elif metric["key"] == "nan_metric":
3625              nan_metric = metric
3626  
3627      # Verify normal metric has a numeric value
3628      assert normal_metric is not None
3629      assert normal_metric["key"] == "normal_metric"
3630      assert normal_metric["value"] == 123
3631      assert normal_metric["timestamp"] == "1"
3632      assert normal_metric["step"] == "1"
3633  
3634      # Verify NaN metric has null value
3635      assert nan_metric is not None
3636      assert nan_metric["key"] == "nan_metric"
3637      assert nan_metric["value"] is None
3638      assert nan_metric["timestamp"] == "2"
3639      assert nan_metric["step"] == "2"
3640  
3641  
3642  def test_create_and_get_evaluation_dataset(mlflow_client, store_type):
3643      if store_type == "file":
3644          pytest.skip("Evaluation datasets not supported for FileStore")
3645  
3646      experiment_id = mlflow_client.create_experiment("eval_dataset_test")
3647  
3648      dataset = mlflow_client.create_dataset(
3649          name="test_eval_dataset",
3650          experiment_id=experiment_id,
3651          tags={"environment": "test", "version": "1.0"},
3652      )
3653  
3654      assert dataset.name == "test_eval_dataset"
3655      assert dataset.experiment_ids == [experiment_id]
3656      assert dataset.tags["environment"] == "test"
3657      assert dataset.tags["version"] == "1.0"
3658      assert dataset.dataset_id is not None
3659  
3660      retrieved = mlflow_client.get_dataset(dataset.dataset_id)
3661      assert retrieved.name == dataset.name
3662      assert retrieved.dataset_id == dataset.dataset_id
3663      assert retrieved.tags == dataset.tags
3664  
3665  
3666  def test_search_evaluation_datasets(mlflow_client, store_type):
3667      if store_type == "file":
3668          pytest.skip("Evaluation datasets not supported for FileStore")
3669  
3670      exp1 = mlflow_client.create_experiment("eval_search_exp1")
3671      exp2 = mlflow_client.create_experiment("eval_search_exp2")
3672  
3673      mlflow_client.create_dataset(
3674          name="search_dataset_1", experiment_id=exp1, tags={"team": "ml", "status": "active"}
3675      )
3676  
3677      mlflow_client.create_dataset(
3678          name="search_dataset_2",
3679          experiment_id=[exp1, exp2],
3680          tags={"team": "data", "status": "active"},
3681      )
3682  
3683      mlflow_client.create_dataset(
3684          name="search_dataset_3", experiment_id=exp2, tags={"team": "ml", "status": "archived"}
3685      )
3686  
3687      all_datasets = mlflow_client.search_datasets()
3688      assert len(all_datasets) >= 3
3689  
3690      exp1_datasets = mlflow_client.search_datasets(experiment_ids=exp1)
3691      dataset_names = [d.name for d in exp1_datasets]
3692      assert "search_dataset_1" in dataset_names
3693      assert "search_dataset_2" in dataset_names
3694  
3695      ml_datasets = mlflow_client.search_datasets(filter_string="tags.team = 'ml'")
3696      ml_names = [d.name for d in ml_datasets]
3697      assert "search_dataset_1" in ml_names
3698      assert "search_dataset_3" in ml_names
3699      assert "search_dataset_2" not in ml_names
3700  
3701      ordered_datasets = mlflow_client.search_datasets(order_by=["name ASC"])
3702      names = [d.name for d in ordered_datasets]
3703      assert names == sorted(names)
3704  
3705  
3706  def test_evaluation_dataset_tag_operations(mlflow_client, store_type):
3707      if store_type == "file":
3708          pytest.skip("Evaluation datasets not supported for FileStore")
3709  
3710      experiment_id = mlflow_client.create_experiment("eval_tags_test")
3711  
3712      dataset = mlflow_client.create_dataset(
3713          name="tag_test_dataset",
3714          experiment_id=experiment_id,
3715          tags={"initial": "value", "env": "dev"},
3716      )
3717  
3718      mlflow_client.set_dataset_tags(dataset.dataset_id, {"env": "staging", "new_tag": "new_value"})
3719  
3720      updated = mlflow_client.get_dataset(dataset.dataset_id)
3721      assert updated.tags["initial"] == "value"  # Original tag preserved
3722      assert updated.tags["env"] == "staging"  # Updated tag
3723      assert updated.tags["new_tag"] == "new_value"  # New tag added
3724  
3725      mlflow_client.delete_dataset_tag(dataset.dataset_id, "new_tag")
3726  
3727      final = mlflow_client.get_dataset(dataset.dataset_id)
3728      assert "new_tag" not in final.tags
3729      assert final.tags["env"] == "staging"  # Other tags preserved
3730  
3731  
3732  def test_evaluation_dataset_delete(mlflow_client, store_type):
3733      if store_type == "file":
3734          pytest.skip("Evaluation datasets not supported for FileStore")
3735  
3736      experiment_id = mlflow_client.create_experiment("eval_delete_test")
3737  
3738      dataset = mlflow_client.create_dataset(
3739          name="delete_test_dataset", experiment_id=experiment_id, tags={"to_delete": "yes"}
3740      )
3741  
3742      retrieved = mlflow_client.get_dataset(dataset.dataset_id)
3743      assert retrieved.name == "delete_test_dataset"
3744  
3745      mlflow_client.delete_dataset(dataset.dataset_id)
3746  
3747      with pytest.raises(MlflowException, match="not found"):
3748          mlflow_client.get_dataset(dataset.dataset_id)
3749  
3750  
3751  def test_evaluation_dataset_upsert_records(mlflow_client, store_type):
3752      if store_type == "file":
3753          pytest.skip("Evaluation datasets not supported for FileStore")
3754  
3755      experiment_id = mlflow_client.create_experiment("upsert_records_test")
3756  
3757      dataset = mlflow_client.create_dataset(
3758          name="test_upsert_dataset",
3759          experiment_id=experiment_id,
3760          tags={"test": "upsert"},
3761      )
3762  
3763      initial_records = [
3764          {
3765              "inputs": {"question": "What is MLflow?"},
3766              "expectations": {"answer": "MLflow is an ML platform"},
3767              "tags": {"difficulty": "easy"},
3768          },
3769          {
3770              "inputs": {"question": "What is Python?"},
3771              "expectations": {"answer": "Python is a programming language"},
3772              "tags": {"difficulty": "easy"},
3773          },
3774      ]
3775  
3776      # NB: MlflowClient doesn't have upsert_dataset_records method - merge_records() calls
3777      # the store directly. We make HTTP requests here to test the REST API handler end-to-end.
3778      response = requests.post(
3779          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/datasets/{dataset.dataset_id}/records",
3780          json={"records": json.dumps(initial_records)},
3781      )
3782      assert response.status_code == 200
3783      result = response.json()
3784      assert result["inserted_count"] == 2
3785      assert result["updated_count"] == 0
3786  
3787      update_records = [
3788          {
3789              "inputs": {"question": "What is MLflow?"},
3790              "expectations": {"answer": "MLflow is an open-source ML platform"},
3791              "tags": {"difficulty": "easy", "updated": "true"},
3792          },
3793          {
3794              "inputs": {"question": "What is Docker?"},
3795              "expectations": {"answer": "Docker is a containerization platform"},
3796              "tags": {"difficulty": "medium"},
3797          },
3798      ]
3799  
3800      response = requests.post(
3801          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/datasets/{dataset.dataset_id}/records",
3802          json={"records": json.dumps(update_records)},
3803      )
3804      assert response.status_code == 200
3805      result = response.json()
3806      assert result["inserted_count"] == 1
3807      assert result["updated_count"] == 1
3808  
3809      response = requests.post(
3810          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/datasets/invalid-id/records",
3811          json={"records": json.dumps(initial_records)},
3812      )
3813      assert response.status_code != 200
3814  
3815  
3816  def test_add_dataset_to_experiments_rest_tracking(mlflow_client, store_type):
3817      if store_type == "file":
3818          pytest.skip("File store doesn't support dataset operations")
3819      exp1 = mlflow_client.create_experiment("dataset_exp_1")
3820      exp2 = mlflow_client.create_experiment("dataset_exp_2")
3821      exp3 = mlflow_client.create_experiment("dataset_exp_3")
3822  
3823      dataset = create_dataset(
3824          name="test_multi_exp_dataset",
3825          experiment_id=[exp1],
3826          tags={"test": "multi_exp"},
3827      )
3828  
3829      assert len(dataset.experiment_ids) == 1
3830      assert exp1 in dataset.experiment_ids
3831  
3832      updated_dataset = add_dataset_to_experiments(
3833          dataset_id=dataset.dataset_id,
3834          experiment_ids=[exp2, exp3],
3835      )
3836  
3837      assert len(updated_dataset.experiment_ids) == 3
3838      assert exp1 in updated_dataset.experiment_ids
3839      assert exp2 in updated_dataset.experiment_ids
3840      assert exp3 in updated_dataset.experiment_ids
3841  
3842      retrieved = mlflow_client.get_dataset(dataset.dataset_id)
3843      assert len(retrieved.experiment_ids) == 3
3844      assert exp1 in retrieved.experiment_ids
3845      assert exp2 in retrieved.experiment_ids
3846      assert exp3 in retrieved.experiment_ids
3847  
3848  
3849  def test_remove_dataset_from_experiments_rest_tracking(mlflow_client, store_type):
3850      if store_type == "file":
3851          pytest.skip("File store doesn't support dataset operations")
3852      exp1 = mlflow_client.create_experiment("dataset_remove_exp_1")
3853      exp2 = mlflow_client.create_experiment("dataset_remove_exp_2")
3854      exp3 = mlflow_client.create_experiment("dataset_remove_exp_3")
3855  
3856      dataset = create_dataset(
3857          name="test_remove_exp_dataset",
3858          experiment_id=[exp1, exp2, exp3],
3859          tags={"test": "remove_exp"},
3860      )
3861  
3862      assert len(dataset.experiment_ids) == 3
3863  
3864      updated_dataset = remove_dataset_from_experiments(
3865          dataset_id=dataset.dataset_id,
3866          experiment_ids=[exp2],
3867      )
3868  
3869      assert len(updated_dataset.experiment_ids) == 2
3870      assert exp1 in updated_dataset.experiment_ids
3871      assert exp2 not in updated_dataset.experiment_ids
3872      assert exp3 in updated_dataset.experiment_ids
3873  
3874      retrieved = mlflow_client.get_dataset(dataset.dataset_id)
3875      assert len(retrieved.experiment_ids) == 2
3876  
3877      updated_dataset = remove_dataset_from_experiments(
3878          dataset_id=dataset.dataset_id,
3879          experiment_ids=[exp1, exp3],
3880      )
3881  
3882      assert len(updated_dataset.experiment_ids) == 0
3883  
3884      retrieved = mlflow_client.get_dataset(dataset.dataset_id)
3885      assert len(retrieved.experiment_ids) == 0
3886  
3887  
3888  def test_add_multiple_experiments_at_once_rest_tracking(mlflow_client, store_type):
3889      if store_type == "file":
3890          pytest.skip("File store doesn't support dataset operations")
3891      exps = [mlflow_client.create_experiment(f"bulk_add_exp_{i}") for i in range(5)]
3892  
3893      dataset = create_dataset(
3894          name="test_bulk_add_dataset",
3895          experiment_id=[exps[0]],
3896          tags={"test": "bulk_add"},
3897      )
3898  
3899      updated_dataset = add_dataset_to_experiments(
3900          dataset_id=dataset.dataset_id,
3901          experiment_ids=exps[1:],
3902      )
3903  
3904      assert len(updated_dataset.experiment_ids) == 5
3905      for exp in exps:
3906          assert exp in updated_dataset.experiment_ids
3907  
3908  
3909  def test_dataset_experiment_association_error_cases_rest_tracking(mlflow_client, store_type):
3910      if store_type == "file":
3911          pytest.skip("File store doesn't support dataset operations")
3912      exp1 = mlflow_client.create_experiment("error_test_exp")
3913  
3914      with pytest.raises(MlflowException, match="not found"):
3915          add_dataset_to_experiments(
3916              dataset_id="d-nonexistent1234567890abcdef1234",
3917              experiment_ids=[exp1],
3918          )
3919  
3920      with pytest.raises(MlflowException, match="not found"):
3921          remove_dataset_from_experiments(
3922              dataset_id="d-nonexistent1234567890abcdef1234",
3923              experiment_ids=[exp1],
3924          )
3925  
3926  
3927  def test_idempotent_add_experiments_rest_tracking(mlflow_client, store_type):
3928      if store_type == "file":
3929          pytest.skip("File store doesn't support dataset operations")
3930      exp1 = mlflow_client.create_experiment("idempotent_test_exp_1")
3931      exp2 = mlflow_client.create_experiment("idempotent_test_exp_2")
3932  
3933      dataset = create_dataset(
3934          name="test_idempotent_dataset",
3935          experiment_id=[exp1, exp2],
3936          tags={"test": "idempotent"},
3937      )
3938  
3939      assert len(dataset.experiment_ids) == 2
3940  
3941      updated_dataset = add_dataset_to_experiments(
3942          dataset_id=dataset.dataset_id,
3943          experiment_ids=[exp1],
3944      )
3945  
3946      assert len(updated_dataset.experiment_ids) == 2
3947      assert exp1 in updated_dataset.experiment_ids
3948      assert exp2 in updated_dataset.experiment_ids
3949  
3950  
3951  def test_idempotent_remove_experiments_rest_tracking(mlflow_client, store_type):
3952      if store_type == "file":
3953          pytest.skip("File store doesn't support dataset operations")
3954      exp1 = mlflow_client.create_experiment("remove_idempotent_test_exp_1")
3955      exp2 = mlflow_client.create_experiment("remove_idempotent_test_exp_2")
3956  
3957      dataset = create_dataset(
3958          name="test_remove_idempotent_dataset",
3959          experiment_id=[exp1],
3960          tags={"test": "remove_idempotent"},
3961      )
3962  
3963      assert len(dataset.experiment_ids) == 1
3964  
3965      updated_dataset = remove_dataset_from_experiments(
3966          dataset_id=dataset.dataset_id,
3967          experiment_ids=[exp2],
3968      )
3969  
3970      assert len(updated_dataset.experiment_ids) == 1
3971      assert exp1 in updated_dataset.experiment_ids
3972  
3973  
3974  def test_client_api_add_remove_experiments_rest_tracking(mlflow_client, store_type):
3975      if store_type == "file":
3976          pytest.skip("File store doesn't support dataset operations")
3977      exp1 = mlflow_client.create_experiment("client_api_exp_1")
3978      exp2 = mlflow_client.create_experiment("client_api_exp_2")
3979      exp3 = mlflow_client.create_experiment("client_api_exp_3")
3980  
3981      dataset = mlflow_client.create_dataset(
3982          name="test_client_api_dataset",
3983          experiment_id=[exp1],
3984          tags={"test": "client_api"},
3985      )
3986  
3987      updated_dataset = mlflow_client.add_dataset_to_experiments(
3988          dataset_id=dataset.dataset_id,
3989          experiment_ids=[exp2, exp3],
3990      )
3991  
3992      assert len(updated_dataset.experiment_ids) == 3
3993  
3994      updated_dataset = mlflow_client.remove_dataset_from_experiments(
3995          dataset_id=dataset.dataset_id,
3996          experiment_ids=[exp2],
3997      )
3998  
3999      assert len(updated_dataset.experiment_ids) == 2
4000      assert exp1 in updated_dataset.experiment_ids
4001      assert exp2 not in updated_dataset.experiment_ids
4002      assert exp3 in updated_dataset.experiment_ids
4003  
4004  
4005  def test_scorer_CRUD(mlflow_client, store_type):
4006      if store_type == "file":
4007          pytest.skip("File store doesn't support scorer CRUD operations")
4008      experiment_id = mlflow_client.create_experiment("test_scorer_api_experiment")
4009  
4010      # Get the RestStore object directly
4011      store = mlflow_client._tracking_client.store
4012  
4013      # Test register scorer
4014      scorer_data = {"name": "test_scorer", "original_func_name": "test_func"}
4015      serialized_scorer = json.dumps(scorer_data)
4016  
4017      version = store.register_scorer(experiment_id, "test_scorer", serialized_scorer)
4018      assert version.scorer_version == 1
4019  
4020      # Test list scorers
4021      scorers = store.list_scorers(experiment_id)
4022      assert len(scorers) == 1
4023      assert scorers[0].scorer_name == "test_scorer"
4024      assert scorers[0].scorer_version == 1
4025  
4026      # Test list scorer versions
4027      versions = store.list_scorer_versions(str(experiment_id), "test_scorer")
4028      assert len(versions) == 1
4029      assert versions[0].scorer_name == "test_scorer"
4030      assert versions[0].scorer_version == 1
4031  
4032      # Test get scorer (latest version)
4033      scorer = store.get_scorer(str(experiment_id), "test_scorer")
4034      assert scorer.scorer_name == "test_scorer"
4035      assert scorer.scorer_version == 1
4036  
4037      # Test get scorer (specific version)
4038      scorer_v1 = store.get_scorer(str(experiment_id), "test_scorer", version=1)
4039      assert scorer_v1.scorer_name == "test_scorer"
4040      assert scorer_v1.scorer_version == 1
4041  
4042      # Test register second version
4043      scorer_data_v2 = {
4044          "name": "test_scorer_v2",
4045          "original_func_name": "test_func_v2",
4046      }
4047      serialized_scorer_v2 = json.dumps(scorer_data_v2)
4048  
4049      version_v2 = store.register_scorer(str(experiment_id), "test_scorer", serialized_scorer_v2)
4050      assert version_v2.scorer_version == 2
4051  
4052      # Verify list scorers returns latest version
4053      scorers_after_v2 = store.list_scorers(str(experiment_id))
4054      assert len(scorers_after_v2) == 1
4055      assert scorers_after_v2[0].scorer_version == 2
4056  
4057      # Verify list versions returns both versions
4058      versions_after_v2 = store.list_scorer_versions(str(experiment_id), "test_scorer")
4059      assert len(versions_after_v2) == 2
4060  
4061      # Test delete specific version
4062      store.delete_scorer(str(experiment_id), "test_scorer", version=1)
4063  
4064      # Verify version 1 is deleted
4065      versions_after_delete = store.list_scorer_versions(str(experiment_id), "test_scorer")
4066      assert len(versions_after_delete) == 1
4067      assert versions_after_delete[0].scorer_version == 2
4068  
4069      # Test delete all versions
4070      store.delete_scorer(str(experiment_id), "test_scorer")
4071  
4072      # Verify all versions are deleted
4073      scorers_after_delete_all = store.list_scorers(str(experiment_id))
4074      assert len(scorers_after_delete_all) == 0
4075  
4076      # Clean up
4077      mlflow_client.delete_experiment(experiment_id)
4078  
4079  
4080  @pytest.mark.parametrize(
4081      "filter_string",
4082      [
4083          "status = 'OK'",
4084          None,
4085      ],
4086  )
4087  def test_online_scoring_config(mlflow_client_with_secrets, filter_string):
4088      """
4089      Smoke test for online scoring configuration REST APIs.
4090      Tests upsert_online_scoring_config and get_online_scoring_configs with both
4091      string and None filter values (None is sent by UI when filter field is blank).
4092      """
4093      experiment_id = mlflow_client_with_secrets.create_experiment("test_online_scoring")
4094      store = mlflow_client_with_secrets._tracking_client.store
4095  
4096      secret = store.create_gateway_secret(
4097          secret_name="test-secret", secret_value={"api_key": "sk-test"}, provider="openai"
4098      )
4099      model_def = store.create_gateway_model_definition(
4100          name="test-model", secret_id=secret.secret_id, provider="openai", model_name="gpt-4"
4101      )
4102      endpoint = store.create_gateway_endpoint(
4103          name="test-endpoint",
4104          model_configs=[
4105              GatewayEndpointModelConfig(
4106                  model_definition_id=model_def.model_definition_id,
4107                  linkage_type=GatewayModelLinkageType.PRIMARY,
4108              )
4109          ],
4110      )
4111  
4112      scorer_data = {"instructions_judge_pydantic_data": {"model": f"gateway:/{endpoint.name}"}}
4113      serialized_scorer = json.dumps(scorer_data)
4114      scorer_version = store.register_scorer(experiment_id, "my_scorer", serialized_scorer)
4115      scorer_id = scorer_version.scorer_id
4116  
4117      config = store.upsert_online_scoring_config(
4118          experiment_id=experiment_id,
4119          scorer_name="my_scorer",
4120          sample_rate=0.5,
4121          filter_string=filter_string,
4122      )
4123      assert config.scorer_id == scorer_id
4124      assert config.sample_rate == 0.5
4125      assert config.filter_string == filter_string
4126      assert config.experiment_id == experiment_id
4127  
4128      configs = store.get_online_scoring_configs([scorer_id])
4129      assert len(configs) == 1
4130      assert configs[0].scorer_id == scorer_id
4131      assert configs[0].sample_rate == 0.5
4132      assert configs[0].filter_string == filter_string
4133  
4134      # Update with different filter string to test update functionality
4135      updated_filter = "status = 'COMPLETED'"
4136      updated_config = store.upsert_online_scoring_config(
4137          experiment_id=experiment_id,
4138          scorer_name="my_scorer",
4139          sample_rate=0.8,
4140          filter_string=updated_filter,
4141      )
4142      assert updated_config.scorer_id == scorer_id
4143      assert updated_config.sample_rate == 0.8
4144      assert updated_config.filter_string == updated_filter
4145  
4146      configs_after_update = store.get_online_scoring_configs([scorer_id])
4147      assert len(configs_after_update) == 1
4148      assert configs_after_update[0].sample_rate == 0.8
4149      assert configs_after_update[0].filter_string == updated_filter
4150  
4151  
4152  @pytest.mark.parametrize("use_async", [False, True])
4153  @pytest.mark.asyncio
4154  async def test_rest_store_logs_spans_via_otel_endpoint(mlflow_client, store_type, use_async):
4155      """
4156      End-to-end test that verifies RestStore can log spans to a running server via OTLP endpoint.
4157  
4158      This test:
4159      1. Creates spans using MLflow's span entities
4160      2. Uses RestStore.log_spans or log_spans_async to send them via OTLP protocol
4161      3. Verifies the spans were stored and can be retrieved
4162      """
4163      if store_type == "file":
4164          pytest.skip("FileStore does not support OTLP span logging")
4165  
4166      experiment_id = mlflow_client.create_experiment(f"rest_store_otel_test_{use_async}")
4167      root_span = mlflow_client.start_trace(
4168          f"rest_store_otel_trace_{use_async}", experiment_id=experiment_id
4169      )
4170      otel_span = OTelReadableSpan(
4171          name=f"test-rest-store-span-{use_async}",
4172          context=build_otel_context(
4173              trace_id=int(root_span.trace_id[3:], 16),  # Remove 'tr-' prefix and convert to int
4174              span_id=0x1234567890ABCDEF,
4175          ),
4176          parent=None,
4177          start_time=1000000000,
4178          end_time=2000000000,
4179          attributes={
4180              SpanAttributeKey.REQUEST_ID: root_span.trace_id,
4181              "test.attribute": json.dumps(f"test-value-{use_async}"),  # JSON-encoded string value
4182          },
4183          resource=None,
4184      )
4185      mlflow_span_to_log = Span(otel_span)
4186      # Call either sync or async version based on parametrization
4187      if use_async:
4188          # Use await to execute the async method
4189          result_spans = await mlflow_client._tracking_client.store.log_spans_async(
4190              location=experiment_id, spans=[mlflow_span_to_log]
4191          )
4192      else:
4193          result_spans = mlflow_client._tracking_client.store.log_spans(
4194              location=experiment_id, spans=[mlflow_span_to_log]
4195          )
4196  
4197      # Verify the spans were returned (indicates successful logging)
4198      assert len(result_spans) == 1
4199      assert result_spans[0].name == f"test-rest-store-span-{use_async}"
4200  
4201  
4202  # =============================================================================
4203  # Secrets and Endpoints E2E Tests
4204  # =============================================================================
4205  
4206  
4207  def test_create_and_get_secret(mlflow_client_with_secrets):
4208      store = mlflow_client_with_secrets._tracking_client.store
4209  
4210      secret = store.create_gateway_secret(
4211          secret_name="test-api-key",
4212          secret_value={"api_key": "sk-test-12345"},
4213          provider="openai",
4214      )
4215  
4216      assert secret.secret_name == "test-api-key"
4217      assert secret.provider == "openai"
4218      assert secret.secret_id is not None
4219  
4220      fetched = store.get_secret_info(secret.secret_id)
4221      assert fetched.secret_name == "test-api-key"
4222      assert fetched.provider == "openai"
4223      assert fetched.secret_id == secret.secret_id
4224  
4225  
4226  def test_update_secret(mlflow_client_with_secrets):
4227      store = mlflow_client_with_secrets._tracking_client.store
4228  
4229      secret = store.create_gateway_secret(
4230          secret_name="test-key",
4231          secret_value={"api_key": "initial-value"},
4232          provider="anthropic",
4233      )
4234  
4235      updated = store.update_gateway_secret(
4236          secret_id=secret.secret_id,
4237          secret_value={"api_key": "updated-value"},
4238      )
4239  
4240      assert updated.secret_id == secret.secret_id
4241      assert updated.secret_name == "test-key"
4242  
4243  
4244  def test_list_secret_infos(mlflow_client_with_secrets):
4245      store = mlflow_client_with_secrets._tracking_client.store
4246  
4247      secret1 = store.create_gateway_secret(
4248          secret_name="openai-key",
4249          secret_value={"api_key": "sk-openai"},
4250          provider="openai",
4251      )
4252      store.create_gateway_secret(
4253          secret_name="anthropic-key",
4254          secret_value={"api_key": "sk-ant"},
4255          provider="anthropic",
4256      )
4257  
4258      all_secrets = store.list_secret_infos()
4259      assert len(all_secrets) >= 2
4260  
4261      openai_secrets = store.list_secret_infos(provider="openai")
4262      assert len(openai_secrets) >= 1
4263      assert any(s.secret_id == secret1.secret_id for s in openai_secrets)
4264  
4265  
4266  def test_delete_secret(mlflow_client_with_secrets):
4267      store = mlflow_client_with_secrets._tracking_client.store
4268  
4269      secret = store.create_gateway_secret(
4270          secret_name="temp-key",
4271          secret_value={"api_key": "temp-value"},
4272      )
4273  
4274      store.delete_gateway_secret(secret.secret_id)
4275  
4276      all_secrets = store.list_secret_infos()
4277      assert not any(s.secret_id == secret.secret_id for s in all_secrets)
4278  
4279  
4280  def test_create_secret_with_dict_value(mlflow_client_with_secrets):
4281      store = mlflow_client_with_secrets._tracking_client.store
4282  
4283      secret = store.create_gateway_secret(
4284          secret_name="aws-creds",
4285          secret_value={"aws_access_key_id": "AKIATEST1234", "aws_secret_access_key": "secret123abc"},
4286          provider="bedrock",
4287      )
4288  
4289      assert secret.secret_name == "aws-creds"
4290      assert secret.provider == "bedrock"
4291      assert secret.secret_id is not None
4292      assert isinstance(secret.masked_values, dict)
4293      assert secret.masked_values == {
4294          "aws_access_key_id": "AKI...1234",
4295          "aws_secret_access_key": "sec...3abc",
4296      }
4297  
4298  
4299  def test_update_secret_with_dict_value(mlflow_client_with_secrets):
4300      store = mlflow_client_with_secrets._tracking_client.store
4301  
4302      secret = store.create_gateway_secret(
4303          secret_name="aws-creds-update",
4304          secret_value={"api_key": "initial-value-1234"},
4305          provider="bedrock",
4306      )
4307  
4308      assert isinstance(secret.masked_values, dict)
4309      assert secret.masked_values == {"api_key": "ini...1234"}
4310  
4311      updated = store.update_gateway_secret(
4312          secret_id=secret.secret_id,
4313          secret_value={
4314              "aws_access_key_id": "NEWKEY123456",
4315              "aws_secret_access_key": "newsecret1234",
4316          },
4317      )
4318  
4319      assert updated.secret_id == secret.secret_id
4320      assert updated.secret_name == "aws-creds-update"
4321      assert isinstance(updated.masked_values, dict)
4322      assert updated.masked_values == {
4323          "aws_access_key_id": "NEW...3456",
4324          "aws_secret_access_key": "new...1234",
4325      }
4326  
4327  
4328  def test_create_and_update_compound_secret_via_rest(mlflow_client_with_secrets):
4329      store = mlflow_client_with_secrets._tracking_client.store
4330  
4331      secret = store.create_gateway_secret(
4332          secret_name="bedrock-aws-creds",
4333          secret_value={
4334              "aws_access_key_id": "AKIAORIGINAL1234",
4335              "aws_secret_access_key": "original-secret-key-1234",
4336          },
4337          provider="bedrock",
4338          auth_config={"auth_mode": "access_keys", "aws_region_name": "us-east-1"},
4339      )
4340  
4341      assert secret.secret_name == "bedrock-aws-creds"
4342      assert secret.provider == "bedrock"
4343      assert isinstance(secret.masked_values, dict)
4344      assert secret.masked_values == {
4345          "aws_access_key_id": "AKI...1234",
4346          "aws_secret_access_key": "ori...1234",
4347      }
4348  
4349      fetched = store.get_secret_info(secret_id=secret.secret_id)
4350      assert fetched.secret_id == secret.secret_id
4351      assert isinstance(fetched.masked_values, dict)
4352      assert fetched.masked_values == secret.masked_values
4353  
4354      updated = store.update_gateway_secret(
4355          secret_id=secret.secret_id,
4356          secret_value={
4357              "aws_access_key_id": "AKIAROTATED5678",
4358              "aws_secret_access_key": "rotated-secret-key-5678",
4359          },
4360      )
4361  
4362      assert updated.secret_id == secret.secret_id
4363      assert updated.last_updated_at > secret.created_at
4364      assert isinstance(updated.masked_values, dict)
4365      assert updated.masked_values == {
4366          "aws_access_key_id": "AKI...5678",
4367          "aws_secret_access_key": "rot...5678",
4368      }
4369  
4370  
4371  def test_create_and_get_endpoint(mlflow_client_with_secrets):
4372      store = mlflow_client_with_secrets._tracking_client.store
4373  
4374      secret = store.create_gateway_secret(
4375          secret_name="test-api-key",
4376          secret_value={"api_key": "sk-test-12345"},
4377          provider="openai",
4378      )
4379      secret2 = store.create_gateway_secret(
4380          secret_name="test-api-key-fallback",
4381          secret_value={"api_key": "sk-test-67890"},
4382          provider="anthropic",
4383      )
4384  
4385      model_def = store.create_gateway_model_definition(
4386          name="test-model-def",
4387          secret_id=secret.secret_id,
4388          provider="openai",
4389          model_name="gpt-4",
4390      )
4391      model_def_fallback = store.create_gateway_model_definition(
4392          name="test-model-def-fallback",
4393          secret_id=secret2.secret_id,
4394          provider="anthropic",
4395          model_name="claude-3-5-sonnet",
4396      )
4397  
4398      endpoint = store.create_gateway_endpoint(
4399          name="test-endpoint",
4400          model_configs=[
4401              GatewayEndpointModelConfig(
4402                  model_definition_id=model_def.model_definition_id,
4403                  linkage_type=GatewayModelLinkageType.PRIMARY,
4404                  weight=1.0,
4405              ),
4406              GatewayEndpointModelConfig(
4407                  model_definition_id=model_def_fallback.model_definition_id,
4408                  linkage_type=GatewayModelLinkageType.FALLBACK,
4409                  weight=1.0,
4410                  fallback_order=0,
4411              ),
4412          ],
4413          routing_strategy=RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT,
4414          fallback_config=FallbackConfig(
4415              strategy=FallbackStrategy.SEQUENTIAL,
4416              max_attempts=2,
4417          ),
4418      )
4419  
4420      assert endpoint.name == "test-endpoint"
4421      assert endpoint.endpoint_id is not None
4422      assert len(endpoint.model_mappings) == 2
4423      assert endpoint.model_mappings[0].model_definition.model_name == "gpt-4"
4424      assert endpoint.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT
4425      assert endpoint.fallback_config is not None
4426      assert endpoint.fallback_config.strategy == FallbackStrategy.SEQUENTIAL
4427      assert endpoint.fallback_config.max_attempts == 2
4428  
4429      fetched = store.get_gateway_endpoint(endpoint.endpoint_id)
4430      assert fetched.name == "test-endpoint"
4431      assert fetched.endpoint_id == endpoint.endpoint_id
4432      assert len(fetched.model_mappings) == 2
4433      assert fetched.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT
4434      assert fetched.fallback_config is not None
4435      assert fetched.fallback_config.strategy == FallbackStrategy.SEQUENTIAL
4436      assert fetched.fallback_config.max_attempts == 2
4437  
4438  
4439  def test_create_endpoint_with_usage_tracking(mlflow_client_with_secrets):
4440      store = mlflow_client_with_secrets._tracking_client.store
4441  
4442      secret = store.create_gateway_secret(
4443          secret_name="usage-tracking-test-key",
4444          secret_value={"api_key": "sk-usage-tracking-test"},
4445          provider="openai",
4446      )
4447  
4448      model_def = store.create_gateway_model_definition(
4449          name="usage-tracking-model-def",
4450          secret_id=secret.secret_id,
4451          provider="openai",
4452          model_name="gpt-4",
4453      )
4454  
4455      endpoint = store.create_gateway_endpoint(
4456          name="usage-tracking-endpoint",
4457          model_configs=[
4458              GatewayEndpointModelConfig(
4459                  model_definition_id=model_def.model_definition_id,
4460                  linkage_type=GatewayModelLinkageType.PRIMARY,
4461                  weight=1.0,
4462              )
4463          ],
4464          usage_tracking=True,
4465      )
4466  
4467      assert endpoint.usage_tracking is True
4468      experiment_id = endpoint.experiment_id
4469  
4470      # Experiment is automatically created with usage tracking enabled
4471      experiment = mlflow_client_with_secrets.get_experiment(experiment_id)
4472      assert experiment.name == "gateway/usage-tracking-endpoint"
4473  
4474  
4475  def test_update_endpoint(mlflow_client_with_secrets):
4476      store = mlflow_client_with_secrets._tracking_client.store
4477  
4478      secret = store.create_gateway_secret(
4479          secret_name="test-api-key-2",
4480          secret_value={"api_key": "sk-test-67890"},
4481          provider="anthropic",
4482      )
4483      secret2 = store.create_gateway_secret(
4484          secret_name="test-api-key-2-fallback",
4485          secret_value={"api_key": "sk-test-99999"},
4486          provider="openai",
4487      )
4488  
4489      model_def = store.create_gateway_model_definition(
4490          name="test-model-def-2",
4491          secret_id=secret.secret_id,
4492          provider="anthropic",
4493          model_name="claude-3-5-sonnet",
4494      )
4495      model_def_fallback = store.create_gateway_model_definition(
4496          name="test-model-def-2-fallback",
4497          secret_id=secret2.secret_id,
4498          provider="openai",
4499          model_name="gpt-4",
4500      )
4501  
4502      endpoint = store.create_gateway_endpoint(
4503          name="initial-name",
4504          model_configs=[
4505              GatewayEndpointModelConfig(
4506                  model_definition_id=model_def.model_definition_id,
4507                  linkage_type=GatewayModelLinkageType.PRIMARY,
4508                  weight=1.0,
4509              ),
4510          ],
4511      )
4512  
4513      updated = store.update_gateway_endpoint(
4514          endpoint_id=endpoint.endpoint_id,
4515          name="updated-name",
4516          model_configs=[
4517              GatewayEndpointModelConfig(
4518                  model_definition_id=model_def.model_definition_id,
4519                  linkage_type=GatewayModelLinkageType.PRIMARY,
4520                  weight=1.0,
4521              ),
4522              GatewayEndpointModelConfig(
4523                  model_definition_id=model_def_fallback.model_definition_id,
4524                  linkage_type=GatewayModelLinkageType.FALLBACK,
4525                  weight=1.0,
4526                  fallback_order=0,
4527              ),
4528          ],
4529          routing_strategy=RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT,
4530          fallback_config=FallbackConfig(
4531              strategy=FallbackStrategy.SEQUENTIAL,
4532              max_attempts=3,
4533          ),
4534      )
4535  
4536      assert updated.endpoint_id == endpoint.endpoint_id
4537      assert updated.name == "updated-name"
4538      assert updated.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT
4539      assert updated.fallback_config is not None
4540      assert updated.fallback_config.strategy == FallbackStrategy.SEQUENTIAL
4541      assert updated.fallback_config.max_attempts == 3
4542      assert len(updated.model_mappings) == 2
4543  
4544  
4545  def test_list_endpoints(mlflow_client_with_secrets):
4546      store = mlflow_client_with_secrets._tracking_client.store
4547  
4548      secret1 = store.create_gateway_secret(
4549          secret_name="test-api-key-3",
4550          secret_value={"api_key": "sk-test-11111"},
4551          provider="openai",
4552      )
4553      secret2 = store.create_gateway_secret(
4554          secret_name="test-api-key-4",
4555          secret_value={"api_key": "sk-test-22222"},
4556          provider="openai",
4557      )
4558      secret3 = store.create_gateway_secret(
4559          secret_name="test-api-key-fallback-3",
4560          secret_value={"api_key": "sk-test-44444"},
4561          provider="anthropic",
4562      )
4563  
4564      model_def1 = store.create_gateway_model_definition(
4565          name="test-model-def-3",
4566          secret_id=secret1.secret_id,
4567          provider="openai",
4568          model_name="gpt-4",
4569      )
4570      model_def2 = store.create_gateway_model_definition(
4571          name="test-model-def-4",
4572          secret_id=secret2.secret_id,
4573          provider="openai",
4574          model_name="gpt-3.5-turbo",
4575      )
4576      model_def3 = store.create_gateway_model_definition(
4577          name="test-model-def-fallback-3",
4578          secret_id=secret3.secret_id,
4579          provider="anthropic",
4580          model_name="claude-3-5-sonnet",
4581      )
4582  
4583      # Create endpoint without fallback
4584      endpoint1 = store.create_gateway_endpoint(
4585          name="endpoint-1",
4586          model_configs=[
4587              GatewayEndpointModelConfig(
4588                  model_definition_id=model_def1.model_definition_id,
4589                  linkage_type=GatewayModelLinkageType.PRIMARY,
4590                  weight=1.0,
4591              ),
4592          ],
4593      )
4594      # Create endpoint with fallback
4595      endpoint2 = store.create_gateway_endpoint(
4596          name="endpoint-2",
4597          model_configs=[
4598              GatewayEndpointModelConfig(
4599                  model_definition_id=model_def2.model_definition_id,
4600                  linkage_type=GatewayModelLinkageType.PRIMARY,
4601                  weight=1.0,
4602              ),
4603              GatewayEndpointModelConfig(
4604                  model_definition_id=model_def3.model_definition_id,
4605                  linkage_type=GatewayModelLinkageType.FALLBACK,
4606                  weight=1.0,
4607                  fallback_order=0,
4608              ),
4609          ],
4610          routing_strategy=RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT,
4611          fallback_config=FallbackConfig(
4612              strategy=FallbackStrategy.SEQUENTIAL,
4613              max_attempts=2,
4614          ),
4615      )
4616  
4617      all_endpoints = store.list_gateway_endpoints()
4618      assert len(all_endpoints) >= 2
4619      endpoint_ids = {e.endpoint_id for e in all_endpoints}
4620      assert endpoint1.endpoint_id in endpoint_ids
4621      assert endpoint2.endpoint_id in endpoint_ids
4622  
4623      # Find and verify endpoints
4624      found_ep1 = next(e for e in all_endpoints if e.endpoint_id == endpoint1.endpoint_id)
4625      found_ep2 = next(e for e in all_endpoints if e.endpoint_id == endpoint2.endpoint_id)
4626  
4627      assert found_ep1.routing_strategy is None
4628      assert found_ep1.fallback_config is None
4629  
4630      assert found_ep2.routing_strategy == RoutingStrategy.REQUEST_BASED_TRAFFIC_SPLIT
4631      assert found_ep2.fallback_config is not None
4632      assert found_ep2.fallback_config.strategy == FallbackStrategy.SEQUENTIAL
4633      assert found_ep2.fallback_config.max_attempts == 2
4634  
4635  
4636  def test_delete_endpoint(mlflow_client_with_secrets):
4637      store = mlflow_client_with_secrets._tracking_client.store
4638  
4639      secret = store.create_gateway_secret(
4640          secret_name="test-api-key-5",
4641          secret_value={"api_key": "sk-test-33333"},
4642          provider="openai",
4643      )
4644  
4645      model_def = store.create_gateway_model_definition(
4646          name="test-model-def-5",
4647          secret_id=secret.secret_id,
4648          provider="openai",
4649          model_name="gpt-4",
4650      )
4651  
4652      endpoint = store.create_gateway_endpoint(
4653          name="temp-endpoint",
4654          model_configs=[
4655              GatewayEndpointModelConfig(
4656                  model_definition_id=model_def.model_definition_id,
4657                  linkage_type=GatewayModelLinkageType.PRIMARY,
4658                  weight=1.0,
4659              ),
4660          ],
4661      )
4662  
4663      store.delete_gateway_endpoint(endpoint.endpoint_id)
4664  
4665      all_endpoints = store.list_gateway_endpoints()
4666      assert not any(e.endpoint_id == endpoint.endpoint_id for e in all_endpoints)
4667  
4668  
4669  def test_model_definitions(mlflow_client_with_secrets):
4670      store = mlflow_client_with_secrets._tracking_client.store
4671  
4672      secret = store.create_gateway_secret(
4673          secret_name="model-secret",
4674          secret_value={"api_key": "sk-test"},
4675          provider="openai",
4676      )
4677  
4678      model_def = store.create_gateway_model_definition(
4679          name="test-model-def",
4680          secret_id=secret.secret_id,
4681          provider="openai",
4682          model_name="gpt-4",
4683      )
4684  
4685      assert model_def.name == "test-model-def"
4686      assert model_def.secret_id == secret.secret_id
4687      assert model_def.provider == "openai"
4688      assert model_def.model_name == "gpt-4"
4689      assert model_def.model_definition_id is not None
4690  
4691      fetched = store.get_gateway_model_definition(model_def.model_definition_id)
4692      assert fetched.model_definition_id == model_def.model_definition_id
4693      assert fetched.name == "test-model-def"
4694  
4695      updated = store.update_gateway_model_definition(
4696          model_definition_id=model_def.model_definition_id,
4697          model_name="gpt-4-turbo",
4698      )
4699      assert updated.model_definition_id == model_def.model_definition_id
4700      assert updated.model_name == "gpt-4-turbo"
4701  
4702      all_defs = store.list_gateway_model_definitions()
4703      assert any(d.model_definition_id == model_def.model_definition_id for d in all_defs)
4704  
4705      store.delete_gateway_model_definition(model_def.model_definition_id)
4706  
4707      all_defs_after = store.list_gateway_model_definitions()
4708      assert not any(d.model_definition_id == model_def.model_definition_id for d in all_defs_after)
4709  
4710  
4711  def test_attach_detach_model_to_endpoint(mlflow_client_with_secrets):
4712      store = mlflow_client_with_secrets._tracking_client.store
4713  
4714      secret = store.create_gateway_secret(
4715          secret_name="attach-detach-secret",
4716          secret_value={"api_key": "sk-test-attach"},
4717          provider="openai",
4718      )
4719  
4720      model_def1 = store.create_gateway_model_definition(
4721          name="attach-model-def-1",
4722          secret_id=secret.secret_id,
4723          provider="openai",
4724          model_name="gpt-4",
4725      )
4726  
4727      model_def2 = store.create_gateway_model_definition(
4728          name="attach-model-def-2",
4729          secret_id=secret.secret_id,
4730          provider="openai",
4731          model_name="gpt-3.5-turbo",
4732      )
4733  
4734      endpoint = store.create_gateway_endpoint(
4735          name="attach-test-endpoint",
4736          model_configs=[
4737              GatewayEndpointModelConfig(
4738                  model_definition_id=model_def1.model_definition_id,
4739                  linkage_type=GatewayModelLinkageType.PRIMARY,
4740                  weight=1.0,
4741              ),
4742          ],
4743      )
4744  
4745      assert len(endpoint.model_mappings) == 1
4746      assert endpoint.model_mappings[0].model_definition.model_name == "gpt-4"
4747  
4748      mapping = store.attach_model_to_endpoint(
4749          endpoint_id=endpoint.endpoint_id,
4750          model_config=GatewayEndpointModelConfig(
4751              model_definition_id=model_def2.model_definition_id,
4752              linkage_type=GatewayModelLinkageType.PRIMARY,
4753              weight=1.0,
4754          ),
4755      )
4756  
4757      assert mapping.endpoint_id == endpoint.endpoint_id
4758      assert mapping.model_definition_id == model_def2.model_definition_id
4759  
4760      fetched_endpoint = store.get_gateway_endpoint(endpoint.endpoint_id)
4761      assert len(fetched_endpoint.model_mappings) == 2
4762  
4763      store.detach_model_from_endpoint(
4764          endpoint_id=endpoint.endpoint_id,
4765          model_definition_id=model_def2.model_definition_id,
4766      )
4767  
4768      fetched_endpoint_after = store.get_gateway_endpoint(endpoint.endpoint_id)
4769      assert len(fetched_endpoint_after.model_mappings) == 1
4770  
4771  
4772  def test_endpoint_bindings(mlflow_client_with_secrets):
4773      store = mlflow_client_with_secrets._tracking_client.store
4774  
4775      secret = store.create_gateway_secret(
4776          secret_name="binding-secret",
4777          secret_value={"api_key": "sk-test-44444"},
4778          provider="openai",
4779      )
4780  
4781      model_def1 = store.create_gateway_model_definition(
4782          name="binding-model-def-1",
4783          secret_id=secret.secret_id,
4784          provider="openai",
4785          model_name="gpt-4",
4786      )
4787  
4788      model_def2 = store.create_gateway_model_definition(
4789          name="binding-model-def-2",
4790          secret_id=secret.secret_id,
4791          provider="openai",
4792          model_name="gpt-3.5-turbo",
4793      )
4794  
4795      endpoint1 = store.create_gateway_endpoint(
4796          name="binding-test-endpoint-1",
4797          model_configs=[
4798              GatewayEndpointModelConfig(
4799                  model_definition_id=model_def1.model_definition_id,
4800                  linkage_type=GatewayModelLinkageType.PRIMARY,
4801                  weight=1.0,
4802              ),
4803          ],
4804      )
4805  
4806      endpoint2 = store.create_gateway_endpoint(
4807          name="binding-test-endpoint-2",
4808          model_configs=[
4809              GatewayEndpointModelConfig(
4810                  model_definition_id=model_def2.model_definition_id,
4811                  linkage_type=GatewayModelLinkageType.PRIMARY,
4812                  weight=1.0,
4813              ),
4814          ],
4815      )
4816  
4817      binding1 = store.create_endpoint_binding(
4818          endpoint_id=endpoint1.endpoint_id,
4819          resource_type=GatewayResourceType.SCORER,
4820          resource_id="job-123",
4821      )
4822  
4823      binding2 = store.create_endpoint_binding(
4824          endpoint_id=endpoint1.endpoint_id,
4825          resource_type=GatewayResourceType.SCORER,
4826          resource_id="job-456",
4827      )
4828  
4829      binding3 = store.create_endpoint_binding(
4830          endpoint_id=endpoint2.endpoint_id,
4831          resource_type=GatewayResourceType.SCORER,
4832          resource_id="job-789",
4833      )
4834  
4835      assert binding1.endpoint_id == endpoint1.endpoint_id
4836      assert binding1.resource_type == GatewayResourceType.SCORER
4837      assert binding1.resource_id == "job-123"
4838  
4839      bindings_endpoint1 = store.list_endpoint_bindings(endpoint_id=endpoint1.endpoint_id)
4840      assert len(bindings_endpoint1) == 2
4841      resource_ids = {b.resource_id for b in bindings_endpoint1}
4842      assert binding1.resource_id in resource_ids
4843      assert binding2.resource_id in resource_ids
4844      assert binding3.resource_id not in resource_ids
4845  
4846      bindings_by_type = store.list_endpoint_bindings(resource_type=GatewayResourceType.SCORER)
4847      assert len(bindings_by_type) >= 3
4848  
4849      bindings_by_resource = store.list_endpoint_bindings(resource_id="job-123")
4850      assert len(bindings_by_resource) == 1
4851      assert bindings_by_resource[0].resource_id == binding1.resource_id
4852  
4853      bindings_multi = store.list_endpoint_bindings(
4854          endpoint_id=endpoint1.endpoint_id,
4855          resource_type=GatewayResourceType.SCORER,
4856      )
4857      assert len(bindings_multi) == 2
4858  
4859      store.delete_endpoint_binding(
4860          endpoint_id=binding1.endpoint_id,
4861          resource_type=binding1.resource_type.value,
4862          resource_id=binding1.resource_id,
4863      )
4864  
4865      bindings_after = store.list_endpoint_bindings(endpoint_id=endpoint1.endpoint_id)
4866      assert len(bindings_after) == 1
4867      assert not any(b.resource_id == binding1.resource_id for b in bindings_after)
4868  
4869  
4870  def test_secrets_and_endpoints_integration(mlflow_client_with_secrets):
4871      store = mlflow_client_with_secrets._tracking_client.store
4872  
4873      secret = store.create_gateway_secret(
4874          secret_name="integration-test-key",
4875          secret_value={"api_key": "sk-integration-test"},
4876          provider="openai",
4877      )
4878  
4879      model_def1 = store.create_gateway_model_definition(
4880          name="integration-model-def-1",
4881          secret_id=secret.secret_id,
4882          provider="openai",
4883          model_name="gpt-3.5-turbo",
4884      )
4885  
4886      model_def2 = store.create_gateway_model_definition(
4887          name="integration-model-def-2",
4888          secret_id=secret.secret_id,
4889          provider="openai",
4890          model_name="gpt-4",
4891      )
4892  
4893      endpoint = store.create_gateway_endpoint(
4894          name="integration-endpoint",
4895          model_configs=[
4896              GatewayEndpointModelConfig(
4897                  model_definition_id=model_def1.model_definition_id,
4898                  linkage_type=GatewayModelLinkageType.PRIMARY,
4899                  weight=1.0,
4900              ),
4901          ],
4902      )
4903  
4904      mapping = store.attach_model_to_endpoint(
4905          endpoint_id=endpoint.endpoint_id,
4906          model_config=GatewayEndpointModelConfig(
4907              model_definition_id=model_def2.model_definition_id,
4908              linkage_type=GatewayModelLinkageType.PRIMARY,
4909              weight=1.0,
4910          ),
4911      )
4912  
4913      binding = store.create_endpoint_binding(
4914          endpoint_id=endpoint.endpoint_id,
4915          resource_type=GatewayResourceType.SCORER,
4916          resource_id="integration-job",
4917      )
4918  
4919      fetched_endpoint = store.get_gateway_endpoint(endpoint.endpoint_id)
4920      assert len(fetched_endpoint.model_mappings) == 2
4921      mapping_ids = {m.mapping_id for m in fetched_endpoint.model_mappings}
4922      assert mapping.mapping_id in mapping_ids
4923  
4924      bindings = store.list_endpoint_bindings(resource_id="integration-job")
4925      assert len(bindings) == 1
4926      assert bindings[0].resource_id == binding.resource_id
4927  
4928      store.delete_endpoint_binding(
4929          endpoint_id=binding.endpoint_id,
4930          resource_type=binding.resource_type.value,
4931          resource_id=binding.resource_id,
4932      )
4933      store.detach_model_from_endpoint(
4934          endpoint_id=endpoint.endpoint_id,
4935          model_definition_id=model_def2.model_definition_id,
4936      )
4937      store.delete_gateway_endpoint(endpoint.endpoint_id)
4938      store.delete_gateway_model_definition(model_def1.model_definition_id)
4939      store.delete_gateway_model_definition(model_def2.model_definition_id)
4940      store.delete_gateway_secret(secret.secret_id)
4941  
4942  
4943  def test_list_providers(mlflow_client_with_secrets):
4944      import requests
4945  
4946      base_url = mlflow_client_with_secrets._tracking_client.tracking_uri
4947      response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/supported-providers")
4948      assert response.status_code == 200
4949      data = response.json()
4950      assert "providers" in data
4951      assert isinstance(data["providers"], list)
4952      assert len(data["providers"]) > 0
4953      assert "openai" in data["providers"]
4954  
4955  
4956  def test_list_models(mlflow_client_with_secrets):
4957      import requests
4958  
4959      base_url = mlflow_client_with_secrets._tracking_client.tracking_uri
4960      response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/supported-models")
4961      assert response.status_code == 200
4962      data = response.json()
4963      assert "models" in data
4964      assert isinstance(data["models"], list)
4965      assert len(data["models"]) > 0
4966  
4967      model = data["models"][0]
4968      assert "model" in model
4969      assert "provider" in model
4970      assert "mode" in model
4971      assert all(not m["model"].startswith("ft:") for m in data["models"])
4972  
4973      response = requests.get(
4974          f"{base_url}/ajax-api/3.0/mlflow/gateway/supported-models", params={"provider": "openai"}
4975      )
4976      assert response.status_code == 200
4977      filtered_data = response.json()
4978      assert all(m["provider"] == "openai" for m in filtered_data["models"])
4979  
4980  
4981  def test_get_provider_config(mlflow_client_with_secrets):
4982      import requests
4983  
4984      base_url = mlflow_client_with_secrets._tracking_client.tracking_uri
4985  
4986      # Test simple provider (openai) - should have single api_key auth mode
4987      response = requests.get(
4988          f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config",
4989          params={"provider": "openai"},
4990      )
4991      assert response.status_code == 200
4992      data = response.json()
4993      assert "auth_modes" in data
4994      assert "default_mode" in data
4995      assert data["default_mode"] == "api_key"
4996      assert len(data["auth_modes"]) >= 1
4997      api_key_mode = data["auth_modes"][0]
4998      assert api_key_mode["mode"] == "api_key"
4999  
5000      # Test multi-mode provider (bedrock) - should have multiple auth modes
5001      response = requests.get(
5002          f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config",
5003          params={"provider": "bedrock"},
5004      )
5005      assert response.status_code == 200
5006      data = response.json()
5007      assert "auth_modes" in data
5008      assert data["default_mode"] == "api_key"
5009      assert len(data["auth_modes"]) >= 2  # api_key, access_keys, iam_role
5010  
5011      # Check access_keys mode structure
5012      access_keys_mode = next(m for m in data["auth_modes"] if m["mode"] == "access_keys")
5013      assert len(access_keys_mode["secret_fields"]) == 2  # access_key_id, secret_access_key
5014      assert any(f["name"] == "aws_secret_access_key" for f in access_keys_mode["secret_fields"])
5015      assert any(f["name"] == "aws_region_name" for f in access_keys_mode["config_fields"])
5016  
5017      # Check iam_role mode exists
5018      iam_role_mode = next(m for m in data["auth_modes"] if m["mode"] == "iam_role")
5019      assert any(f["name"] == "aws_role_name" for f in iam_role_mode["config_fields"])
5020  
5021      # Unknown providers get a generic fallback
5022      response = requests.get(
5023          f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config",
5024          params={"provider": "unknown_provider"},
5025      )
5026      assert response.status_code == 200
5027      data = response.json()
5028      assert data["default_mode"] == "api_key"
5029      assert data["auth_modes"][0]["mode"] == "api_key"
5030      assert data["auth_modes"][0]["config_fields"][0]["name"] == "api_base"
5031  
5032      # Missing provider parameter returns 400
5033      response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/provider-config")
5034      assert response.status_code == 400
5035  
5036  
5037  def test_get_secrets_config_with_custom_passphrase(mlflow_client_with_secrets):
5038      base_url = mlflow_client_with_secrets._tracking_client.tracking_uri
5039  
5040      response = requests.get(f"{base_url}/ajax-api/3.0/mlflow/gateway/secrets/config")
5041      assert response.status_code == 200
5042      data = response.json()
5043      assert data["secrets_available"] is True
5044      assert data["using_default_passphrase"] is False
5045  
5046  
5047  def test_get_secrets_config_with_default_passphrase(tmp_path: Path, monkeypatch):
5048      from tests.tracking.integration_test_utils import ServerThread, get_safe_port
5049  
5050      monkeypatch.delenv("MLFLOW_CRYPTO_KEK_PASSPHRASE", raising=False)
5051  
5052      backend_uri = f"sqlite:///{tmp_path}/mlflow.db"
5053      artifact_uri = (tmp_path / "artifacts").as_uri()
5054  
5055      store = SqlAlchemyStore(backend_uri, artifact_uri)
5056      store.engine.dispose()
5057  
5058      handlers._tracking_store = None
5059      handlers._model_registry_store = None
5060      initialize_backend_stores(backend_uri, default_artifact_root=artifact_uri)
5061  
5062      with ServerThread(app, get_safe_port()) as url:
5063          response = requests.get(f"{url}/ajax-api/3.0/mlflow/gateway/secrets/config")
5064          assert response.status_code == 200
5065          data = response.json()
5066          assert data["secrets_available"] is True
5067          assert data["using_default_passphrase"] is True
5068  
5069  
5070  def test_endpoint_with_orphaned_model_definition(mlflow_client_with_secrets):
5071      store = mlflow_client_with_secrets._tracking_client.store
5072  
5073      secret = store.create_gateway_secret(
5074          secret_name="orphan-test-key",
5075          secret_value={"api_key": "sk-orphan-test"},
5076          provider="openai",
5077      )
5078  
5079      model_def = store.create_gateway_model_definition(
5080          name="orphan-model-def",
5081          secret_id=secret.secret_id,
5082          provider="openai",
5083          model_name="gpt-4",
5084      )
5085  
5086      endpoint = store.create_gateway_endpoint(
5087          name="orphan-test-endpoint",
5088          model_configs=[
5089              GatewayEndpointModelConfig(
5090                  model_definition_id=model_def.model_definition_id,
5091                  linkage_type=GatewayModelLinkageType.PRIMARY,
5092                  weight=1.0,
5093              ),
5094          ],
5095      )
5096  
5097      assert len(endpoint.model_mappings) == 1
5098      assert endpoint.model_mappings[0].model_definition.secret_id == secret.secret_id
5099      assert endpoint.model_mappings[0].model_definition.secret_name == "orphan-test-key"
5100  
5101      store.delete_gateway_secret(secret.secret_id)
5102  
5103      fetched_endpoint = store.get_gateway_endpoint(endpoint.endpoint_id)
5104      assert len(fetched_endpoint.model_mappings) == 1
5105      assert fetched_endpoint.model_mappings[0].model_definition.secret_id is None
5106      assert fetched_endpoint.model_mappings[0].model_definition.secret_name is None
5107  
5108  
5109  def test_update_model_definition_provider(mlflow_client_with_secrets):
5110      store = mlflow_client_with_secrets._tracking_client.store
5111  
5112      secret = store.create_gateway_secret(
5113          secret_name="provider-update-secret",
5114          secret_value={"api_key": "sk-provider-test"},
5115          provider="openai",
5116      )
5117  
5118      model_def = store.create_gateway_model_definition(
5119          name="provider-update-model-def",
5120          secret_id=secret.secret_id,
5121          provider="openai",
5122          model_name="gpt-4",
5123      )
5124  
5125      assert model_def.provider == "openai"
5126      assert model_def.model_name == "gpt-4"
5127  
5128      updated = store.update_gateway_model_definition(
5129          model_definition_id=model_def.model_definition_id,
5130          provider="anthropic",
5131          model_name="claude-3-5-haiku-latest",
5132      )
5133  
5134      assert updated.provider == "anthropic"
5135      assert updated.model_name == "claude-3-5-haiku-latest"
5136  
5137      fetched = store.get_gateway_model_definition(model_def.model_definition_id)
5138      assert fetched.provider == "anthropic"
5139      assert fetched.model_name == "claude-3-5-haiku-latest"
5140  
5141      store.delete_gateway_model_definition(model_def.model_definition_id)
5142      store.delete_gateway_secret(secret.secret_id)
5143  
5144  
5145  def test_create_issue_with_all_fields(mlflow_client, store_type):
5146      if store_type == "file":
5147          pytest.skip("Issues are only supported in SqlAlchemyStore")
5148  
5149      mlflow.set_tracking_uri(mlflow_client.tracking_uri)
5150      experiment_id = mlflow_client.create_experiment("Issue Test")
5151      run = mlflow_client.create_run(experiment_id)
5152  
5153      response = requests.post(
5154          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5155          json={
5156              "experiment_id": experiment_id,
5157              "name": "High latency issue",
5158              "description": "API calls are taking too long",
5159              "status": IssueStatus.PENDING.value,
5160              "source_run_id": run.info.run_id,
5161              "root_causes": ["Database query inefficiency", "Network latency"],
5162              "severity": IssueSeverity.HIGH.value,
5163              "created_by": "test-user",
5164          },
5165      )
5166      assert response.status_code == 200
5167      data = response.json()
5168      assert "issue" in data
5169      issue = data["issue"]
5170      assert issue["experiment_id"] == experiment_id
5171      assert issue["name"] == "High latency issue"
5172      assert issue["description"] == "API calls are taking too long"
5173      assert issue["status"] == IssueStatus.PENDING.value
5174      assert issue["source_run_id"] == run.info.run_id
5175      assert issue["root_causes"] == ["Database query inefficiency", "Network latency"]
5176      assert issue["severity"] == IssueSeverity.HIGH.value
5177      assert issue["created_by"] == "test-user"
5178      assert "issue_id" in issue
5179      assert "created_timestamp" in issue
5180      assert "last_updated_timestamp" in issue
5181  
5182  
5183  def test_create_issue_minimal_fields(mlflow_client, store_type):
5184      if store_type == "file":
5185          pytest.skip("Issues are only supported in SqlAlchemyStore")
5186      experiment_id = mlflow_client.create_experiment("Issue Test Minimal")
5187  
5188      response = requests.post(
5189          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5190          json={
5191              "experiment_id": experiment_id,
5192              "name": "Test issue",
5193              "description": "Test description",
5194          },
5195      )
5196      assert response.status_code == 200
5197      data = response.json()
5198      issue = data["issue"]
5199      assert issue["experiment_id"] == experiment_id
5200      assert issue["name"] == "Test issue"
5201      assert issue["description"] == "Test description"
5202      assert issue["status"] == IssueStatus.PENDING.value
5203      assert "issue_id" in issue
5204  
5205  
5206  def test_create_issue_with_required_fields(mlflow_client, store_type):
5207      if store_type == "file":
5208          pytest.skip("Issues are only supported in SqlAlchemyStore")
5209      experiment_id = mlflow_client.create_experiment("Issue Test Required Fields")
5210  
5211      response = requests.post(
5212          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5213          json={
5214              "experiment_id": experiment_id,
5215              "name": "Issue with required fields only",
5216              "description": "Testing issue creation with required fields",
5217              "status": IssueStatus.RESOLVED.value,
5218          },
5219      )
5220      assert response.status_code == 200
5221      data = response.json()
5222      issue = data["issue"]
5223      assert issue["status"] == IssueStatus.RESOLVED.value
5224      assert "issue_id" in issue
5225      assert "created_timestamp" in issue
5226      assert "last_updated_timestamp" in issue
5227  
5228  
5229  def test_create_issue_invalid_experiment(mlflow_client, store_type):
5230      if store_type == "file":
5231          pytest.skip("Issues are only supported in SqlAlchemyStore")
5232      response = requests.post(
5233          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5234          json={
5235              "experiment_id": "999999",
5236              "name": "Test issue",
5237              "description": "Test description",
5238          },
5239      )
5240      assert response.status_code == 404
5241      data = response.json()
5242      assert data["error_code"] == "RESOURCE_DOES_NOT_EXIST"
5243  
5244  
5245  def test_get_issue(mlflow_client, store_type):
5246      if store_type == "file":
5247          pytest.skip("Issues are only supported in SqlAlchemyStore")
5248      experiment_id = mlflow_client.create_experiment("Issue Test Get")
5249  
5250      create_response = requests.post(
5251          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5252          json={
5253              "experiment_id": experiment_id,
5254              "name": "Test issue",
5255              "description": "Test description",
5256              "severity": IssueSeverity.MEDIUM.value,
5257          },
5258      )
5259      issue_id = create_response.json()["issue"]["issue_id"]
5260  
5261      get_response = requests.get(f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/{issue_id}")
5262      assert get_response.status_code == 200
5263      data = get_response.json()
5264      issue = data["issue"]
5265      assert issue["issue_id"] == issue_id
5266      assert issue["name"] == "Test issue"
5267      assert issue["severity"] == IssueSeverity.MEDIUM.value
5268  
5269  
5270  def test_get_issue_not_found(mlflow_client, store_type):
5271      if store_type == "file":
5272          pytest.skip("Issues are only supported in SqlAlchemyStore")
5273      response = requests.get(f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/nonexistent-issue")
5274      assert response.status_code == 404
5275      data = response.json()
5276      assert data["error_code"] == "RESOURCE_DOES_NOT_EXIST"
5277  
5278  
5279  def test_update_issue(mlflow_client, store_type):
5280      if store_type == "file":
5281          pytest.skip("Issues are only supported in SqlAlchemyStore")
5282      experiment_id = mlflow_client.create_experiment("Issue Test Update")
5283  
5284      create_response = requests.post(
5285          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5286          json={
5287              "experiment_id": experiment_id,
5288              "name": "Original name",
5289              "description": "Original description",
5290              "status": IssueStatus.PENDING.value,
5291          },
5292      )
5293      issue_id = create_response.json()["issue"]["issue_id"]
5294  
5295      update_response = requests.patch(
5296          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/{issue_id}",
5297          json={
5298              "issue_id": issue_id,
5299              "name": "Updated name",
5300              "description": "Updated description",
5301              "status": IssueStatus.RESOLVED.value,
5302              "severity": IssueSeverity.HIGH.value,
5303          },
5304      )
5305      assert update_response.status_code == 200
5306      data = update_response.json()
5307      issue = data["issue"]
5308      assert issue["issue_id"] == issue_id
5309      assert issue["name"] == "Updated name"
5310      assert issue["description"] == "Updated description"
5311      assert issue["status"] == IssueStatus.RESOLVED.value
5312      assert issue["severity"] == IssueSeverity.HIGH.value
5313  
5314  
5315  def test_search_issues_no_filters(mlflow_client, store_type):
5316      if store_type == "file":
5317          pytest.skip("Issues are only supported in SqlAlchemyStore")
5318      experiment_id = mlflow_client.create_experiment("Issue Test Search")
5319  
5320      for i in range(3):
5321          requests.post(
5322              f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5323              json={
5324                  "experiment_id": experiment_id,
5325                  "name": f"Issue {i}",
5326                  "description": f"Description {i}",
5327              },
5328          )
5329  
5330      search_response = requests.post(
5331          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search", json={}
5332      )
5333      assert search_response.status_code == 200
5334      data = search_response.json()
5335      assert "issues" in data
5336      assert len(data["issues"]) == 3
5337      assert {issue["name"] for issue in data["issues"]} == {"Issue 0", "Issue 1", "Issue 2"}
5338      assert {issue["status"] for issue in data["issues"]} == {IssueStatus.PENDING.value}
5339  
5340  
5341  def test_search_issues_by_experiment(mlflow_client, store_type):
5342      if store_type == "file":
5343          pytest.skip("Issues are only supported in SqlAlchemyStore")
5344      exp1 = mlflow_client.create_experiment("Issue Test Search Exp1")
5345      exp2 = mlflow_client.create_experiment("Issue Test Search Exp2")
5346  
5347      requests.post(
5348          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5349          json={
5350              "experiment_id": exp1,
5351              "name": "Issue in exp1",
5352              "description": "Description",
5353          },
5354      )
5355  
5356      requests.post(
5357          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5358          json={
5359              "experiment_id": exp2,
5360              "name": "Issue in exp2",
5361              "description": "Description",
5362          },
5363      )
5364  
5365      search_response = requests.post(
5366          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search",
5367          json={"experiment_id": exp1},
5368      )
5369      assert search_response.status_code == 200
5370      data = search_response.json()
5371      issues = data["issues"]
5372      assert len(issues) == 1
5373      assert issues[0]["experiment_id"] == exp1
5374      assert issues[0]["name"] == "Issue in exp1"
5375  
5376  
5377  def test_search_issues_by_status(mlflow_client, store_type):
5378      if store_type == "file":
5379          pytest.skip("Issues are only supported in SqlAlchemyStore")
5380      experiment_id = mlflow_client.create_experiment("Issue Test Search Status")
5381  
5382      requests.post(
5383          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5384          json={
5385              "experiment_id": experiment_id,
5386              "name": "Draft issue",
5387              "description": "Description",
5388              "status": IssueStatus.PENDING.value,
5389          },
5390      )
5391  
5392      requests.post(
5393          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5394          json={
5395              "experiment_id": experiment_id,
5396              "name": "Confirmed issue",
5397              "description": "Description",
5398              "status": IssueStatus.RESOLVED.value,
5399          },
5400      )
5401  
5402      search_response = requests.post(
5403          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search",
5404          json={"experiment_id": experiment_id, "filter_string": "status = 'resolved'"},
5405      )
5406      assert search_response.status_code == 200
5407      data = search_response.json()
5408      issues = data["issues"]
5409      assert all(issue["status"] == IssueStatus.RESOLVED.value for issue in issues)
5410      assert any(issue["name"] == "Confirmed issue" for issue in issues)
5411  
5412  
5413  def test_search_issues_with_pagination(mlflow_client, store_type):
5414      if store_type == "file":
5415          pytest.skip("Issues are only supported in SqlAlchemyStore")
5416      experiment_id = mlflow_client.create_experiment("Issue Test Pagination")
5417  
5418      for i in range(15):
5419          requests.post(
5420              f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5421              json={
5422                  "experiment_id": experiment_id,
5423                  "name": f"Issue {i}",
5424                  "description": f"Description {i}",
5425              },
5426          )
5427  
5428      first_page = requests.post(
5429          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search",
5430          json={"experiment_id": experiment_id, "max_results": 10},
5431      )
5432      assert first_page.status_code == 200
5433      first_data = first_page.json()
5434      assert len(first_data["issues"]) == 10
5435      assert "next_page_token" in first_data
5436      assert first_data["next_page_token"] != ""
5437  
5438      second_page = requests.post(
5439          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search",
5440          json={
5441              "experiment_id": experiment_id,
5442              "max_results": 10,
5443              "page_token": first_data["next_page_token"],
5444          },
5445      )
5446      assert second_page.status_code == 200
5447      second_data = second_page.json()
5448      assert len(second_data["issues"]) == 5
5449      assert second_data["next_page_token"] == ""
5450  
5451  
5452  def test_search_issues_sorted_by_timestamp(mlflow_client, store_type):
5453      if store_type == "file":
5454          pytest.skip("Issues are only supported in SqlAlchemyStore")
5455      experiment_id = mlflow_client.create_experiment("Issue Test Sort")
5456  
5457      # Create issues with slight delays to ensure different timestamps
5458      issue_ids = []
5459      for i in range(3):
5460          response = requests.post(
5461              f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues",
5462              json={
5463                  "experiment_id": experiment_id,
5464                  "name": f"Issue {i}",
5465                  "description": f"Description {i}",
5466              },
5467          )
5468          issue_ids.append(response.json()["issue"]["issue_id"])
5469          time.sleep(0.01)  # Small delay to ensure different timestamps
5470  
5471      search_response = requests.post(
5472          f"{mlflow_client.tracking_uri}/api/3.0/mlflow/issues/search",
5473          json={"experiment_id": experiment_id},
5474      )
5475      assert search_response.status_code == 200
5476      data = search_response.json()
5477      issues = data["issues"]
5478      assert len(issues) == 3
5479      # Issues should be returned (default order is by created_timestamp descending)
5480      assert {issue["issue_id"] for issue in issues} == set(issue_ids)