/ tests / tracking / test_client.py
test_client.py
   1  import json
   2  import os
   3  import pickle
   4  import threading
   5  import time
   6  import uuid
   7  from pathlib import Path
   8  from unittest import mock
   9  from unittest.mock import Mock, patch
  10  
  11  import pytest
  12  from opentelemetry import trace as trace_api
  13  from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
  14  from pydantic import BaseModel
  15  
  16  import mlflow
  17  from mlflow import MlflowClient, flush_async_logging
  18  from mlflow.config import enable_async_logging
  19  from mlflow.entities import (
  20      EvaluationDataset,
  21      ExperimentTag,
  22      IssueSeverity,
  23      IssueStatus,
  24      LoggedModel,
  25      Run,
  26      RunInfo,
  27      RunStatus,
  28      RunTag,
  29      SourceType,
  30      Span,
  31      SpanStatusCode,
  32      SpanType,
  33      Trace,
  34      TraceInfo,
  35      ViewType,
  36  )
  37  from mlflow.entities.file_info import FileInfo
  38  from mlflow.entities.logged_model_status import LoggedModelStatus
  39  from mlflow.entities.metric import Metric
  40  from mlflow.entities.model_registry import ModelVersion, ModelVersionTag
  41  from mlflow.entities.model_registry.model_version_status import ModelVersionStatus
  42  from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY
  43  from mlflow.entities.param import Param
  44  from mlflow.entities.span import create_mlflow_span
  45  from mlflow.entities.trace_data import TraceData
  46  from mlflow.entities.trace_location import TraceLocation, TraceLocationType, UCSchemaLocation
  47  from mlflow.entities.trace_state import TraceState
  48  from mlflow.entities.trace_status import TraceStatus
  49  from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME
  50  from mlflow.exceptions import (
  51      MlflowException,
  52      MlflowNotImplementedException,
  53      MlflowTraceDataCorrupted,
  54      MlflowTraceDataNotFound,
  55  )
  56  from mlflow.prompt.registry_utils import PromptCache
  57  from mlflow.store.artifact.artifact_repo import ArtifactRepository
  58  from mlflow.store.entities.paged_list import PagedList
  59  from mlflow.store.model_registry.sqlalchemy_store import (
  60      SqlAlchemyStore as SqlAlchemyModelRegistryStore,
  61  )
  62  from mlflow.store.tracking import SEARCH_EVALUATION_DATASETS_MAX_RESULTS, SEARCH_MAX_RESULTS_DEFAULT
  63  from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore as SqlAlchemyTrackingStore
  64  from mlflow.tracing.constant import SpansLocation, TraceMetadataKey, TraceTagKey
  65  from mlflow.tracing.provider import _get_tracer, trace_disabled
  66  from mlflow.tracing.utils import TraceJSONEncoder
  67  from mlflow.tracking import set_registry_uri
  68  from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
  69  from mlflow.tracking._model_registry.utils import (
  70      _get_store_registry as _get_model_registry_store_registry,
  71  )
  72  from mlflow.tracking._tracking_service.utils import _register, _use_tracking_uri
  73  from mlflow.tracking.default_experiment import DEFAULT_EXPERIMENT_ID
  74  from mlflow.utils.databricks_utils import _construct_databricks_run_url
  75  from mlflow.utils.mlflow_tags import (
  76      MLFLOW_GIT_COMMIT,
  77      MLFLOW_PARENT_RUN_ID,
  78      MLFLOW_PROJECT_ENTRY_POINT,
  79      MLFLOW_SOURCE_NAME,
  80      MLFLOW_SOURCE_TYPE,
  81      MLFLOW_USER,
  82  )
  83  from mlflow.utils.os import is_windows
  84  
  85  from tests.tracing.conftest import async_logging_enabled  # noqa: F401
  86  from tests.tracing.helper import create_test_trace_info, get_traces
  87  
  88  
  89  @pytest.fixture(autouse=True)
  90  def reset_registry_uri():
  91      yield
  92      set_registry_uri(None)
  93  
  94  
  95  @pytest.fixture
  96  def mock_store():
  97      with mock.patch("mlflow.tracking._tracking_service.utils._get_store") as mock_get_store:
  98          mock_store = mock_get_store.return_value
  99          with mock.patch("mlflow.tracing.client._get_store", return_value=mock_store):
 100              yield mock_store
 101  
 102  
 103  @pytest.fixture
 104  def mock_artifact_repo():
 105      with mock.patch(
 106          "mlflow.tracking._tracking_service.client.get_artifact_repository"
 107      ) as mock_get_repo:
 108          mock_repo = mock_get_repo.return_value
 109          with mock.patch("mlflow.tracing.client.get_artifact_repository", return_value=mock_repo):
 110              yield mock_repo
 111  
 112  
 113  @pytest.fixture
 114  def mock_registry_store():
 115      mock_store = mock.MagicMock()
 116      mock_store.create_model_version = mock.create_autospec(
 117          SqlAlchemyModelRegistryStore.create_model_version
 118      )
 119      with mock.patch("mlflow.tracking._model_registry.utils._get_store", return_value=mock_store):
 120          yield mock_store
 121  
 122  
 123  @pytest.fixture
 124  def mock_databricks_tracking_store():
 125      experiment_id = "test-exp-id"
 126      run_id = "runid"
 127  
 128      class MockDatabricksTrackingStore:
 129          def __init__(self, run_id, experiment_id):
 130              self.run_id = run_id
 131              self.experiment_id = experiment_id
 132  
 133          def get_run(self, *args, **kwargs):
 134              return Run(
 135                  RunInfo(self.run_id, self.experiment_id, "userid", "status", 0, 1, None), None
 136              )
 137  
 138      mock_tracking_store = MockDatabricksTrackingStore(run_id, experiment_id)
 139  
 140      with mock.patch(
 141          "mlflow.tracking._tracking_service.utils._tracking_store_registry.get_store",
 142          return_value=mock_tracking_store,
 143      ):
 144          yield mock_tracking_store
 145  
 146  
 147  @pytest.fixture
 148  def mock_store_start_trace():
 149      def _mock_start_trace(trace_info):
 150          return create_test_trace_info(
 151              trace_id="tr-123",
 152              experiment_id=trace_info.experiment_id,
 153              request_time=trace_info.request_time,
 154              execution_duration=trace_info.execution_duration,
 155              state=trace_info.state,
 156              trace_metadata=trace_info.trace_metadata,
 157              tags={
 158                  "mlflow.user": "bob",
 159                  "mlflow.artifactLocation": "test",
 160                  **trace_info.tags,
 161              },
 162          )
 163  
 164      with mock.patch(
 165          "mlflow.tracing.client.TracingClient.start_trace", side_effect=_mock_start_trace
 166      ) as mock_start_trace:
 167          yield mock_start_trace
 168  
 169  
 170  @pytest.fixture
 171  def mock_spark_session():
 172      with mock.patch(
 173          "mlflow.utils.databricks_utils._get_active_spark_session"
 174      ) as mock_spark_session:
 175          yield mock_spark_session.return_value
 176  
 177  
 178  @pytest.fixture
 179  def mock_time():
 180      time = 1552319350.244724
 181      with mock.patch("time.time", return_value=time):
 182          yield time
 183  
 184  
 185  @pytest.fixture
 186  def setup_async_logging():
 187      enable_async_logging(True)
 188      yield
 189      flush_async_logging()
 190      enable_async_logging(False)
 191  
 192  
 193  def test_client_create_run(mock_store, mock_time):
 194      experiment_id = mock.Mock()
 195  
 196      MlflowClient().create_run(experiment_id)
 197  
 198      mock_store.create_run.assert_called_once_with(
 199          experiment_id=experiment_id,
 200          user_id="unknown",
 201          start_time=int(mock_time * 1000),
 202          tags=[],
 203          run_name=None,
 204      )
 205  
 206  
 207  def test_client_create_run_with_name(mock_store, mock_time):
 208      experiment_id = mock.Mock()
 209  
 210      MlflowClient().create_run(experiment_id, run_name="my name")
 211  
 212      mock_store.create_run.assert_called_once_with(
 213          experiment_id=experiment_id,
 214          user_id="unknown",
 215          start_time=int(mock_time * 1000),
 216          tags=[],
 217          run_name="my name",
 218      )
 219  
 220  
 221  def test_client_get_trace(mock_store, mock_artifact_repo):
 222      trace_id = "trace:/catalog.schema/123"
 223      mock_store.batch_get_traces.return_value = [
 224          Trace(
 225              TraceInfo(
 226                  trace_id=trace_id,
 227                  trace_location=TraceLocation(
 228                      type=TraceLocationType.UC_SCHEMA,
 229                      uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"),
 230                  ),
 231                  request_time=123,
 232                  execution_duration=456,
 233                  state=TraceState.OK,
 234                  tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts"},
 235              ),
 236              TraceData(
 237                  spans=[
 238                      Span.from_dict({
 239                          "name": "predict",
 240                          "context": {
 241                              "trace_id": "0x123456789",
 242                              "span_id": "0x12345",
 243                          },
 244                          "parent_id": None,
 245                          "start_time": 123000000,
 246                          "end_time": 579000000,
 247                          "status_code": "OK",
 248                          "status_message": "",
 249                          "attributes": {
 250                              "mlflow.traceRequestId": f'"{trace_id}"',
 251                              "mlflow.spanType": '"LLM"',
 252                              "mlflow.spanFunctionName": '"predict"',
 253                              "mlflow.spanInputs": '{"prompt": "What is the meaning of life?"}',
 254                              "mlflow.spanOutputs": '{"answer": 42}',
 255                          },
 256                          "events": [],
 257                      })
 258                  ]
 259              ),
 260          )
 261      ]
 262      trace = MlflowClient().get_trace(trace_id)
 263      mock_store.batch_get_traces.assert_called_once_with([trace_id], "catalog.schema")
 264      mock_artifact_repo.download_trace_data.assert_not_called()
 265  
 266      assert trace.info.trace_id == trace_id
 267      assert trace.info.trace_location.uc_schema.catalog_name == "catalog"
 268      assert trace.info.trace_location.uc_schema.schema_name == "schema"
 269      assert trace.info.timestamp_ms == 123
 270      assert trace.info.execution_time_ms == 456
 271      assert trace.info.status == TraceStatus.OK
 272      assert trace.info.tags == {"mlflow.artifactLocation": "dbfs:/path/to/artifacts"}
 273      assert trace.data.request == '{"prompt": "What is the meaning of life?"}'
 274      assert trace.data.response == '{"answer": 42}'
 275      assert len(trace.data.spans) == 1
 276      assert trace.data.spans[0].name == "predict"
 277      assert trace.data.spans[0].trace_id == trace_id
 278      assert trace.data.spans[0].inputs == {"prompt": "What is the meaning of life?"}
 279      assert trace.data.spans[0].outputs == {"answer": 42}
 280      assert trace.data.spans[0].start_time_ns == 123000000
 281      assert trace.data.spans[0].end_time_ns == 579000000
 282      assert trace.data.spans[0].status.status_code == SpanStatusCode.OK
 283  
 284  
 285  def test_client_get_trace_empty_result(mock_store):
 286      mock_store.batch_get_traces.return_value = []
 287      with pytest.raises(MlflowException, match="not found"):
 288          MlflowClient().get_trace("trace:/catalog.schema/123")
 289  
 290  
 291  def test_client_get_trace_from_artifact_repo(mock_store, mock_artifact_repo):
 292      mock_store.get_trace_info.return_value = TraceInfo(
 293          trace_id="tr-1234567",
 294          trace_location=TraceLocation.from_experiment_id("0"),
 295          request_time=123,
 296          execution_duration=456,
 297          state=TraceState.OK,
 298          tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts"},
 299      )
 300      mock_artifact_repo.download_trace_data.return_value = {
 301          "request": '{"prompt": "What is the meaning of life?"}',
 302          "response": '{"answer": 42}',
 303          "spans": [
 304              {
 305                  "name": "predict",
 306                  "context": {
 307                      "trace_id": "0x123456789",
 308                      "span_id": "0x12345",
 309                  },
 310                  "parent_id": None,
 311                  "start_time": 123000000,
 312                  "end_time": 579000000,
 313                  "status_code": "OK",
 314                  "status_message": "",
 315                  "attributes": {
 316                      "mlflow.traceRequestId": '"tr-1234567"',
 317                      "mlflow.spanType": '"LLM"',
 318                      "mlflow.spanFunctionName": '"predict"',
 319                      "mlflow.spanInputs": '{"prompt": "What is the meaning of life?"}',
 320                      "mlflow.spanOutputs": '{"answer": 42}',
 321                  },
 322                  "events": [],
 323              }
 324          ],
 325      }
 326      trace = MlflowClient().get_trace("1234567")
 327      mock_store.get_trace_info.assert_called_once_with("1234567")
 328      mock_artifact_repo.download_trace_data.assert_called_once()
 329  
 330      assert trace.info.trace_id == "tr-1234567"
 331      assert trace.info.experiment_id == "0"
 332      assert trace.info.timestamp_ms == 123
 333      assert trace.info.execution_time_ms == 456
 334      assert trace.info.status == TraceStatus.OK
 335      assert trace.info.tags == {"mlflow.artifactLocation": "dbfs:/path/to/artifacts"}
 336      assert trace.data.request == '{"prompt": "What is the meaning of life?"}'
 337      assert trace.data.response == '{"answer": 42}'
 338      assert len(trace.data.spans) == 1
 339      assert trace.data.spans[0].name == "predict"
 340      assert trace.data.spans[0].trace_id == "tr-1234567"
 341      assert trace.data.spans[0].inputs == {"prompt": "What is the meaning of life?"}
 342      assert trace.data.spans[0].outputs == {"answer": 42}
 343      assert trace.data.spans[0].start_time_ns == 123000000
 344      assert trace.data.spans[0].end_time_ns == 579000000
 345      assert trace.data.spans[0].status.status_code == SpanStatusCode.OK
 346  
 347  
 348  def test_client_get_trace_throws_for_missing_or_corrupted_data(mock_store, mock_artifact_repo):
 349      mock_store.get_trace_info.return_value = TraceInfo(
 350          trace_id="1234567",
 351          trace_location=TraceLocation.from_experiment_id("0"),
 352          request_time=123,
 353          execution_duration=456,
 354          state=TraceState.OK,
 355          tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts"},
 356      )
 357      mock_artifact_repo.download_trace_data.side_effect = MlflowTraceDataNotFound("1234567")
 358  
 359      with pytest.raises(
 360          MlflowException,
 361          match="Trace with ID 1234567 cannot be loaded because it is missing span data",
 362      ):
 363          MlflowClient().get_trace("1234567")
 364  
 365      mock_artifact_repo.download_trace_data.side_effect = MlflowTraceDataCorrupted("1234567")
 366      with pytest.raises(
 367          MlflowException,
 368          match="Trace with ID 1234567 cannot be loaded because its span data is corrupted",
 369      ):
 370          MlflowClient().get_trace("1234567")
 371  
 372  
 373  @pytest.mark.parametrize("include_spans", [True, False])
 374  @pytest.mark.parametrize("num_results", [0, 5])
 375  def test_client_search_traces_with_get_traces(
 376      mock_store, mock_artifact_repo, include_spans, num_results
 377  ):
 378      mock_trace_infos = [
 379          TraceInfo(
 380              trace_id=f"trace:/catalog.schema/{i}",
 381              trace_location=TraceLocation(
 382                  type=TraceLocationType.UC_SCHEMA,
 383                  uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"),
 384              ),
 385              request_time=123,
 386              execution_duration=456,
 387              state=TraceState.OK,
 388          )
 389          for i in range(num_results)
 390      ]
 391      mock_store.search_traces.return_value = (mock_trace_infos, None)
 392      mock_store.batch_get_traces.return_value = [
 393          Trace(info=info, data=TraceData(spans=[])) for info in mock_trace_infos
 394      ]
 395  
 396      results = MlflowClient().search_traces(
 397          locations=["catalog.schema"],
 398          include_spans=include_spans,
 399      )
 400      mock_store.search_traces.assert_called_once_with(
 401          experiment_ids=None,
 402          filter_string=None,
 403          max_results=100,
 404          order_by=None,
 405          page_token=None,
 406          model_id=None,
 407          locations=["catalog.schema"],
 408      )
 409      assert len(results) == num_results
 410  
 411      if include_spans and num_results > 0:
 412          mock_store.batch_get_traces.assert_called_once_with(
 413              [f"trace:/catalog.schema/{i}" for i in range(num_results)],
 414              "catalog.schema",
 415          )
 416      else:
 417          mock_store.batch_get_traces.assert_not_called()
 418  
 419      mock_artifact_repo.download_trace_data.assert_not_called()
 420  
 421      # The TraceInfo is already fetched prior to the upload_trace_data call,
 422      # so we should not call _get_trace_info again
 423      mock_store.get_trace_info.assert_not_called()
 424  
 425  
 426  def test_client_search_traces_with_large_results(mock_store, mock_artifact_repo):
 427      mock_trace_infos = [
 428          TraceInfo(
 429              trace_id=f"trace:/catalog.schema/{i}",
 430              trace_location=TraceLocation(
 431                  type=TraceLocationType.UC_SCHEMA,
 432                  uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"),
 433              ),
 434              request_time=123,
 435              execution_duration=456,
 436              state=TraceState.OK,
 437          )
 438          for i in range(100)
 439      ]
 440      mock_store.search_traces.return_value = (mock_trace_infos, None)
 441  
 442      mock_store.batch_get_traces.return_value = [
 443          Trace(info=mock_trace_infos[0], data=TraceData(spans=[])) for i in range(10)
 444      ]
 445  
 446      results = MlflowClient().search_traces(locations=["catalog.schema"])
 447      mock_store.search_traces.assert_called_once_with(
 448          experiment_ids=None,
 449          filter_string=None,
 450          max_results=100,
 451          order_by=None,
 452          page_token=None,
 453          model_id=None,
 454          locations=["catalog.schema"],
 455      )
 456      assert len(results) == 100
 457      assert mock_store.batch_get_traces.call_count == 10
 458      assert mock_store.batch_get_traces.has_calls([
 459          mock.call([f"trace:/catalog.schema/{j * 10 + i}" for i in range(10)], "catalog.schema")
 460          for j in range(10)
 461      ])
 462      mock_artifact_repo.download_trace_data.assert_not_called()
 463  
 464  
 465  @pytest.mark.parametrize("include_spans", [True, False])
 466  def test_client_search_traces_mixed(mock_store, mock_artifact_repo, include_spans):
 467      mock_traces = [
 468          TraceInfo(
 469              trace_id="1234567",
 470              trace_location=TraceLocation(
 471                  type=TraceLocationType.UC_SCHEMA,
 472                  uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"),
 473              ),
 474              request_time=123,
 475              execution_duration=456,
 476              state=TraceState.OK,
 477              tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/1"},
 478          ),
 479          TraceInfo(
 480              trace_id="8910",
 481              trace_location=TraceLocation.from_experiment_id("1"),
 482              request_time=456,
 483              execution_duration=789,
 484              state=TraceState.OK,
 485              tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/2"},
 486          ),
 487      ]
 488      mock_store.search_traces.return_value = (mock_traces, None)
 489      mock_store.batch_get_traces.return_value = [
 490          Trace(info=mock_traces[0], data=TraceData(spans=[]))
 491      ]
 492      mock_artifact_repo.download_trace_data.return_value = {}
 493      results = MlflowClient().search_traces(
 494          locations=["1", "catalog.schema"], include_spans=include_spans
 495      )
 496  
 497      mock_store.search_traces.assert_called_once_with(
 498          experiment_ids=None,
 499          filter_string=None,
 500          max_results=100,
 501          order_by=None,
 502          page_token=None,
 503          model_id=None,
 504          locations=["1", "catalog.schema"],
 505      )
 506      assert len(results) == 2
 507      if include_spans:
 508          mock_store.batch_get_traces.assert_called_once_with(["1234567"], "catalog.schema")
 509          mock_artifact_repo.download_trace_data.assert_called()
 510      else:
 511          mock_store.batch_get_traces.assert_not_called()
 512          mock_artifact_repo.download_trace_data.assert_not_called()
 513  
 514  
 515  @pytest.mark.parametrize("include_spans", [True, False])
 516  @pytest.mark.parametrize("num_results", [0, 5])
 517  def test_client_search_traces_with_get_traces_tracking_store(
 518      mock_store, mock_artifact_repo, include_spans, num_results
 519  ):
 520      mock_trace_infos = [
 521          TraceInfo(
 522              trace_id=f"tr-123456789{i}",
 523              trace_location=TraceLocation.from_experiment_id(f"exp-{i}"),
 524              request_time=123,
 525              execution_duration=456,
 526              state=TraceState.OK,
 527              tags={TraceTagKey.SPANS_LOCATION: SpansLocation.TRACKING_STORE},
 528          )
 529          for i in range(num_results)
 530      ]
 531      mock_store.search_traces.return_value = (mock_trace_infos, None)
 532      mock_store.batch_get_traces.return_value = [
 533          Trace(info=info, data=TraceData(spans=[])) for info in mock_trace_infos
 534      ]
 535  
 536      results = MlflowClient().search_traces(
 537          locations=["exp-0", "exp-1", "exp-2"],
 538          include_spans=include_spans,
 539      )
 540      mock_store.search_traces.assert_called_once_with(
 541          experiment_ids=None,
 542          filter_string=None,
 543          max_results=100,
 544          order_by=None,
 545          page_token=None,
 546          model_id=None,
 547          locations=["exp-0", "exp-1", "exp-2"],
 548      )
 549      assert len(results) == num_results
 550  
 551      if include_spans and num_results > 0:
 552          mock_store.batch_get_traces.assert_called_once_with(
 553              [f"tr-123456789{i}" for i in range(num_results)],
 554              None,
 555          )
 556      else:
 557          mock_store.batch_get_traces.assert_not_called()
 558  
 559      mock_artifact_repo.download_trace_data.assert_not_called()
 560  
 561      # The TraceInfo is already fetched prior to the upload_trace_data call,
 562      # so we should not call _get_trace_info again
 563      mock_store.get_trace_info.assert_not_called()
 564  
 565  
 566  @pytest.mark.parametrize("include_spans", [True, False])
 567  def test_client_search_traces_with_artifact_repo(mock_store, mock_artifact_repo, include_spans):
 568      mock_traces = [
 569          TraceInfo(
 570              trace_id="tr-1234567",
 571              trace_location=TraceLocation.from_experiment_id("1"),
 572              request_time=123,
 573              execution_duration=456,
 574              state=TraceState.OK,
 575              tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/1"},
 576          ),
 577          TraceInfo(
 578              trace_id="tr-8910",
 579              trace_location=TraceLocation.from_experiment_id("2"),
 580              request_time=456,
 581              execution_duration=789,
 582              state=TraceState.OK,
 583              tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/2"},
 584          ),
 585      ]
 586      mock_store.search_traces.return_value = (mock_traces, None)
 587      mock_artifact_repo.download_trace_data.return_value = {}
 588      results = MlflowClient().search_traces(locations=["1", "2", "3"], include_spans=include_spans)
 589  
 590      mock_store.search_traces.assert_called_once_with(
 591          experiment_ids=None,
 592          filter_string=None,
 593          max_results=100,
 594          order_by=None,
 595          page_token=None,
 596          model_id=None,
 597          locations=["1", "2", "3"],
 598      )
 599      assert len(results) == 2
 600      if include_spans:
 601          mock_artifact_repo.download_trace_data.assert_called()
 602      else:
 603          mock_artifact_repo.download_trace_data.assert_not_called()
 604  
 605      # The TraceInfo is already fetched prior to the upload_trace_data call,
 606      # so we should not call _get_trace_info again
 607      mock_store.get_trace_info.assert_not_called()
 608  
 609  
 610  @pytest.mark.parametrize("include_spans", [True, False])
 611  def test_client_search_traces_trace_data_download_error(mock_store, include_spans):
 612      class CustomArtifactRepository(ArtifactRepository):
 613          def log_artifact(self, local_file, artifact_path=None):
 614              raise NotImplementedError("Should not be called")
 615  
 616          def log_artifacts(self, local_dir, artifact_path=None):
 617              raise NotImplementedError("Should not be called")
 618  
 619          def list_artifacts(self, path):
 620              raise NotImplementedError("Should not be called")
 621  
 622          def _download_file(self, *args, **kwargs):
 623              raise Exception("Failed to download trace data")
 624  
 625      with mock.patch(
 626          "mlflow.tracing.client.get_artifact_repository",
 627          return_value=CustomArtifactRepository("test"),
 628      ) as mock_get_artifact_repository:
 629          mock_traces = [
 630              TraceInfo(
 631                  trace_id="1234567",
 632                  trace_location=TraceLocation.from_experiment_id("1"),
 633                  request_time=123,
 634                  execution_duration=456,
 635                  state=TraceState.OK,
 636                  tags={"mlflow.artifactLocation": "test"},
 637              ),
 638          ]
 639          mock_store.search_traces.return_value = (mock_traces, None)
 640          traces = MlflowClient().search_traces(locations=["1"], include_spans=include_spans)
 641  
 642          if include_spans:
 643              assert traces == []
 644              mock_get_artifact_repository.assert_called()
 645          else:
 646              assert len(traces) == 1
 647              assert traces[0].info.trace_id == "1234567"
 648              mock_get_artifact_repository.assert_not_called()
 649  
 650  
 651  def test_client_search_traces_validates_experiment_ids_type():
 652      with pytest.raises(MlflowException, match=r"locations must be a list"):
 653          MlflowClient().search_traces(locations=4)
 654  
 655      with pytest.raises(MlflowException, match=r"locations must be a list"):
 656          MlflowClient().search_traces(locations="4")
 657  
 658  
 659  def test_client_delete_traces(mock_store):
 660      MlflowClient().delete_traces(
 661          experiment_id="0",
 662          max_timestamp_millis=1,
 663          max_traces=2,
 664          trace_ids=["tr-1234"],
 665      )
 666      mock_store.delete_traces.assert_called_once_with(
 667          experiment_id="0",
 668          max_timestamp_millis=1,
 669          max_traces=2,
 670          trace_ids=["tr-1234"],
 671      )
 672  
 673  
 674  @pytest.fixture
 675  def disable_prompt_cache():
 676      from mlflow.environment_variables import (
 677          MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS,
 678          MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS,
 679      )
 680  
 681      MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS.set(0)
 682      MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS.set(0)
 683      yield
 684      MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS.unset()
 685      MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS.unset()
 686  
 687  
 688  @pytest.fixture(autouse=True)
 689  def reset_prompt_cache():
 690      PromptCache._reset_instance()
 691      yield
 692      PromptCache._reset_instance()
 693  
 694  
 695  @pytest.fixture(params=["file", "sqlalchemy"])
 696  def tracking_uri(request, tmp_path, db_uri):
 697      """Set an MLflow Tracking URI with different type of backend."""
 698      if request.param == "file":
 699          pytest.skip("FileStore is no longer supported.")
 700      if "MLFLOW_SKINNY" in os.environ and request.param == "sqlalchemy":
 701          pytest.skip("SQLAlchemy store is not available in skinny.")
 702  
 703      original_tracking_uri = mlflow.get_tracking_uri()
 704  
 705      if request.param == "file":
 706          tracking_uri = tmp_path.joinpath("file").as_uri()
 707      elif request.param == "sqlalchemy":
 708          tracking_uri = db_uri
 709  
 710      # NB: MLflow tracer does not handle the change of tracking URI well,
 711      # so we need to reset the tracer to switch the tracking URI during testing.
 712      mlflow.tracing.disable()
 713      mlflow.set_tracking_uri(tracking_uri)
 714      mlflow.tracing.enable()
 715  
 716      yield tracking_uri
 717  
 718      # Reset tracking URI
 719      mlflow.set_tracking_uri(original_tracking_uri)
 720  
 721  
 722  @pytest.mark.parametrize("with_active_run", [True, False])
 723  def test_start_and_end_trace(tracking_uri, with_active_run, async_logging_enabled):
 724      client = MlflowClient(tracking_uri)
 725  
 726      experiment_id = client.create_experiment("test_experiment")
 727  
 728      class TestModel:
 729          def predict(self, x, y):
 730              root_span = client.start_trace(
 731                  name="predict",
 732                  inputs={"x": x, "y": y},
 733                  tags={"tag": "tag_value"},
 734                  experiment_id=experiment_id,
 735              )
 736              trace_id = root_span.trace_id
 737  
 738              z = x + y
 739  
 740              child_span = client.start_span(
 741                  "child_span_1",
 742                  span_type=SpanType.LLM,
 743                  trace_id=trace_id,
 744                  parent_id=root_span.span_id,
 745                  inputs={"z": z},
 746              )
 747  
 748              z = z + 2
 749  
 750              client.end_span(
 751                  trace_id=trace_id,
 752                  span_id=child_span.span_id,
 753                  outputs={"output": z},
 754                  attributes={"delta": 2},
 755              )
 756  
 757              res = self.square(z, trace_id, root_span.span_id)
 758              client.end_trace(trace_id, outputs={"output": res}, status="OK")
 759              return res
 760  
 761          def square(self, t, trace_id, parent_id):
 762              span = client.start_span(
 763                  "child_span_2",
 764                  trace_id=trace_id,
 765                  parent_id=parent_id,
 766                  inputs={"t": t},
 767              )
 768  
 769              res = t**2
 770              time.sleep(0.1)
 771  
 772              client.end_span(
 773                  trace_id=trace_id,
 774                  span_id=span.span_id,
 775                  outputs={"output": res},
 776              )
 777              return res
 778  
 779      model = TestModel()
 780      if with_active_run:
 781          with mlflow.start_run(experiment_id=experiment_id) as run:
 782              model.predict(1, 2)
 783              run_id = run.info.run_id
 784      else:
 785          model.predict(1, 2)
 786  
 787      if async_logging_enabled:
 788          mlflow.flush_trace_async_logging(terminate=True)
 789  
 790      trace_id = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True).info.trace_id
 791  
 792      # Validate that trace is logged to the backend
 793      trace = client.get_trace(trace_id)
 794      assert trace is not None
 795  
 796      assert trace.info.trace_id is not None
 797      assert trace.info.execution_time_ms >= 0.1 * 1e3  # at least 0.1 sec
 798      assert trace.info.status == TraceStatus.OK
 799      assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}'
 800      assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == '{"output": 25}'
 801      if with_active_run:
 802          assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id
 803          assert trace.info.experiment_id == run.info.experiment_id
 804      else:
 805          assert trace.info.experiment_id == experiment_id
 806  
 807      assert trace.data.request == '{"x": 1, "y": 2}'
 808      assert trace.data.response == '{"output": 25}'
 809      assert len(trace.data.spans) == 3
 810  
 811      span_name_to_span = {span.name: span for span in trace.data.spans}
 812      root_span = span_name_to_span["predict"]
 813      # NB: Start time of root span and trace info does not match because there is some
 814      #   latency for starting the trace within the backend
 815      # assert root_span.start_time_ns // 1e6 == trace.info.timestamp_ms
 816      assert root_span.parent_id is None
 817      assert root_span.attributes == {
 818          "mlflow.experimentId": experiment_id,
 819          "mlflow.traceRequestId": trace.info.trace_id,
 820          "mlflow.spanType": "UNKNOWN",
 821          "mlflow.spanInputs": {"x": 1, "y": 2},
 822          "mlflow.spanOutputs": {"output": 25},
 823      }
 824  
 825      child_span_1 = span_name_to_span["child_span_1"]
 826      assert child_span_1.parent_id == root_span.span_id
 827      assert child_span_1.attributes == {
 828          "mlflow.traceRequestId": trace.info.trace_id,
 829          "mlflow.spanType": "LLM",
 830          "mlflow.spanInputs": {"z": 3},
 831          "mlflow.spanOutputs": {"output": 5},
 832          "delta": 2,
 833      }
 834  
 835      child_span_2 = span_name_to_span["child_span_2"]
 836      assert child_span_2.parent_id == root_span.span_id
 837      assert child_span_2.attributes == {
 838          "mlflow.traceRequestId": trace.info.trace_id,
 839          "mlflow.spanType": "UNKNOWN",
 840          "mlflow.spanInputs": {"t": 5},
 841          "mlflow.spanOutputs": {"output": 25},
 842      }
 843      assert child_span_2.start_time_ns <= child_span_2.end_time_ns - 0.1 * 1e6
 844  
 845  
 846  def test_start_and_end_trace_capture_falsy_input_and_output(tracking_uri):
 847      # This test is to verify that falsy input and output values are correctly logged
 848      client = MlflowClient(tracking_uri)
 849      experiment_id = client.create_experiment("test_experiment")
 850  
 851      root = client.start_trace(name="root", experiment_id=experiment_id, inputs=[])
 852      span = client.start_span(name="child", trace_id=root.trace_id, parent_id=root.span_id, inputs=0)
 853      client.end_span(trace_id=root.trace_id, span_id=span.span_id, outputs=False)
 854      client.end_trace(trace_id=root.trace_id, outputs="")
 855  
 856      trace = client.get_trace(root.trace_id, flush=True)
 857      assert trace.data.spans[0].inputs == []
 858      assert trace.data.spans[0].outputs == ""
 859      assert trace.data.spans[1].inputs == 0
 860      assert trace.data.spans[1].outputs is False
 861  
 862  
 863  # TODO: we should investigate whether we need to support this
 864  @pytest.mark.skip(reason="This is not supported by latest span-level export")
 865  @pytest.mark.usefixtures("reset_active_experiment")
 866  def test_start_and_end_trace_before_all_span_end(async_logging_enabled):
 867      # This test is to verify that the trace is still exported even if some spans are not ended
 868      exp_id = mlflow.set_experiment("test_experiment_1").experiment_id
 869  
 870      class TestModel:
 871          def __init__(self):
 872              self._client = MlflowClient()
 873  
 874          def predict(self, x):
 875              root_span = self._client.start_trace(name="predict")
 876              trace_id = root_span.trace_id
 877              child_span = self._client.start_span(
 878                  "ended-span",
 879                  trace_id=trace_id,
 880                  parent_id=root_span.span_id,
 881              )
 882              time.sleep(0.1)
 883              self._client.end_span(trace_id, child_span.span_id)
 884  
 885              res = self.square(x, trace_id, root_span.span_id)
 886              self._client.end_trace(trace_id)
 887              return res
 888  
 889          def square(self, t, trace_id, parent_id):
 890              self._client.start_span("non-ended-span", trace_id=trace_id, parent_id=parent_id)
 891              time.sleep(0.1)
 892              # The span created above is not ended
 893              return t**2
 894  
 895      model = TestModel()
 896      model.predict(1)
 897  
 898      if async_logging_enabled:
 899          mlflow.flush_trace_async_logging(terminate=True)
 900  
 901      traces = MlflowClient().search_traces(locations=[exp_id])
 902      assert len(traces) == 1
 903  
 904      trace_info = traces[0].info
 905      assert trace_info.trace_id is not None
 906      assert trace_info.experiment_id == exp_id
 907      assert trace_info.timestamp_ms is not None
 908      assert trace_info.execution_time_ms is not None
 909      assert trace_info.status == TraceStatus.OK
 910  
 911      trace_data = traces[0].data
 912      assert trace_data.request is None
 913      assert trace_data.response is None
 914      assert len(trace_data.spans) == 3  # The non-ended span should be also included in the trace
 915  
 916      span_name_to_span = {span.name: span for span in trace_data.spans}
 917      root_span = span_name_to_span["predict"]
 918      assert root_span.parent_id is None
 919      assert root_span.status.status_code == SpanStatusCode.OK
 920  
 921      ended_span = span_name_to_span["ended-span"]
 922      assert ended_span.parent_id == root_span.span_id
 923      assert ended_span.start_time_ns < ended_span.end_time_ns
 924      assert ended_span.status.status_code == SpanStatusCode.OK
 925  
 926      # The non-ended span should have null end_time and UNSET status
 927      non_ended_span = span_name_to_span["non-ended-span"]
 928      assert non_ended_span.parent_id == root_span.span_id
 929      assert non_ended_span.start_time_ns is not None
 930      assert non_ended_span.end_time_ns is None
 931      assert non_ended_span.status.status_code == SpanStatusCode.UNSET
 932  
 933  
 934  def test_log_trace_with_databricks_tracking_uri(mock_store_start_trace, monkeypatch):
 935      monkeypatch.setenv("MLFLOW_EXPERIMENT_NAME", "test")
 936      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
 937  
 938      mock_experiment = mock.MagicMock()
 939      mock_experiment.experiment_id = "test_experiment_id"
 940  
 941      class TestModel:
 942          def __init__(self):
 943              self._client = MlflowClient()
 944  
 945          def predict(self, x, y):
 946              root_span = self._client.start_trace(
 947                  name="predict",
 948                  inputs={"x": x, "y": y},
 949                  # Trying to override mlflow.user tag, which will be ignored
 950                  tags={"tag": "tag_value", "mlflow.user": "unknown"},
 951              )
 952              trace_id = root_span.trace_id
 953  
 954              z = x + y
 955  
 956              child_span = self._client.start_span(
 957                  "child_span_1",
 958                  span_type=SpanType.LLM,
 959                  trace_id=trace_id,
 960                  parent_id=root_span.span_id,
 961                  inputs={"z": z},
 962              )
 963  
 964              z = z + 2
 965  
 966              self._client.end_span(
 967                  trace_id=trace_id,
 968                  span_id=child_span.span_id,
 969                  outputs={"output": z},
 970                  attributes={"delta": 2},
 971              )
 972              self._client.end_trace(trace_id, outputs=z, status="OK")
 973              return z
 974  
 975      model = TestModel()
 976  
 977      with (
 978          mock.patch("mlflow.get_tracking_uri", return_value="databricks"),
 979          mock.patch("mlflow.tracking.context.default_context._get_source_name", return_value="test"),
 980          mock.patch(
 981              "mlflow.tracing.client.TracingClient._upload_trace_data"
 982          ) as mock_upload_trace_data,
 983          mock.patch(
 984              "mlflow.tracing.client.TracingClient.set_trace_tags",
 985          ),
 986          mock.patch(
 987              "mlflow.tracing.client.TracingClient.set_trace_tag",
 988          ),
 989          mock.patch(
 990              "mlflow.tracing.client.TracingClient.get_trace_info",
 991          ),
 992          mock.patch(
 993              "mlflow.MlflowClient.get_experiment_by_name",
 994              return_value=mock_experiment,
 995          ),
 996      ):
 997          model.predict(1, 2)
 998          mlflow.flush_trace_async_logging(terminate=True)
 999  
1000      mock_store_start_trace.assert_called_once()
1001      mock_upload_trace_data.assert_called_once()
1002  
1003  
1004  def test_start_and_end_trace_does_not_log_trace_when_disabled(
1005      tracking_uri, monkeypatch, async_logging_enabled
1006  ):
1007      client = MlflowClient(tracking_uri)
1008      experiment_id = client.create_experiment("test_experiment")
1009  
1010      @trace_disabled
1011      def func():
1012          span = client.start_trace(
1013              name="predict",
1014              experiment_id=experiment_id,
1015              inputs={"x": 1, "y": 2},
1016              attributes={"attr": "value"},
1017              tags={"tag": "tag_value"},
1018          )
1019          child_span = client.start_span(
1020              "child_span_1",
1021              trace_id=span.trace_id,
1022              parent_id=span.span_id,
1023          )
1024          client.end_span(
1025              trace_id=span.trace_id,
1026              span_id=child_span.span_id,
1027              outputs={"output": 5},
1028          )
1029          client.end_trace(span.trace_id, outputs=5, status="OK")
1030          return "done"
1031  
1032      mock_logger = mock.MagicMock()
1033      monkeypatch.setattr(mlflow.tracking.client, "_logger", mock_logger)
1034  
1035      res = func()
1036  
1037      assert res == "done"
1038      assert client.search_traces(locations=[experiment_id]) == []
1039      # No warning should be issued
1040      mock_logger.warning.assert_not_called()
1041  
1042  
1043  def test_start_trace_within_active_run(async_logging_enabled):
1044      exp_id = mlflow.create_experiment("test")
1045  
1046      client = mlflow.MlflowClient()
1047      with mlflow.start_run():
1048          root_span = client.start_trace(
1049              name="test",
1050              experiment_id=exp_id,
1051          )
1052          client.end_trace(root_span.trace_id)
1053  
1054      if async_logging_enabled:
1055          mlflow.flush_trace_async_logging(terminate=True)
1056  
1057      traces = client.search_traces(locations=[exp_id])
1058      assert len(traces) == 1
1059      assert traces[0].info.experiment_id == exp_id
1060  
1061  
1062  def test_start_trace_raise_error_when_active_trace_exists():
1063      with mlflow.start_span("fluent_span"):
1064          with pytest.raises(MlflowException, match=r"Another trace is already set in the global"):
1065              mlflow.tracking.MlflowClient().start_trace("test")
1066  
1067  
1068  def test_end_trace_raise_error_when_trace_not_exist():
1069      client = mlflow.tracking.MlflowClient()
1070      mock_tracing_client = mock.MagicMock()
1071      mock_tracing_client.get_trace.return_value = None
1072      client._tracing_client = mock_tracing_client
1073  
1074      with pytest.raises(MlflowException, match=r"Trace with ID test not found"):
1075          client.end_trace("test")
1076  
1077  
1078  def test_end_trace_works_for_trace_in_pending_status():
1079      client = mlflow.tracking.MlflowClient()
1080      mock_tracing_client = mock.MagicMock()
1081      mock_tracing_client.get_trace.return_value = Trace(
1082          info=create_test_trace_info("test", state=TraceState.IN_PROGRESS), data=TraceData()
1083      )
1084      client._tracing_client = mock_tracing_client
1085      client.end_span = lambda *args: None
1086  
1087      client.end_trace("test")
1088  
1089  
1090  @pytest.mark.parametrize("state", [TraceState.OK, TraceState.ERROR])
1091  def test_end_trace_raise_error_for_trace_in_end_status(state):
1092      client = mlflow.tracking.MlflowClient()
1093      mock_tracing_client = mock.MagicMock()
1094      mock_tracing_client.get_trace.return_value = Trace(
1095          info=create_test_trace_info("test", state=state), data=TraceData()
1096      )
1097      client._tracing_client = mock_tracing_client
1098  
1099      with pytest.raises(MlflowException, match=r"Trace with ID test already finished"):
1100          client.end_trace("test")
1101  
1102  
1103  def test_trace_status_either_pending_or_end():
1104      all_statuses = {status.value for status in TraceStatus}
1105      pending_or_end_statuses = TraceStatus.pending_statuses() | TraceStatus.end_statuses()
1106      unclassified_statuses = all_statuses - pending_or_end_statuses
1107      assert len(unclassified_statuses) == 0, (
1108          f"Please add {unclassified_statuses} to "
1109          "either pending_statuses or end_statuses in TraceStatus class definition"
1110      )
1111  
1112  
1113  def test_start_span_raise_error_when_parent_id_is_not_provided():
1114      with pytest.raises(MlflowException, match=r"start_span\(\) must be called with"):
1115          mlflow.tracking.MlflowClient().start_span("span_name", trace_id="test", parent_id=None)
1116  
1117  
1118  def test_log_trace(tracking_uri):
1119      client = MlflowClient(tracking_uri)
1120      experiment_id = client.create_experiment("test_experiment")
1121  
1122      span = client.start_trace(
1123          name="test",
1124          span_type=SpanType.LLM,
1125          experiment_id=experiment_id,
1126          tags={"custom_tag": "tag_value"},
1127      )
1128      client.end_trace(span.trace_id, status="OK")
1129  
1130      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1131  
1132      # Purge all traces in the backend once
1133      client.delete_traces(experiment_id=experiment_id, trace_ids=[trace.info.trace_id])
1134      assert client.search_traces(locations=[experiment_id]) == []
1135  
1136      # Log the trace manually — _log_trace triggers async export via span processor
1137      new_trace_id = client._log_trace(trace)
1138  
1139      # Validate the trace is added to the backend (flush=True waits for async writes)
1140      backend_traces = client.search_traces(locations=[experiment_id], flush=True)
1141      assert len(backend_traces) == 1
1142      assert backend_traces[0].info.trace_id == new_trace_id  # new request ID is assigned
1143      assert backend_traces[0].info.experiment_id == experiment_id
1144      assert backend_traces[0].info.status == trace.info.status
1145      assert backend_traces[0].info.tags["custom_tag"] == "tag_value"
1146      assert backend_traces[0].info.request_preview == trace.info.request_preview
1147      assert backend_traces[0].info.response_preview == trace.info.response_preview
1148      assert len(backend_traces[0].data.spans) == len(trace.data.spans)
1149      assert backend_traces[0].data.spans[0].name == trace.data.spans[0].name
1150  
1151      # If the experiment ID is None in the given trace, it should be set to the default experiment
1152      trace.info.experiment_id = None
1153      new_trace_id = client._log_trace(trace)
1154      backend_traces = client.search_traces(locations=[DEFAULT_EXPERIMENT_ID], flush=True)
1155      assert len(backend_traces) == 1
1156  
1157  
1158  @pytest.mark.filterwarnings("ignore::FutureWarning")
1159  def test_search_traces_experiment_ids_deprecation_warning():
1160      client = MlflowClient()
1161      exp_id = mlflow.set_experiment("test_experiment_deprecation").experiment_id
1162  
1163      # Test that using experiment_ids shows a deprecation warning
1164      with pytest.warns(FutureWarning, match="experiment_ids.*deprecated.*use.*locations"):
1165          result = client.search_traces(experiment_ids=[exp_id])
1166      assert isinstance(result, list)
1167  
1168  
1169  def test_ignore_exception_from_tracing_logic(monkeypatch, async_logging_enabled):
1170      exp_id = mlflow.set_experiment("test_experiment_1").experiment_id
1171      client = MlflowClient()
1172  
1173      class TestModel:
1174          def predict(self, x):
1175              root_span = client.start_trace(experiment_id=exp_id, name="predict")
1176              trace_id = root_span.trace_id
1177              child_span = client.start_span(
1178                  name="child", trace_id=trace_id, parent_id=root_span.span_id
1179              )
1180              client.end_span(trace_id, child_span.span_id)
1181              client.end_trace(trace_id)
1182              return x
1183  
1184      model = TestModel()
1185  
1186      # Mock the span processor's on_end handler to raise an exception
1187      processor = _get_tracer(__name__).span_processor
1188  
1189      def _always_fail(*args, **kwargs):
1190          raise ValueError("Some error")
1191  
1192      # Exception while starting the trace should be caught not raise
1193      monkeypatch.setattr(processor, "on_start", _always_fail)
1194      response = model.predict(1)
1195      assert response == 1
1196      assert len(get_traces()) == 0
1197  
1198      # Exception while ending the trace should be caught not raise
1199      monkeypatch.setattr(processor, "on_end", _always_fail)
1200      response = model.predict(1)
1201      assert response == 1
1202      assert len(get_traces()) == 0
1203  
1204  
1205  def test_set_and_delete_trace_tag_on_active_trace(monkeypatch):
1206      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
1207      monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
1208  
1209      client = mlflow.tracking.MlflowClient()
1210  
1211      root_span = client.start_trace(name="test")
1212      trace_id = root_span.trace_id
1213      client.set_trace_tag(trace_id, "foo", "bar")
1214      client.end_trace(trace_id)
1215  
1216      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1217      assert trace.info.tags["foo"] == "bar"
1218  
1219  
1220  def test_set_trace_tag_on_logged_trace(mock_store):
1221      mlflow.tracking.MlflowClient().set_trace_tag("test", "foo", "bar")
1222      mlflow.tracking.MlflowClient().set_trace_tag("test", "mlflow.some.reserved.tag", "value")
1223      mock_store.set_trace_tag.assert_has_calls([
1224          mock.call("test", "foo", "bar"),
1225          mock.call("test", "mlflow.some.reserved.tag", "value"),
1226      ])
1227  
1228  
1229  def test_delete_trace_tag_on_active_trace(monkeypatch):
1230      monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob")
1231      monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test")
1232  
1233      client = mlflow.tracking.MlflowClient()
1234      root_span = client.start_trace(name="test", tags={"foo": "bar", "baz": "qux"})
1235      trace_id = root_span.trace_id
1236      client.delete_trace_tag(trace_id, "foo")
1237      client.end_trace(trace_id)
1238  
1239      trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True)
1240      assert "baz" in trace.info.tags
1241      assert "foo" not in trace.info.tags
1242  
1243  
1244  def test_delete_trace_tag_on_logged_trace(mock_store):
1245      mlflow.tracking.MlflowClient().delete_trace_tag("test", "foo")
1246      mock_store.delete_trace_tag.assert_called_once_with("test", "foo")
1247  
1248  
1249  def test_client_create_experiment(mock_store):
1250      MlflowClient().create_experiment("someName", "someLocation", {"key1": "val1", "key2": "val2"})
1251  
1252      mock_store.create_experiment.assert_called_once_with(
1253          artifact_location="someLocation",
1254          tags=[ExperimentTag("key1", "val1"), ExperimentTag("key2", "val2")],
1255          name="someName",
1256      )
1257  
1258  
1259  def test_client_create_run_overrides(mock_store):
1260      experiment_id = mock.Mock()
1261      user = mock.Mock()
1262      start_time = mock.Mock()
1263      run_name = mock.Mock()
1264      tags = {
1265          MLFLOW_USER: user,
1266          MLFLOW_PARENT_RUN_ID: mock.Mock(),
1267          MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB),
1268          MLFLOW_SOURCE_NAME: mock.Mock(),
1269          MLFLOW_PROJECT_ENTRY_POINT: mock.Mock(),
1270          MLFLOW_GIT_COMMIT: mock.Mock(),
1271          "other-key": "other-value",
1272      }
1273  
1274      MlflowClient().create_run(experiment_id, start_time, tags, run_name)
1275  
1276      mock_store.create_run.assert_called_once_with(
1277          experiment_id=experiment_id,
1278          user_id=user,
1279          start_time=start_time,
1280          tags=[RunTag(key, value) for key, value in tags.items()],
1281          run_name=run_name,
1282      )
1283      mock_store.reset_mock()
1284      MlflowClient().create_run(experiment_id, start_time, tags)
1285      mock_store.create_run.assert_called_once_with(
1286          experiment_id=experiment_id,
1287          user_id=user,
1288          start_time=start_time,
1289          tags=[RunTag(key, value) for key, value in tags.items()],
1290          run_name=None,
1291      )
1292  
1293  
1294  def test_client_set_terminated_no_change_name(mock_store):
1295      experiment_id = mock.Mock()
1296      run = MlflowClient().create_run(experiment_id, run_name="my name")
1297      MlflowClient().set_terminated(run.info.run_id)
1298      _, kwargs = mock_store.update_run_info.call_args
1299      assert kwargs["run_name"] is None
1300  
1301  
1302  def test_client_search_runs_defaults(mock_store):
1303      MlflowClient().search_runs([1, 2, 3])
1304      mock_store.search_runs.assert_called_once_with(
1305          experiment_ids=[1, 2, 3],
1306          filter_string="",
1307          run_view_type=ViewType.ACTIVE_ONLY,
1308          max_results=SEARCH_MAX_RESULTS_DEFAULT,
1309          order_by=None,
1310          page_token=None,
1311      )
1312  
1313  
1314  def test_client_search_runs_filter(mock_store):
1315      MlflowClient().search_runs(["a", "b", "c"], "my filter")
1316      mock_store.search_runs.assert_called_once_with(
1317          experiment_ids=["a", "b", "c"],
1318          filter_string="my filter",
1319          run_view_type=ViewType.ACTIVE_ONLY,
1320          max_results=SEARCH_MAX_RESULTS_DEFAULT,
1321          order_by=None,
1322          page_token=None,
1323      )
1324  
1325  
1326  def test_client_search_runs_view_type(mock_store):
1327      MlflowClient().search_runs(["a", "b", "c"], "my filter", ViewType.DELETED_ONLY)
1328      mock_store.search_runs.assert_called_once_with(
1329          experiment_ids=["a", "b", "c"],
1330          filter_string="my filter",
1331          run_view_type=ViewType.DELETED_ONLY,
1332          max_results=SEARCH_MAX_RESULTS_DEFAULT,
1333          order_by=None,
1334          page_token=None,
1335      )
1336  
1337  
1338  def test_client_search_runs_max_results(mock_store):
1339      MlflowClient().search_runs([5], "my filter", ViewType.ALL, 2876)
1340      mock_store.search_runs.assert_called_once_with(
1341          experiment_ids=[5],
1342          filter_string="my filter",
1343          run_view_type=ViewType.ALL,
1344          max_results=2876,
1345          order_by=None,
1346          page_token=None,
1347      )
1348  
1349  
1350  def test_client_search_runs_int_experiment_id(mock_store):
1351      MlflowClient().search_runs(123)
1352      mock_store.search_runs.assert_called_once_with(
1353          experiment_ids=[123],
1354          filter_string="",
1355          run_view_type=ViewType.ACTIVE_ONLY,
1356          max_results=SEARCH_MAX_RESULTS_DEFAULT,
1357          order_by=None,
1358          page_token=None,
1359      )
1360  
1361  
1362  def test_client_search_runs_string_experiment_id(mock_store):
1363      MlflowClient().search_runs("abc")
1364      mock_store.search_runs.assert_called_once_with(
1365          experiment_ids=["abc"],
1366          filter_string="",
1367          run_view_type=ViewType.ACTIVE_ONLY,
1368          max_results=SEARCH_MAX_RESULTS_DEFAULT,
1369          order_by=None,
1370          page_token=None,
1371      )
1372  
1373  
1374  def test_client_search_runs_order_by(mock_store):
1375      MlflowClient().search_runs([5], order_by=["a", "b"])
1376      mock_store.search_runs.assert_called_once_with(
1377          experiment_ids=[5],
1378          filter_string="",
1379          run_view_type=ViewType.ACTIVE_ONLY,
1380          max_results=SEARCH_MAX_RESULTS_DEFAULT,
1381          order_by=["a", "b"],
1382          page_token=None,
1383      )
1384  
1385  
1386  def test_client_search_runs_page_token(mock_store):
1387      MlflowClient().search_runs([5], page_token="blah")
1388      mock_store.search_runs.assert_called_once_with(
1389          experiment_ids=[5],
1390          filter_string="",
1391          run_view_type=ViewType.ACTIVE_ONLY,
1392          max_results=SEARCH_MAX_RESULTS_DEFAULT,
1393          order_by=None,
1394          page_token="blah",
1395      )
1396  
1397  
1398  def test_update_registered_model(mock_registry_store):
1399      """
1400      Update registered model no longer supports name change.
1401      """
1402      expected_return_value = "some expected return value."
1403      mock_registry_store.rename_registered_model.return_value = expected_return_value
1404      expected_return_value_2 = "other expected return value."
1405      mock_registry_store.update_registered_model.return_value = expected_return_value_2
1406      res = MlflowClient(registry_uri="sqlite:///somedb.db").update_registered_model(
1407          name="orig name", description="new description"
1408      )
1409      assert expected_return_value_2 == res
1410      mock_registry_store.update_registered_model.assert_called_once_with(
1411          name="orig name", description="new description", deployment_job_id=None
1412      )
1413      mock_registry_store.rename_registered_model.assert_not_called()
1414  
1415  
1416  def test_create_model_version(mock_registry_store):
1417      """
1418      Basic test for create model version.
1419      """
1420      mock_registry_store.create_model_version.return_value = _default_model_version()
1421      res = MlflowClient(registry_uri="sqlite:///somedb.db").create_model_version(
1422          "orig name", "source", "run-id", tags={"key": "value"}, description="desc"
1423      )
1424      assert res == _default_model_version()
1425      mock_registry_store.create_model_version.assert_called_once_with(
1426          "orig name",
1427          "source",
1428          "run-id",
1429          [ModelVersionTag(key="key", value="value")],
1430          None,
1431          "desc",
1432          local_model_path=None,
1433          model_id=None,
1434      )
1435  
1436  
1437  def test_update_model_version(mock_registry_store):
1438      """
1439      Update registered model no longer support state changes.
1440      """
1441      mock_registry_store.update_model_version.return_value = _default_model_version()
1442      res = MlflowClient(registry_uri="sqlite:///somedb.db").update_model_version(
1443          name="orig name", version="1", description="desc"
1444      )
1445      assert _default_model_version() == res
1446      mock_registry_store.update_model_version.assert_called_once_with(
1447          name="orig name", version="1", description="desc"
1448      )
1449      mock_registry_store.transition_model_version_stage.assert_not_called()
1450  
1451  
1452  def test_transition_model_version_stage(mock_registry_store):
1453      name = "Model 1"
1454      version = "12"
1455      stage = "Production"
1456      expected_result = ModelVersion(name, version, creation_timestamp=123, current_stage=stage)
1457      mock_registry_store.transition_model_version_stage.return_value = expected_result
1458      actual_result = MlflowClient(registry_uri="sqlite:///somedb.db").transition_model_version_stage(
1459          name, version, stage
1460      )
1461      mock_registry_store.transition_model_version_stage.assert_called_once_with(
1462          name=name, version=version, stage=stage, archive_existing_versions=False
1463      )
1464      assert expected_result == actual_result
1465  
1466  
1467  def test_registry_uri_set_as_param():
1468      uri = "sqlite:///somedb.db"
1469      client = MlflowClient(tracking_uri="databricks://tracking", registry_uri=uri)
1470      assert client._registry_uri == uri
1471  
1472  
1473  def test_registry_uri_from_set_registry_uri():
1474      uri = "sqlite:///somedb.db"
1475      set_registry_uri(uri)
1476      client = MlflowClient(tracking_uri="databricks://tracking")
1477      assert client._registry_uri == uri
1478      set_registry_uri(None)
1479  
1480  
1481  def test_registry_uri_from_tracking_uri_param():
1482      tracking_uri = "databricks://tracking_vhawoierj"
1483      with mock.patch(
1484          "mlflow.tracking._tracking_service.utils.get_tracking_uri",
1485          return_value=tracking_uri,
1486      ):
1487          client = MlflowClient(tracking_uri=tracking_uri)
1488          # For databricks tracking URIs, registry URI defaults to Unity Catalog with profile
1489          assert client._registry_uri == "databricks-uc://tracking_vhawoierj"
1490  
1491  
1492  def test_registry_uri_from_implicit_tracking_uri():
1493      tracking_uri = "databricks://tracking_wierojasdf"
1494      with mock.patch(
1495          "mlflow.tracking._tracking_service.utils.get_tracking_uri",
1496          return_value=tracking_uri,
1497      ):
1498          client = MlflowClient()
1499          # For databricks tracking URIs, registry URI defaults to Unity Catalog with profile
1500          assert client._registry_uri == "databricks-uc://tracking_wierojasdf"
1501  
1502  
1503  def test_create_model_version_nondatabricks_source_no_runlink(mock_registry_store):
1504      run_id = "runid"
1505      client = MlflowClient(tracking_uri="http://10.123.1231.11")
1506      mock_registry_store.create_model_version.return_value = ModelVersion(
1507          "name",
1508          1,
1509          0,
1510          1,
1511          source="source",
1512          run_id=run_id,
1513      )
1514      model_version = client.create_model_version("name", "source", "runid")
1515      assert model_version.name == "name"
1516      assert model_version.source == "source"
1517      assert model_version.run_id == "runid"
1518      # verify that the store was not provided a run link
1519      mock_registry_store.create_model_version.assert_called_once_with(
1520          "name", "source", "runid", [], None, None, local_model_path=None, model_id=None
1521      )
1522  
1523  
1524  def test_create_model_version_nondatabricks_source_no_run_id(mock_registry_store):
1525      client = MlflowClient(tracking_uri="http://10.123.1231.11")
1526      mock_registry_store.create_model_version.return_value = ModelVersion(
1527          "name", 1, 0, 1, source="source"
1528      )
1529      model_version = client.create_model_version("name", "source")
1530      assert model_version.name == "name"
1531      assert model_version.source == "source"
1532      assert model_version.run_id is None
1533      # verify that the store was not provided a run id
1534      mock_registry_store.create_model_version.assert_called_once_with(
1535          "name", "source", None, [], None, None, local_model_path=None, model_id=None
1536      )
1537  
1538  
1539  def test_create_model_version_explicitly_set_run_link(
1540      mock_registry_store, mock_databricks_tracking_store
1541  ):
1542      run_link = "my-run-link"
1543      hostname = "https://workspace.databricks.com/"
1544      workspace_id = "10002"
1545      mock_registry_store.create_model_version.return_value = ModelVersion(
1546          "name",
1547          1,
1548          0,
1549          1,
1550          source="source",
1551          run_id=mock_databricks_tracking_store.run_id,
1552          run_link=run_link,
1553      )
1554  
1555      # mocks to make sure that even if you're in a notebook, this setting is respected.
1556      with (
1557          mock.patch("mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True),
1558          mock.patch(
1559              "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils",
1560              return_value=(hostname, workspace_id),
1561          ),
1562      ):
1563          client = MlflowClient(tracking_uri="databricks", registry_uri="otherplace")
1564          model_version = client.create_model_version("name", "source", "runid", run_link=run_link)
1565          assert model_version.run_link == run_link
1566          # verify that the store was provided with the explicitly passed in run link
1567          mock_registry_store.create_model_version.assert_called_once_with(
1568              "name", "source", "runid", [], run_link, None, local_model_path=None, model_id=None
1569          )
1570  
1571  
1572  def test_create_model_version_run_link_in_notebook_with_default_profile(
1573      mock_registry_store, mock_databricks_tracking_store
1574  ):
1575      hostname = "https://workspace.databricks.com/"
1576      workspace_id = "10002"
1577      workspace_url = _construct_databricks_run_url(
1578          hostname,
1579          mock_databricks_tracking_store.experiment_id,
1580          mock_databricks_tracking_store.run_id,
1581          workspace_id,
1582      )
1583  
1584      with (
1585          mock.patch("mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True),
1586          mock.patch(
1587              "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils",
1588              return_value=(hostname, workspace_id),
1589          ),
1590      ):
1591          client = MlflowClient(tracking_uri="databricks", registry_uri="otherplace")
1592          mock_registry_store.create_model_version.return_value = ModelVersion(
1593              "name",
1594              1,
1595              0,
1596              1,
1597              source="source",
1598              run_id=mock_databricks_tracking_store.run_id,
1599              run_link=workspace_url,
1600          )
1601          model_version = client.create_model_version("name", "source", "runid")
1602          assert model_version.run_link == workspace_url
1603          # verify that the client generated the right URL
1604          mock_registry_store.create_model_version.assert_called_once_with(
1605              "name", "source", "runid", [], workspace_url, None, local_model_path=None, model_id=None
1606          )
1607  
1608  
1609  def test_create_model_version_with_source(mock_registry_store, mock_databricks_tracking_store):
1610      model_id = "model_id"
1611      mock_registry_store.create_model_version.return_value = ModelVersion(
1612          "name",
1613          1,
1614          0,
1615          1,
1616          source="/path/to/source",
1617          run_id=mock_databricks_tracking_store.run_id,
1618          run_link=None,
1619          model_id=model_id,
1620      )
1621      mock_logged_model = LoggedModel(
1622          experiment_id="exp_id",
1623          model_id="model_id",
1624          name="name",
1625          artifact_location="/path/to/source",
1626          creation_timestamp=0,
1627          last_updated_timestamp=0,
1628      )
1629  
1630      with mock.patch(
1631          "mlflow.tracking.client.MlflowClient.get_logged_model", return_value=mock_logged_model
1632      ):
1633          client = MlflowClient(tracking_uri="databricks")
1634          model_version = client.create_model_version(
1635              "name", f"models:/{model_id}", "runid", run_link=None, model_id=model_id
1636          )
1637          assert model_version.model_id == model_id
1638          mock_registry_store.create_model_version.assert_called_once_with(
1639              "name",
1640              f"models:/{model_id}",
1641              "runid",
1642              [],
1643              None,
1644              None,
1645              local_model_path=None,
1646              model_id="model_id",
1647          )
1648  
1649      mock_registry_store.create_model_version.reset_mock()
1650      with mock.patch(
1651          "mlflow.tracking.client.MlflowClient.get_logged_model", return_value=mock_logged_model
1652      ):
1653          client = MlflowClient(tracking_uri="databricks", registry_uri="databricks-uc")
1654          model_version = client.create_model_version(
1655              "name", f"models:/{model_id}", "runid", run_link=None, model_id=model_id
1656          )
1657          assert model_version.model_id == model_id
1658          mock_registry_store.create_model_version.assert_called_once_with(
1659              "name",
1660              f"models:/{model_id}",
1661              "runid",
1662              [],
1663              None,
1664              None,
1665              local_model_path=None,
1666              model_id="model_id",
1667          )
1668  
1669  
1670  def test_create_model_version_with_nondatabricks_source_uc_registry(mock_registry_store):
1671      name = "name"
1672      model_id = "model_id"
1673      run_id = "runid"
1674      source = "/path/to/source"
1675      model_uri = f"models:/{model_id}"
1676      mock_registry_store.create_model_version.return_value = ModelVersion(
1677          "name",
1678          1,
1679          0,
1680          1,
1681          source=source,
1682          run_id=run_id,
1683          run_link=None,
1684          model_id=model_id,
1685      )
1686      mock_logged_model = LoggedModel(
1687          experiment_id="exp_id",
1688          model_id=model_id,
1689          name=name,
1690          artifact_location=source,
1691          creation_timestamp=0,
1692          last_updated_timestamp=0,
1693      )
1694  
1695      with mock.patch(
1696          "mlflow.tracking.client.MlflowClient.get_logged_model", return_value=mock_logged_model
1697      ):
1698          client = MlflowClient(tracking_uri="http://10.123.1231.11", registry_uri="databricks-uc")
1699          model_version = client.create_model_version(
1700              name, model_uri, run_id, run_link=None, model_id=model_id
1701          )
1702          assert model_version.model_id == model_id
1703          mock_registry_store.create_model_version.assert_called_once_with(
1704              name,
1705              source,
1706              run_id,
1707              [],
1708              None,
1709              None,
1710              local_model_path=None,
1711              model_id=None,
1712          )
1713  
1714  
1715  def test_creation_default_values_in_unity_catalog(mock_registry_store):
1716      client = MlflowClient(tracking_uri="databricks", registry_uri="databricks-uc")
1717      mock_registry_store.create_model_version.return_value = ModelVersion(
1718          "name",
1719          1,
1720          0,
1721          1,
1722          source="source",
1723          run_id="runid",
1724      )
1725      client.create_model_version("name", "source", "runid")
1726      # verify that registry store was called with tags=[] and run_link=None
1727      mock_registry_store.create_model_version.assert_called_once_with(
1728          "name", "source", "runid", [], None, None, local_model_path=None, model_id=None
1729      )
1730      client.create_registered_model(name="name", description="description")
1731      # verify that registry store was called with tags=[]
1732      mock_registry_store.create_registered_model.assert_called_once_with(
1733          "name", [], "description", None
1734      )
1735  
1736  
1737  def test_await_model_version_creation(mock_registry_store):
1738      mv = ModelVersion(
1739          name="name",
1740          version=1,
1741          creation_timestamp=123,
1742          status=ModelVersionStatus.to_string(ModelVersionStatus.FAILED_REGISTRATION),
1743      )
1744      mock_registry_store.create_model_version.return_value = mv
1745  
1746      client = MlflowClient(tracking_uri="http://10.123.1231.11")
1747  
1748      client.create_model_version("name", "source")
1749      mock_registry_store._await_model_version_creation.assert_called_once_with(
1750          mv, DEFAULT_AWAIT_MAX_SLEEP_SECONDS
1751      )
1752  
1753  
1754  def test_create_model_version_run_link_with_configured_profile(
1755      mock_registry_store, mock_databricks_tracking_store
1756  ):
1757      hostname = "https://workspace.databricks.com/"
1758      workspace_id = "10002"
1759      workspace_url = _construct_databricks_run_url(
1760          hostname,
1761          mock_databricks_tracking_store.experiment_id,
1762          mock_databricks_tracking_store.run_id,
1763          workspace_id,
1764      )
1765  
1766      with (
1767          mock.patch("mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=False),
1768          mock.patch(
1769              "mlflow.utils.databricks_utils.get_workspace_info_from_databricks_secrets",
1770              return_value=(hostname, workspace_id),
1771          ),
1772      ):
1773          client = MlflowClient(tracking_uri="databricks", registry_uri="otherplace")
1774          mock_registry_store.create_model_version.return_value = ModelVersion(
1775              "name",
1776              1,
1777              0,
1778              1,
1779              source="source",
1780              run_id=mock_databricks_tracking_store.run_id,
1781              run_link=workspace_url,
1782          )
1783          model_version = client.create_model_version("name", "source", "runid")
1784          assert model_version.run_link == workspace_url
1785          # verify that the client generated the right URL
1786          mock_registry_store.create_model_version.assert_called_once_with(
1787              "name", "source", "runid", [], workspace_url, None, local_model_path=None, model_id=None
1788          )
1789  
1790  
1791  def test_create_model_version_copy_called_db_to_db(mock_registry_store):
1792      client = MlflowClient(
1793          tracking_uri="databricks://tracking",
1794          registry_uri="databricks://registry:workspace",
1795      )
1796      mock_registry_store.create_model_version.return_value = _default_model_version()
1797      with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock:
1798          client.create_model_version(
1799              "model name",
1800              "dbfs:/source",
1801              "run_12345",
1802              run_link="not:/important/for/test",
1803          )
1804          upload_mock.assert_called_once_with(
1805              "dbfs:/source",
1806              "run_12345",
1807              "databricks://tracking",
1808              "databricks://registry:workspace",
1809          )
1810  
1811  
1812  def test_create_model_version_copy_called_nondb_to_db(mock_registry_store):
1813      client = MlflowClient(
1814          tracking_uri="https://tracking", registry_uri="databricks://registry:workspace"
1815      )
1816      mock_registry_store.create_model_version.return_value = _default_model_version()
1817      with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock:
1818          client.create_model_version(
1819              "model name", "s3:/source", "run_12345", run_link="not:/important/for/test"
1820          )
1821          upload_mock.assert_called_once_with(
1822              "s3:/source",
1823              "run_12345",
1824              "https://tracking",
1825              "databricks://registry:workspace",
1826          )
1827  
1828  
1829  def test_create_model_version_copy_not_called_to_db(mock_registry_store):
1830      client = MlflowClient(
1831          tracking_uri="databricks://registry:workspace",
1832          registry_uri="databricks://registry:workspace",
1833      )
1834      mock_registry_store.create_model_version.return_value = _default_model_version()
1835      with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock:
1836          client.create_model_version(
1837              "model name",
1838              "dbfs:/source",
1839              "run_12345",
1840              run_link="not:/important/for/test",
1841          )
1842          upload_mock.assert_not_called()
1843  
1844  
1845  def test_create_model_version_copy_not_called_to_nondb(mock_registry_store):
1846      client = MlflowClient(tracking_uri="databricks://tracking", registry_uri="https://registry")
1847      mock_registry_store.create_model_version.return_value = _default_model_version()
1848      with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock:
1849          client.create_model_version(
1850              "model name",
1851              "dbfs:/source",
1852              "run_12345",
1853              run_link="not:/important/for/test",
1854          )
1855          upload_mock.assert_not_called()
1856  
1857  
1858  def _default_model_version():
1859      return ModelVersion("model name", 1, creation_timestamp=123, status="READY")
1860  
1861  
1862  def test_client_can_be_serialized_with_pickle(tmp_path):
1863      """
1864      Verifies that instances of `MlflowClient` can be serialized using pickle, even if the underlying
1865      Tracking and Model Registry stores used by the client are not serializable using pickle
1866      """
1867  
1868      class MockUnpickleableTrackingStore(SqlAlchemyTrackingStore):
1869          pass
1870  
1871      class MockUnpickleableModelRegistryStore(SqlAlchemyModelRegistryStore):
1872          pass
1873  
1874      backend_store_path = tmp_path.joinpath("test.db")
1875      artifact_store_path = tmp_path.joinpath("artifacts")
1876  
1877      mock_tracking_store = MockUnpickleableTrackingStore(
1878          f"sqlite:///{backend_store_path}", str(artifact_store_path)
1879      )
1880      mock_model_registry_store = MockUnpickleableModelRegistryStore(
1881          f"sqlite:///{backend_store_path}"
1882      )
1883  
1884      # Verify that the mock stores cannot be pickled because they are defined within a function
1885      # (i.e. the test function)
1886      with pytest.raises(AttributeError, match="<locals>.MockUnpickleableTrackingStore'"):
1887          pickle.dumps(mock_tracking_store)
1888  
1889      with pytest.raises(AttributeError, match="<locals>.MockUnpickleableModelRegistryStore'"):
1890          pickle.dumps(mock_model_registry_store)
1891  
1892      _register("pickle", lambda *args, **kwargs: mock_tracking_store)
1893      _get_model_registry_store_registry().register(
1894          "pickle", lambda *args, **kwargs: mock_model_registry_store
1895      )
1896  
1897      # Create an MlflowClient with the store that cannot be pickled, perform
1898      # tracking & model registry operations, and verify that the client can still be pickled
1899      client = MlflowClient("pickle://foo")
1900      client.create_experiment("test_experiment")
1901      client.create_registered_model("test_model")
1902      pickle.dumps(client)
1903  
1904  
1905  @pytest.fixture
1906  def mock_registry_store_with_get_latest_version(mock_registry_store):
1907      mock_get_latest_versions = mock.Mock()
1908      mock_get_latest_versions.return_value = [
1909          ModelVersion(
1910              "model_name",
1911              1,
1912              0,
1913          )
1914      ]
1915  
1916      mock_registry_store.get_latest_versions = mock_get_latest_versions
1917      return mock_registry_store
1918  
1919  
1920  def test_set_model_version_tag(mock_registry_store_with_get_latest_version):
1921      # set_model_version_tag using version
1922      MlflowClient().set_model_version_tag("model_name", 1, "tag1", "foobar")
1923      mock_registry_store_with_get_latest_version.set_model_version_tag.assert_called_once_with(
1924          "model_name", 1, ModelVersionTag(key="tag1", value="foobar")
1925      )
1926  
1927      mock_registry_store_with_get_latest_version.set_model_version_tag.reset_mock()
1928  
1929      # set_model_version_tag using stage
1930      MlflowClient().set_model_version_tag("model_name", key="tag1", value="foobar", stage="Staging")
1931      mock_registry_store_with_get_latest_version.set_model_version_tag.assert_called_once_with(
1932          "model_name", 1, ModelVersionTag(key="tag1", value="foobar")
1933      )
1934  
1935      # set_model_version_tag with version and stage set
1936      with pytest.raises(MlflowException, match="version and stage cannot be set together"):
1937          MlflowClient().set_model_version_tag("model_name", 1, "tag1", "foobar", stage="Staging")
1938  
1939      # set_model_version_tag with version and stage not set
1940      with pytest.raises(MlflowException, match="version or stage must be set"):
1941          MlflowClient().set_model_version_tag("model_name", key="tag1", value="foobar")
1942  
1943  
1944  def test_delete_model_version_tag(mock_registry_store_with_get_latest_version):
1945      # delete_model_version_tag using version
1946      MlflowClient().delete_model_version_tag("model_name", 1, "tag1")
1947      mock_registry_store_with_get_latest_version.delete_model_version_tag.assert_called_once_with(
1948          "model_name", 1, "tag1"
1949      )
1950  
1951      mock_registry_store_with_get_latest_version.delete_model_version_tag.reset_mock()
1952  
1953      # delete_model_version_tag using stage
1954      MlflowClient().delete_model_version_tag("model_name", key="tag1", stage="Staging")
1955      mock_registry_store_with_get_latest_version.delete_model_version_tag.assert_called_once_with(
1956          "model_name", 1, "tag1"
1957      )
1958  
1959      # delete_model_version_tag with version and stage set
1960      with pytest.raises(MlflowException, match="version and stage cannot be set together"):
1961          MlflowClient().delete_model_version_tag(
1962              "model_name", version=1, key="tag1", stage="staging"
1963          )
1964  
1965      # delete_model_version_tag with version and stage not set
1966      with pytest.raises(MlflowException, match="version or stage must be set"):
1967          MlflowClient().delete_model_version_tag("model_name", key="tag1")
1968  
1969  
1970  def test_set_registered_model_alias(mock_registry_store):
1971      MlflowClient().set_registered_model_alias("model_name", "test_alias", 1)
1972      mock_registry_store.set_registered_model_alias.assert_called_once_with(
1973          "model_name", "test_alias", 1
1974      )
1975  
1976  
1977  def test_delete_registered_model_alias(mock_registry_store):
1978      MlflowClient().delete_registered_model_alias("model_name", "test_alias")
1979      mock_registry_store.delete_registered_model_alias.assert_called_once_with(
1980          "model_name", "test_alias"
1981      )
1982  
1983  
1984  def test_get_model_version_by_alias(mock_registry_store):
1985      mock_registry_store.get_model_version_by_alias.return_value = _default_model_version()
1986      res = MlflowClient().get_model_version_by_alias("model_name", "test_alias")
1987      assert res == _default_model_version()
1988      mock_registry_store.get_model_version_by_alias.assert_called_once_with(
1989          "model_name", "test_alias"
1990      )
1991  
1992  
1993  def test_update_run(mock_store):
1994      MlflowClient().update_run(run_id="run_id", status="FINISHED", name="my name")
1995      mock_store.update_run_info.assert_called_once_with(
1996          run_id="run_id",
1997          run_status=RunStatus.from_string("FINISHED"),
1998          end_time=mock.ANY,
1999          run_name="my name",
2000      )
2001  
2002  
2003  def test_client_log_metric_params_tags_overrides(mock_store):
2004      experiment_id = mock.Mock()
2005      start_time = mock.Mock()
2006      run_name = mock.Mock()
2007      run = MlflowClient().create_run(experiment_id, start_time, tags={}, run_name=run_name)
2008      run_id = run.info.run_id
2009  
2010      run_operation = MlflowClient().log_metric(run_id, "m1", 0.87, 123456789, 1, synchronous=False)
2011      run_operation.wait()
2012  
2013      run_operation = MlflowClient().log_param(run_id, "p1", "pv1", synchronous=False)
2014      run_operation.wait()
2015  
2016      run_operation = MlflowClient().set_tag(run_id, "t1", "tv1", synchronous=False)
2017      run_operation.wait()
2018  
2019      mock_store.log_metric_async.assert_called_once_with(run_id, Metric("m1", 0.87, 123456789, 1))
2020      mock_store.log_param_async.assert_called_once_with(run_id, Param("p1", "pv1"))
2021      mock_store.set_tag_async.assert_called_once_with(run_id, RunTag("t1", "tv1"))
2022  
2023      mock_store.reset_mock()
2024  
2025      # log_batch_async
2026      MlflowClient().create_run(experiment_id, start_time, {})
2027      metrics = [Metric("m1", 0.87, 123456789, 1), Metric("m2", 0.87, 123456789, 1)]
2028      tags = [RunTag("t1", "tv1"), RunTag("t2", "tv2")]
2029      params = [Param("p1", "pv1"), Param("p2", "pv2")]
2030      run_operation = MlflowClient().log_batch(run_id, metrics, params, tags, synchronous=False)
2031      run_operation.wait()
2032  
2033      mock_store.log_batch_async.assert_called_once_with(
2034          run_id=run_id, metrics=metrics, params=params, tags=tags
2035      )
2036  
2037  
2038  def test_invalid_run_id_log_artifact():
2039      with pytest.raises(
2040          MlflowException,
2041          match=r"Invalid run id.*",
2042      ):
2043          MlflowClient().log_artifact("tr-123", "path")
2044  
2045  
2046  def test_enable_async_logging(mock_store, setup_async_logging):
2047      MlflowClient().log_param(run_id="run_id", key="key", value="val")
2048      mock_store.log_param_async.assert_called_once_with("run_id", Param("key", "val"))
2049  
2050      MlflowClient().log_metric(run_id="run_id", key="key", value="val", step=1, timestamp=1)
2051      mock_store.log_metric_async.assert_called_once_with("run_id", Metric("key", "val", 1, 1))
2052  
2053  
2054  def test_file_store_download_upload_trace_data(tmp_path):
2055      pytest.skip("FileStore is no longer supported.")
2056      with _use_tracking_uri(tmp_path.joinpath("mlruns").as_uri()):
2057          client = MlflowClient()
2058          span = client.start_trace("test", inputs={"test": 1})
2059          client.end_trace(span.trace_id, outputs={"result": 2})
2060          trace = mlflow.get_trace(span.trace_id, flush=True)
2061          trace_data = client.get_trace(span.trace_id, flush=True).data
2062          assert trace_data.request == trace.data.request
2063          assert trace_data.response == trace.data.response
2064  
2065  
2066  def test_get_trace_throw_if_trace_id_is_online_trace_id(db_uri):
2067      client = MlflowClient("databricks")
2068      trace_id = "3a3c3b56-910a-4721-8d02-0333eda5f37e"
2069      with pytest.raises(MlflowException, match="Traces from inference tables can only be loaded"):
2070          client.get_trace(trace_id)
2071  
2072      another_client = MlflowClient(db_uri)
2073      with pytest.raises(MlflowException, match=r"Trace with ID '[\w-]+' not found"):
2074          another_client.get_trace(trace_id)
2075  
2076  
2077  @pytest.fixture(params=["file", "sqlalchemy"])
2078  def registry_uri(request, tmp_path, db_uri):
2079      """Set an MLflow Model Registry URI with different type of backend."""
2080      if request.param == "file":
2081          pytest.skip("FileStore is no longer supported.")
2082      if "MLFLOW_SKINNY" in os.environ and request.param == "sqlalchemy":
2083          pytest.skip("SQLAlchemy store is not available in skinny.")
2084  
2085      original_registry_uri = mlflow.get_registry_uri()
2086  
2087      if request.param == "file":
2088          registry_uri = tmp_path.joinpath("file").as_uri()
2089      elif request.param == "sqlalchemy":
2090          registry_uri = db_uri
2091  
2092      yield registry_uri
2093  
2094      # Reset tracking URI
2095      mlflow.set_tracking_uri(original_registry_uri)
2096  
2097  
2098  def test_crud_prompts(tracking_uri):
2099      client = MlflowClient(tracking_uri=tracking_uri)
2100  
2101      client.register_prompt(
2102          name="prompt_1",
2103          template="Hi, {{title}} {{name}}! How are you today?",
2104          commit_message="A friendly greeting",
2105      )
2106  
2107      prompt = client.load_prompt("prompt_1", version=1)
2108      assert prompt.name == "prompt_1"
2109      assert prompt.template == "Hi, {{title}} {{name}}! How are you today?"
2110      assert prompt.commit_message == "A friendly greeting"
2111  
2112      client.register_prompt(
2113          name="prompt_1",
2114          template="Hi, {{title}} {{name}}! What's up?",
2115          commit_message="New greeting",
2116      )
2117  
2118      prompt = client.load_prompt("prompt_1", version=2)
2119      assert prompt.template == "Hi, {{title}} {{name}}! What's up?"
2120  
2121      prompt = client.load_prompt("prompt_1", version=1)
2122      assert prompt.template == "Hi, {{title}} {{name}}! How are you today?"
2123  
2124      prompt = client.load_prompt("prompts:/prompt_1/2")
2125      assert prompt.template == "Hi, {{title}} {{name}}! What's up?"
2126  
2127      # Test loading non-existent prompts
2128      assert mlflow.load_prompt("does_not_exist", version=1, allow_missing=True) is None
2129  
2130  
2131  def test_create_prompt_with_tags_and_metadata(tracking_uri, disable_prompt_cache):
2132      def wait_for_prompt_linking():
2133          """Wait for background prompt linking threads to complete."""
2134          for t in threading.enumerate():
2135              if t.name.startswith("link_prompt_to_experiment_thread"):
2136                  t.join(timeout=5.0)
2137                  if t.is_alive():
2138                      raise TimeoutError(f"Thread {t.name} did not complete within timeout.")
2139  
2140      client = MlflowClient(tracking_uri=tracking_uri)
2141  
2142      # Create prompt with version-specific tags
2143      client.register_prompt(
2144          name="prompt_1",
2145          template="Hi, {{name}}!",
2146          tags={"author": "Alice"},  # This will be version-level tags now
2147      )
2148  
2149      # Wait for the background linking thread to complete
2150      wait_for_prompt_linking()
2151  
2152      # Set some prompt-level tags separately
2153      client.set_prompt_tag("prompt_1", "application", "greeting")
2154      client.set_prompt_tag("prompt_1", "language", "en")
2155  
2156      # Test version 1
2157      prompt_v1 = client.load_prompt("prompt_1", version=1)
2158      assert prompt_v1.template == "Hi, {{name}}!"
2159      # Version tags are separate from prompt tags
2160      assert prompt_v1.tags == {"author": "Alice"}
2161  
2162      # Wait for the background linking thread from load_prompt
2163      wait_for_prompt_linking()
2164  
2165      # Test prompt-level tags (separate from version)
2166      prompt_entity = client.get_prompt("prompt_1")
2167      # Note: Currently includes the version tags too, but we expect this behavior to change
2168      assert prompt_entity.tags == {
2169          "author": "Alice",  # This appears due to current implementation
2170          "application": "greeting",
2171          "language": "en",
2172          "_mlflow_experiment_ids": ",0,",  # Linked to Default experiment
2173      }
2174  
2175      # Create version 2 with different version-level tags
2176      client.register_prompt(
2177          name="prompt_1",
2178          template="こんにちは、{{name}}!",
2179          tags={"author": "Bob", "date": "2022-01-01"},  # Version-level tags
2180      )
2181  
2182      # Wait for the background linking thread from register_prompt
2183      wait_for_prompt_linking()
2184  
2185      # Update some prompt-level tags
2186      client.set_prompt_tag("prompt_1", "project", "toy")
2187      client.set_prompt_tag("prompt_1", "language", "ja")
2188  
2189      # Test version 2
2190      prompt_v2 = client.load_prompt("prompt_1", version=2)
2191      assert prompt_v2.template == "こんにちは、{{name}}!"
2192      # Version 2 has its own version tags (decoupled from prompt and version 1)
2193      assert prompt_v2.tags == {"author": "Bob", "date": "2022-01-01"}
2194  
2195      # Wait for the background linking thread from load_prompt
2196      wait_for_prompt_linking()
2197  
2198      # Verify prompt-level tags are updated and separate
2199      prompt_entity_updated = client.get_prompt("prompt_1")
2200      # Note: Currently the prompt tags get overwritten by the newest version's tags
2201      assert prompt_entity_updated.tags == {
2202          "author": "Bob",  # This appears due to current implementation
2203          "date": "2022-01-01",  # This appears due to current implementation
2204          "application": "greeting",
2205          "project": "toy",
2206          "language": "ja",
2207          "_mlflow_experiment_ids": ",0,",  # Linked to Default experiment
2208      }
2209  
2210      # Version 1 tags should be unchanged (decoupled from prompt tags)
2211      prompt_v1_after_update = client.load_prompt("prompt_1", version=1)
2212      assert prompt_v1_after_update.tags == {"author": "Alice"}  # Unchanged
2213  
2214  
2215  def test_create_prompt_error_handling(tracking_uri, disable_prompt_cache):
2216      client = MlflowClient(tracking_uri=tracking_uri)
2217  
2218      # Exceeds the max length
2219      with pytest.raises(MlflowException, match=r"Prompt text exceeds max length of"):
2220          client.register_prompt(name="prompt_1", template="a" * 100_001)
2221  
2222      # When the first version creation fails, RegisteredModel should not be created
2223      with pytest.raises(MlflowException, match=r"Prompt with name=prompt_1 not found"):
2224          client.load_prompt("prompt_1", version=1)
2225  
2226      client.register_prompt("prompt_1", template="Hi, {{title}} {{name}}!")
2227      assert client.load_prompt("prompt_1", version=1) is not None
2228  
2229      # When the subsequent version creation fails, RegisteredModel should remain
2230      with pytest.raises(MlflowException, match=r"Prompt text exceeds max length of"):
2231          client.register_prompt(name="prompt_1", template="a" * 100_001)
2232  
2233      assert client.load_prompt("prompt_1", version=1) is not None
2234  
2235  
2236  def test_create_prompt_with_invalid_name(tracking_uri):
2237      client = MlflowClient(tracking_uri=tracking_uri)
2238  
2239      with pytest.raises(MlflowException, match=r"Prompt name must be a non-empty string"):
2240          client.register_prompt(name="", template="Hi, {{name}}!")
2241  
2242      with pytest.raises(MlflowException, match=r"Prompt name must be a non-empty string"):
2243          client.register_prompt(name=123, template="Hi, {{name}}!")
2244  
2245      for invalid_pattern in [
2246          "prompt_1/2",
2247          "m%6fdel",
2248          "prompt?!?",
2249          "prompt with space",
2250      ]:
2251          with pytest.raises(MlflowException, match=r"Prompt name can only contain alphanumeric"):
2252              client.register_prompt(name=invalid_pattern, template="Hi, {{name}}!")
2253  
2254      # Name conflicts with a model
2255      client.create_registered_model("model")
2256      with pytest.raises(MlflowException, match=r"Model 'model' exists with the same name."):
2257          client.register_prompt(name="model", template="Hi, {{name}}!")
2258  
2259  
2260  def test_load_prompt_error(tracking_uri):
2261      client = MlflowClient(tracking_uri=tracking_uri)
2262  
2263      with pytest.raises(MlflowException, match=r"Prompt with name=test not found"):
2264          client.load_prompt("test", version=1)
2265  
2266      # Both file and sqlalchemy return the same error format now
2267      error_msg = r"Prompt with name=test not found"
2268  
2269      with pytest.raises(MlflowException, match=error_msg):
2270          client.load_prompt("test", version=2)
2271  
2272      with pytest.raises(MlflowException, match=error_msg):
2273          client.load_prompt("test", version=2, allow_missing=False)
2274  
2275      # Load prompt with a model name
2276      client.create_registered_model("model")
2277      client.create_model_version("model", "source")
2278  
2279      with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"):
2280          client.load_prompt("model", version=1)
2281  
2282      with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"):
2283          client.load_prompt("model", version=1)
2284  
2285      with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"):
2286          client.load_prompt("model", version=1, allow_missing=False)
2287  
2288      with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"):
2289          client.load_prompt("model", version=1, allow_missing=False)
2290  
2291  
2292  def test_link_prompt_version_to_run(tracking_uri):
2293      client = MlflowClient(tracking_uri=tracking_uri)
2294  
2295      prompt = client.register_prompt("prompt", template="Hi, {{name}}!")
2296  
2297      # Create actual runs to link to
2298      run1 = client.create_run(experiment_id="0").info.run_id
2299      run2 = client.create_run(experiment_id="0").info.run_id
2300  
2301      # Test that the method can be called without error
2302      client.link_prompt_version_to_run(run1, prompt)
2303      client.link_prompt_version_to_run(run2, prompt)
2304  
2305      # Verify tag was set by checking the run data
2306      run_data = client.get_run(run1)
2307      linked_prompts_tag = run_data.data.tags.get("mlflow.linkedPrompts")
2308      assert linked_prompts_tag is not None
2309  
2310      # Verify the JSON structure
2311      linked_prompts = json.loads(linked_prompts_tag)
2312      assert any(p["name"] == "prompt" and p["version"] == "1" for p in linked_prompts)
2313  
2314      # Test error case
2315      with pytest.raises(MlflowException, match=r"The `prompt` argument must be"):
2316          client.link_prompt_version_to_run(run1, 123)
2317  
2318  
2319  @pytest.mark.parametrize("registry_uri", ["databricks"])
2320  def test_crud_prompt_on_unsupported_registry(registry_uri):
2321      client = MlflowClient(registry_uri=registry_uri)
2322  
2323      with pytest.raises(MlflowException, match=r"The 'register_prompt' API is not supported"):
2324          client.register_prompt(
2325              name="prompt_1",
2326              template="Hi, {{title}} {{name}}! How are you today?",
2327              commit_message="A friendly greeting",
2328              tags={"model": "my-model"},
2329          )
2330  
2331      with pytest.raises(MlflowException, match=r"The 'load_prompt' API is not supported"):
2332          client.load_prompt("prompt_1")
2333  
2334  
2335  def test_block_create_model_with_prompt_tag(tracking_uri):
2336      client = MlflowClient(tracking_uri=tracking_uri)
2337  
2338      with pytest.raises(MlflowException, match=r"Prompts cannot be registered"):
2339          client.create_registered_model(
2340              name="model",
2341              tags={IS_PROMPT_TAG_KEY: "true"},
2342          )
2343  
2344      with pytest.raises(MlflowException, match=r"Prompts cannot be registered"):
2345          client.create_model_version(
2346              name="model",
2347              source="source",
2348              tags={IS_PROMPT_TAG_KEY: "false"},
2349          )
2350  
2351  
2352  def test_block_create_prompt_with_existing_model_name(tracking_uri):
2353      client = MlflowClient(tracking_uri=tracking_uri)
2354  
2355      client.create_registered_model("model")
2356  
2357      with pytest.raises(MlflowException, match=r"Model 'model' exists with"):
2358          client.register_prompt(
2359              name="model",
2360              template="Hi, {{title}} {{name}}! How are you today?",
2361              commit_message="A friendly greeting",
2362              tags={"model": "my-model"},
2363          )
2364  
2365  
2366  def test_block_handling_prompt_with_model_apis(tracking_uri):
2367      client = MlflowClient(tracking_uri=tracking_uri)
2368      client.register_prompt("prompt", template="Hi, {{name}}!")
2369      client.set_prompt_alias("prompt", alias="alias", version=1)
2370      # Validate the prompt is registered
2371      prompt = client.load_prompt("prompt", version=1)
2372      assert prompt.name == "prompt"
2373      assert prompt.aliases == ["alias"]
2374  
2375      apis_to_args = [
2376          (client.rename_registered_model, ["prompt", "new_name"]),
2377          (client.update_registered_model, ["prompt", "new_description"]),
2378          (client.delete_registered_model, ["prompt"]),
2379          (client.get_registered_model, ["prompt"]),
2380          (client.get_latest_versions, ["prompt"]),
2381          (client.set_registered_model_tag, ["prompt", "tag", "value"]),
2382          (client.delete_registered_model_tag, ["prompt", "tag"]),
2383          (client.update_model_version, ["prompt", 1, "new_description"]),
2384          (client.transition_model_version_stage, ["prompt", 1, "Production"]),
2385          (client.delete_model_version, ["prompt", 1]),
2386          (client.get_model_version, ["prompt", 1]),
2387          (client.get_model_version_download_uri, ["prompt", 1]),
2388          (client.set_model_version_tag, ["prompt", 1, "tag", "value"]),
2389          (client.delete_model_version_tag, ["prompt", 1, "tag"]),
2390          (client.set_registered_model_alias, ["prompt", "alias", 1]),
2391          (client.delete_registered_model_alias, ["prompt", "alias"]),
2392          (client.get_model_version_by_alias, ["prompt", "alias"]),
2393      ]
2394  
2395      for api, args in apis_to_args:
2396          with pytest.raises(MlflowException, match=r"Registered Model with name='prompt' not found"):
2397              api(*args)
2398  
2399      with pytest.raises(MlflowException, match=r"Model with uri 'models:/prompt/1' not found"):
2400          client.copy_model_version("models:/prompt/1", "new_model")
2401  
2402  
2403  def test_log_and_detach_prompt(tracking_uri):
2404      client = MlflowClient(tracking_uri=tracking_uri)
2405  
2406      client.register_prompt(name="p1", template="Hi, there!")
2407      time.sleep(0.001)  # To avoid timestamp precision issue in Windows
2408      client.register_prompt(name="p2", template="Hi, {{name}}!")
2409  
2410      run_id = client.create_run(experiment_id="0").info.run_id
2411  
2412      # Check that initially no prompts are linked to the run
2413      run = client.get_run(run_id)
2414      linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS)
2415      assert linked_prompts_tag is None
2416  
2417      client.link_prompt_version_to_run(run_id, "prompts:/p1/1")
2418      run = client.get_run(run_id)
2419      linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS)
2420      assert linked_prompts_tag is not None
2421      prompts = json.loads(linked_prompts_tag)
2422      assert len(prompts) == 1
2423      assert prompts[0]["name"] == "p1"
2424  
2425      client.link_prompt_version_to_run(run_id, "prompts:/p2/1")
2426      run = client.get_run(run_id)
2427      linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS)
2428      prompts = json.loads(linked_prompts_tag)
2429      assert len(prompts) == 2
2430      prompt_names = [p["name"] for p in prompts]
2431      assert "p1" in prompt_names
2432      assert "p2" in prompt_names
2433  
2434  
2435  def test_search_prompt(tracking_uri):
2436      client = MlflowClient(tracking_uri=tracking_uri)
2437  
2438      client.register_prompt(name="prompt_1", template="Hi, {{name}}!")
2439      client.register_prompt(name="prompt_2", template="Hello, {{name}}!")
2440      client.register_prompt(name="prompt_3", template="Greetings, {{name}}!")
2441      client.register_prompt(name="prompt_4", template="Howdy, {{name}}!")
2442      client.register_prompt(name="prompt_5", template="Salutations, {{name}}!")
2443      client.register_prompt(name="prompt_6", template="Bonjour, {{name}}!")
2444      client.register_prompt(name="test", template="Test Template")
2445      client.register_prompt(name="new", template="Bonjour, {{name}}!")
2446  
2447      prompts = client.search_prompts(filter_string="name='prompt_1'")
2448      assert len(prompts) == 1
2449      assert prompts[0].name == "prompt_1"
2450  
2451      prompts = client.search_prompts(filter_string="name LIKE '%prompt%'")
2452      assert len(prompts) == 6
2453      assert all("prompt" in prompt.name for prompt in prompts)
2454  
2455      prompts = client.search_prompts()
2456      assert len(prompts) == 8
2457  
2458      prompts = client.search_prompts(max_results=3)
2459      assert len(prompts) == 3
2460  
2461  
2462  def test_delete_prompt_version_no_auto_cleanup(tracking_uri):
2463      client = MlflowClient(tracking_uri=tracking_uri)
2464  
2465      # Create prompt and version
2466      client.register_prompt(name="test_prompt", template="Hello {{name}}!")
2467  
2468      # Verify prompt and version exist
2469      prompt = client.get_prompt("test_prompt")
2470      assert prompt is not None
2471      assert prompt.name == "test_prompt"
2472  
2473      prompt_version = client.get_prompt_version("test_prompt", 1)
2474      assert prompt_version is not None
2475      assert prompt_version.version == 1
2476  
2477      # Delete the version - prompt should remain
2478      client.delete_prompt_version("test_prompt", "1")
2479  
2480      # Prompt should still exist even though it has no versions
2481      prompt = client.get_prompt("test_prompt")
2482      assert prompt is not None
2483      assert prompt.name == "test_prompt"
2484  
2485      # Version should be gone
2486      with pytest.raises(MlflowException, match=r"Prompt.*name=test_prompt.*version=1.*not found"):
2487          client.get_prompt_version("test_prompt", 1)
2488  
2489  
2490  def test_delete_prompt_version_invalidates_cached_load_prompt(tracking_uri):
2491      client = MlflowClient(tracking_uri=tracking_uri)
2492  
2493      prompt_ver = client.register_prompt(name="test_prompt", template="Version 1")
2494      loaded = client.load_prompt(prompt_ver.name, version=prompt_ver.version)
2495      assert loaded.template == "Version 1"
2496  
2497      client.delete_prompt_version(prompt_ver.name, str(prompt_ver.version))
2498  
2499      with pytest.raises(
2500          MlflowException,
2501          match=rf"Prompt.*name={prompt_ver.name}.*version={prompt_ver.version}.*not found",
2502      ):
2503          client.get_prompt_version(prompt_ver.name, prompt_ver.version)
2504  
2505      with pytest.raises(
2506          MlflowException,
2507          match=rf"Prompt.*name={prompt_ver.name}.*version={prompt_ver.version}.*not found",
2508      ):
2509          client.load_prompt(prompt_ver.name, version=prompt_ver.version)
2510  
2511  
2512  def test_delete_prompt_version_invalidates_latest_cache(tracking_uri):
2513      client = MlflowClient(tracking_uri=tracking_uri)
2514  
2515      prompt_v1 = client.register_prompt(name="test_prompt", template="Version 1")
2516      prompt_v2 = client.register_prompt(name=prompt_v1.name, template="Version 2")
2517  
2518      latest_prompt = client.load_prompt(f"prompts:/{prompt_v1.name}@latest")
2519      assert latest_prompt.version == prompt_v2.version
2520      assert latest_prompt.template == prompt_v2.template
2521  
2522      client.delete_prompt_version(prompt_v2.name, str(prompt_v2.version))
2523  
2524      latest_prompt_after_delete = client.load_prompt(f"prompts:/{prompt_v1.name}@latest")
2525      assert latest_prompt_after_delete.version == prompt_v1.version
2526      assert latest_prompt_after_delete.template == prompt_v1.template
2527  
2528  
2529  def test_set_prompt_model_config_invalidates_latest_cache(tracking_uri):
2530      client = MlflowClient(tracking_uri=tracking_uri)
2531  
2532      cache_ttl_seconds = 60
2533      prompt = client.register_prompt(name="test_prompt", template="test")
2534      prompt_before_update = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds)
2535      assert prompt_before_update.model_config is None
2536  
2537      model_config = {"model_name": "gpt-4", "temperature": 0.7}
2538      mlflow.genai.set_prompt_model_config(
2539          name=prompt.name,
2540          version=prompt.version,
2541          model_config=model_config,
2542      )
2543  
2544      prompt_after_update = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds)
2545      assert prompt_after_update.model_config == model_config
2546  
2547  
2548  def test_delete_prompt_model_config_invalidates_latest_cache(tracking_uri):
2549      client = MlflowClient(tracking_uri=tracking_uri)
2550  
2551      cache_ttl_seconds = 60
2552      model_config = {"model_name": "gpt-4", "temperature": 0.7}
2553      prompt = client.register_prompt(
2554          name="test_prompt",
2555          template="test",
2556          model_config=model_config,
2557      )
2558      prompt_before_delete = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds)
2559      assert prompt_before_delete.model_config == model_config
2560  
2561      mlflow.genai.delete_prompt_model_config(name=prompt.name, version=prompt.version)
2562  
2563      prompt_after_delete = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds)
2564      assert prompt_after_delete.model_config is None
2565  
2566  
2567  def test_delete_prompt_version_invalidates_alias_cache(tracking_uri):
2568      client = MlflowClient(tracking_uri=tracking_uri)
2569  
2570      prompt_v1 = client.register_prompt(name="test_prompt", template="Version 1")
2571      client.register_prompt(name=prompt_v1.name, template="Version 2")
2572      client.set_prompt_alias(prompt_v1.name, alias="production", version=prompt_v1.version)
2573  
2574      aliased_prompt = client.load_prompt(f"prompts:/{prompt_v1.name}@production")
2575      assert aliased_prompt.version == prompt_v1.version
2576      assert aliased_prompt.template == prompt_v1.template
2577  
2578      client.delete_prompt_version(prompt_v1.name, str(prompt_v1.version))
2579  
2580      with pytest.raises(
2581          MlflowException,
2582          match=(
2583              r"Prompt (.*) does not exist.|Prompt alias (.*) not found.|"
2584              rf"Prompt.*version={prompt_v1.version}.*not found"
2585          ),
2586      ):
2587          client.load_prompt(f"prompts:/{prompt_v1.name}@production")
2588  
2589  
2590  def test_delete_prompt_with_no_versions(tracking_uri):
2591      client = MlflowClient(tracking_uri=tracking_uri)
2592      mlflow.set_experiment("test_delete_prompt_with_no_versions")
2593  
2594      # Create prompt and version, then delete version
2595      client.register_prompt(name="empty_prompt", template="Hello {{name}}!")
2596      client.delete_prompt_version("empty_prompt", "1")
2597  
2598      # Verify prompt exists but has no versions
2599      prompt = client.get_prompt("empty_prompt")
2600      assert prompt is not None
2601  
2602      # Delete the prompt - should work regardless of registry type
2603      client.delete_prompt("empty_prompt")
2604  
2605      # Prompt should be gone
2606      prompt = client.get_prompt("empty_prompt")
2607      assert prompt is None
2608  
2609  
2610  def test_delete_prompt_invalidates_cached_load_prompt(tracking_uri):
2611      client = MlflowClient(tracking_uri=tracking_uri)
2612  
2613      prompt_ver = client.register_prompt(name="test_prompt", template="Version 1")
2614      loaded = client.load_prompt(prompt_ver.name, version=prompt_ver.version)
2615      assert loaded.template == "Version 1"
2616  
2617      client.delete_prompt(prompt_ver.name)
2618  
2619      assert client.get_prompt(prompt_ver.name) is None
2620  
2621      with pytest.raises(MlflowException, match=rf"Prompt.*name={prompt_ver.name}.*not found"):
2622          client.load_prompt(prompt_ver.name, version=prompt_ver.version)
2623  
2624  
2625  def test_delete_prompt_complete_workflow(tracking_uri):
2626      client = MlflowClient(tracking_uri=tracking_uri)
2627  
2628      # Create prompt with multiple versions
2629      client.register_prompt(name="workflow_prompt", template="Version 1: {{name}}")
2630      client.register_prompt(name="workflow_prompt", template="Version 2: {{name}}")
2631      client.register_prompt(name="workflow_prompt", template="Version 3: {{name}}")
2632  
2633      # Verify all versions exist
2634      v1 = client.get_prompt_version("workflow_prompt", 1)
2635      v2 = client.get_prompt_version("workflow_prompt", 2)
2636      v3 = client.get_prompt_version("workflow_prompt", 3)
2637      assert v1.template == "Version 1: {{name}}"
2638      assert v2.template == "Version 2: {{name}}"
2639      assert v3.template == "Version 3: {{name}}"
2640  
2641      # Delete versions one by one
2642      client.delete_prompt_version("workflow_prompt", "1")
2643      client.delete_prompt_version("workflow_prompt", "2")
2644      client.delete_prompt_version("workflow_prompt", "3")
2645  
2646      # Prompt should still exist with no versions
2647      prompt = client.get_prompt("workflow_prompt")
2648      assert prompt is not None
2649  
2650      # Now delete the prompt itself
2651      client.delete_prompt("workflow_prompt")
2652  
2653      # Prompt should be completely gone
2654      prompt = client.get_prompt("workflow_prompt")
2655      assert prompt is None
2656  
2657  
2658  def test_delete_prompt_error_handling(tracking_uri):
2659      client = MlflowClient(tracking_uri=tracking_uri)
2660  
2661      # Test deleting non-existent prompt
2662      with pytest.raises(MlflowException, match=r"Prompt with name=nonexistent not found"):
2663          client.delete_prompt("nonexistent")
2664  
2665      # Test deleting non-existent version
2666      client.register_prompt(name="test_errors", template="Hello {{name}}!")
2667      with pytest.raises(MlflowException, match=r"Prompt.*name=test_errors.*version=999.*not found"):
2668          client.delete_prompt_version("test_errors", "999")
2669  
2670  
2671  def test_delete_prompt_version_behavior_consistency(tracking_uri):
2672      client = MlflowClient(tracking_uri=tracking_uri)
2673  
2674      # Create multiple prompts with versions
2675      for i in range(3):
2676          prompt_name = f"consistency_test_{i}"
2677          client.register_prompt(name=prompt_name, template=f"Template {i}: {{{{name}}}}")
2678  
2679          # Delete the version immediately
2680          client.delete_prompt_version(prompt_name, "1")
2681  
2682          # Prompt should remain but have no versions
2683          prompt = client.get_prompt(prompt_name)
2684          assert prompt is not None
2685          assert prompt.name == prompt_name
2686  
2687          # Version should be gone
2688          with pytest.raises(MlflowException, match=r"Prompt.*version.*not found"):
2689              client.get_prompt_version(prompt_name, 1)
2690  
2691      # Clean up - delete all prompts
2692      for i in range(3):
2693          client.delete_prompt(f"consistency_test_{i}")
2694          prompt = client.get_prompt(f"consistency_test_{i}")
2695          assert prompt is None
2696  
2697  
2698  @pytest.mark.parametrize("registry_uri", ["databricks-uc"])
2699  def test_delete_prompt_with_versions_unity_catalog_error(registry_uri):
2700      # Mock Unity Catalog behavior
2701      client = MlflowClient(registry_uri=registry_uri)
2702  
2703      # Mock the search_prompt_versions to return a PagedList with versions
2704      mock_versions = PagedList([Mock(version="1")], None)
2705  
2706      with (
2707          patch.object(client, "search_prompt_versions", return_value=mock_versions),
2708          patch.object(client, "_registry_uri", registry_uri),
2709      ):
2710          with pytest.raises(
2711              MlflowException, match=r"Cannot delete prompt .* because it still has undeleted"
2712          ):
2713              client.delete_prompt("test_prompt")
2714  
2715  
2716  def test_link_prompt_version_to_model_smoke_test(tracking_uri):
2717      client = MlflowClient(tracking_uri=tracking_uri)
2718  
2719      # Create an experiment and a run to have a proper context
2720      experiment_id = client.create_experiment("test_experiment")
2721      with mlflow.start_run(experiment_id=experiment_id):
2722          # Create a model with a run context
2723          model = client.create_logged_model(experiment_id=experiment_id)
2724  
2725          # Register a prompt
2726          client.register_prompt(name="test_prompt", template="Hello, {{name}}!")
2727  
2728          # Link the prompt version to the model (this should not raise an exception)
2729          # This is the main assertion - that the method call succeeds
2730          client.link_prompt_version_to_model(
2731              name="test_prompt", version="1", model_id=model.model_id
2732          )
2733  
2734  
2735  def test_link_prompts_to_trace_smoke_test(tracking_uri):
2736      client = MlflowClient(tracking_uri=tracking_uri)
2737  
2738      # Create an experiment and a run to have a proper context
2739      experiment_id = client.create_experiment("test_experiment")
2740      with mlflow.start_run(experiment_id=experiment_id):
2741          # Create a simple trace for testing
2742          trace_info = client.start_trace("test_trace")
2743          trace_id = trace_info.request_id
2744  
2745          # Register a prompt
2746          client.register_prompt(name="test_prompt", template="Hello, {{name}}!")
2747  
2748          # Get the prompt version and link to the trace (this should not raise an exception)
2749          # This is the main assertion - that the method call succeeds
2750          prompt_version = client.get_prompt_version("test_prompt", "1")
2751          client.link_prompt_versions_to_trace(prompt_versions=[prompt_version], trace_id=trace_id)
2752  
2753  
2754  def test_log_model_artifact(tmp_path: Path, tracking_uri: str) -> None:
2755      client = MlflowClient(tracking_uri=tracking_uri)
2756      experiment_id = client.create_experiment("test")
2757      model = client.create_logged_model(experiment_id=experiment_id)
2758      tmp_path = tmp_path.joinpath("artifacts")
2759      tmp_path.mkdir()
2760      tmp_file = tmp_path.joinpath("file")
2761      tmp_file.write_text("a")
2762      client.log_model_artifact(model_id=model.model_id, local_path=str(tmp_file))
2763      artifacts = client.list_logged_model_artifacts(model_id=model.model_id)
2764      assert artifacts == [FileInfo(path="file", is_dir=False, file_size=1)]
2765      another_tmp_file = tmp_path.joinpath("another_file")
2766      another_tmp_file.write_text("aa")
2767      client.log_model_artifact(model_id=model.model_id, local_path=str(another_tmp_file))
2768      artifacts = client.list_logged_model_artifacts(model_id=model.model_id)
2769      artifacts = sorted(artifacts, key=lambda x: x.path)
2770      assert artifacts == [
2771          FileInfo(path="another_file", is_dir=False, file_size=2),
2772          FileInfo(path="file", is_dir=False, file_size=1),
2773      ]
2774  
2775  
2776  def test_log_model_artifacts(tmp_path: Path, tracking_uri: str) -> None:
2777      client = MlflowClient(tracking_uri=tracking_uri)
2778      experiment_id = client.create_experiment("test")
2779      model = client.create_logged_model(experiment_id=experiment_id)
2780      tmp_path = tmp_path.joinpath("artifacts")
2781      tmp_path.mkdir()
2782      tmp_file = tmp_path.joinpath("file")
2783      tmp_file.write_text("a")
2784      tmp_dir = tmp_path.joinpath("dir")
2785      tmp_dir.mkdir()
2786      another_file = tmp_dir.joinpath("another_file")
2787      another_file.write_text("aa")
2788      client.log_model_artifacts(model_id=model.model_id, local_dir=str(tmp_path))
2789      artifacts = client.list_logged_model_artifacts(model_id=model.model_id)
2790      artifacts = sorted(artifacts, key=lambda x: x.path)
2791      assert artifacts == [
2792          FileInfo(path="dir", is_dir=True, file_size=None),
2793          FileInfo(path="file", is_dir=False, file_size=1),
2794      ]
2795      artifacts = client.list_logged_model_artifacts(model_id=model.model_id, path="dir")
2796      assert artifacts == [FileInfo(path="dir/another_file", is_dir=False, file_size=2)]
2797  
2798  
2799  def test_logged_model_model_id_required(tracking_uri):
2800      client = MlflowClient(tracking_uri=tracking_uri)
2801  
2802      with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"):
2803          client.finalize_logged_model("", LoggedModelStatus.READY)
2804  
2805      with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"):
2806          client.get_logged_model("")
2807  
2808      with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"):
2809          client.delete_logged_model("")
2810  
2811      with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"):
2812          client.set_logged_model_tags("", {})
2813  
2814      with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"):
2815          client.delete_logged_model_tag("", "")
2816  
2817  
2818  @pytest.mark.skipif(
2819      "MLFLOW_SKINNY" in os.environ,
2820      reason="Skinny client does not support the np or pandas dependencies",
2821  )
2822  def test_log_metric_link_to_active_model(tracking_uri):
2823      model = mlflow.create_external_model(name="test_model")
2824      mlflow.set_active_model(name=model.name)
2825      client = MlflowClient(tracking_uri=tracking_uri)
2826      with mlflow.start_run() as run:
2827          client.log_metric(run.info.run_id, "metric", 1)
2828      logged_model = mlflow.get_logged_model(model_id=model.model_id)
2829      assert logged_model.name == model.name
2830      assert logged_model.model_id == model.model_id
2831      assert logged_model.metrics[0].key == "metric"
2832      assert logged_model.metrics[0].value == 1
2833  
2834  
2835  @pytest.mark.skipif(
2836      "MLFLOW_SKINNY" in os.environ,
2837      reason="Skinny client does not support the np or pandas dependencies",
2838  )
2839  def test_log_batch_link_to_active_model(tracking_uri):
2840      model = mlflow.create_external_model(name="test_model")
2841      mlflow.set_active_model(name=model.name)
2842      client = MlflowClient(tracking_uri=tracking_uri)
2843      with mlflow.start_run() as run:
2844          client.log_batch(run.info.run_id, [Metric("metric1", 1, 0, 0), Metric("metric2", 2, 0, 0)])
2845      logged_model = mlflow.get_logged_model(model_id=model.model_id)
2846      assert logged_model.name == model.name
2847      assert logged_model.model_id == model.model_id
2848      assert {m.key: m.value for m in logged_model.metrics} == {
2849          "metric1": 1,
2850          "metric2": 2,
2851      }
2852  
2853  
2854  def test_load_prompt_with_alias_uri(tracking_uri, disable_prompt_cache):
2855      client = MlflowClient(tracking_uri=tracking_uri)
2856  
2857      # Register two versions of a prompt
2858      client.register_prompt(name="alias_prompt", template="Hello, world!")
2859      client.register_prompt(name="alias_prompt", template="Hello, {{name}}!")
2860  
2861      # Assign alias to version 1
2862      client.set_prompt_alias("alias_prompt", alias="production", version=1)
2863      prompt = client.load_prompt("prompts:/alias_prompt@production")
2864      assert prompt.template == "Hello, world!"
2865      assert "production" in prompt.aliases
2866  
2867      # Reassign alias to version 2
2868      client.set_prompt_alias("alias_prompt", alias="production", version=2)
2869      prompt = client.load_prompt("prompts:/alias_prompt@production")
2870      assert prompt.template == "Hello, {{name}}!"
2871      assert "production" in prompt.aliases
2872  
2873      # Delete alias and verify loading fails
2874      client.delete_prompt_alias("alias_prompt", alias="production")
2875      with pytest.raises(
2876          MlflowException, match=r"Prompt (.*) does not exist.|Prompt alias (.*) not found."
2877      ):
2878          client.load_prompt("prompts:/alias_prompt@production")
2879  
2880      # Loading with the 'latest' alias
2881      prompt = client.load_prompt("prompts:/alias_prompt@latest")
2882      assert prompt.template == "Hello, {{name}}!"
2883  
2884  
2885  def test_load_prompt_allow_missing_name_version(tracking_uri):
2886      client = MlflowClient(tracking_uri=tracking_uri)
2887  
2888      # Non-existent prompt by name+version should return None when allow_missing=True
2889      result = client.load_prompt("nonexistent_prompt", version=1, allow_missing=True)
2890      assert result is None
2891  
2892      # Non-existent prompt by name+version should raise exception when allow_missing=False
2893      with pytest.raises(MlflowException, match="Prompt with name=nonexistent_prompt not found"):
2894          client.load_prompt("nonexistent_prompt", version=1, allow_missing=False)
2895  
2896      # Existing prompt with non-existent version should return None when allow_missing=True
2897      client.register_prompt(name="existing_prompt", template="Hello, world!")
2898      result = client.load_prompt("existing_prompt", version=999, allow_missing=True)
2899      assert result is None
2900  
2901      # Existing prompt with non-existent version should raise exception when allow_missing=False
2902      with pytest.raises(
2903          MlflowException, match=r"Prompt \(name=existing_prompt, version=999\) not found"
2904      ):
2905          client.load_prompt("existing_prompt", version=999, allow_missing=False)
2906  
2907  
2908  def test_load_prompt_allow_missing_uri_version(tracking_uri):
2909      client = MlflowClient(tracking_uri=tracking_uri)
2910  
2911      # Non-existent prompt by URI+version should return None when allow_missing=True
2912      result = client.load_prompt("prompts:/nonexistent_prompt/1", allow_missing=True)
2913      assert result is None
2914  
2915      # Non-existent prompt by URI+version should raise exception when allow_missing=False
2916      with pytest.raises(MlflowException, match="Prompt with name=nonexistent_prompt not found"):
2917          client.load_prompt("prompts:/nonexistent_prompt/1", allow_missing=False)
2918  
2919      # Existing prompt with non-existent version via URI should return None when allow_missing=True
2920      client.register_prompt(name="existing_prompt", template="Hello, world!")
2921      result = client.load_prompt("prompts:/existing_prompt/999", allow_missing=True)
2922      assert result is None
2923  
2924      # Existing prompt with non-existent version via URI should raise when allow_missing=False
2925      with pytest.raises(
2926          MlflowException, match=r"Prompt \(name=existing_prompt, version=999\) not found"
2927      ):
2928          client.load_prompt("prompts:/existing_prompt/999", allow_missing=False)
2929  
2930  
2931  def test_load_prompt_allow_missing_uri_alias(tracking_uri):
2932      client = MlflowClient(tracking_uri=tracking_uri)
2933  
2934      # Non-existent prompt with alias should return None when allow_missing=True
2935      result = client.load_prompt("prompts:/nonexistent_prompt@production", allow_missing=True)
2936      assert result is None
2937  
2938      # Non-existent prompt with alias should raise exception when allow_missing=False
2939      with pytest.raises(MlflowException, match="Prompt with name=nonexistent_prompt not found"):
2940          client.load_prompt("prompts:/nonexistent_prompt@production", allow_missing=False)
2941  
2942      # Existing prompt with non-existent alias should return None when allow_missing=True
2943      client.register_prompt(name="existing_prompt", template="Hello, world!")
2944      result = client.load_prompt("prompts:/existing_prompt@nonexistent_alias", allow_missing=True)
2945      assert result is None
2946  
2947      # Existing prompt with non-existent alias should raise exception when allow_missing=False
2948      with pytest.raises(MlflowException, match="Prompt alias nonexistent_alias not found"):
2949          client.load_prompt("prompts:/existing_prompt@nonexistent_alias", allow_missing=False)
2950  
2951  
2952  def test_create_prompt_chat_format_client_integration():
2953      chat_template = [
2954          {"role": "system", "content": "You are a {{style}} assistant."},
2955          {"role": "user", "content": "{{question}}"},
2956      ]
2957  
2958      response_format = {"type": "string"}
2959  
2960      # Use client to create prompt
2961      client = MlflowClient()
2962      prompt = client.register_prompt(
2963          name="test_chat_client",
2964          template=chat_template,
2965          response_format=response_format,
2966          commit_message="Test chat prompt via client",
2967      )
2968  
2969      assert prompt.template == chat_template
2970      assert prompt.response_format == response_format
2971  
2972      # Load via client
2973      loaded_prompt = client.get_prompt_version("test_chat_client", 1)
2974      assert not loaded_prompt.is_text_prompt
2975      assert loaded_prompt.template == chat_template
2976      assert loaded_prompt.response_format == response_format
2977  
2978  
2979  def test_link_chat_prompt_version_to_run():
2980      chat_template = [
2981          {"role": "system", "content": "You are a helpful assistant."},
2982          {"role": "user", "content": "Hello {{name}}!"},
2983      ]
2984  
2985      client = MlflowClient()
2986      prompt = client.register_prompt(name="test_chat_link", template=chat_template)
2987  
2988      # Create run and link prompt
2989      run = client.create_run(client.create_experiment("test_exp"))
2990      client.link_prompt_version_to_run(run.info.run_id, prompt)
2991  
2992      # Verify linking
2993      run_data = client.get_run(run.info.run_id)
2994      linked_prompts_tag = run_data.data.tags.get(TraceTagKey.LINKED_PROMPTS)
2995      assert linked_prompts_tag is not None
2996  
2997      linked_prompts = json.loads(linked_prompts_tag)
2998      assert len(linked_prompts) == 1
2999      assert linked_prompts[0]["name"] == "test_chat_link"
3000      assert linked_prompts[0]["version"] == "1"
3001  
3002  
3003  def test_create_prompt_with_pydantic_response_format_client():
3004      class ResponseSchema(BaseModel):
3005          answer: str
3006          confidence: float
3007  
3008      client = MlflowClient()
3009      prompt = client.register_prompt(
3010          name="test_pydantic_client",
3011          template="What is {{question}}?",
3012          response_format=ResponseSchema,
3013          commit_message="Test Pydantic response format via client",
3014      )
3015  
3016      assert prompt.response_format == ResponseSchema.model_json_schema()
3017      assert prompt.commit_message == "Test Pydantic response format via client"
3018  
3019      # Load and verify
3020      loaded_prompt = client.get_prompt_version("test_pydantic_client", 1)
3021      assert loaded_prompt.response_format == ResponseSchema.model_json_schema()
3022  
3023  
3024  def test_create_prompt_with_dict_response_format_client():
3025      response_format = {
3026          "type": "object",
3027          "properties": {
3028              "summary": {"type": "string"},
3029              "key_points": {"type": "array", "items": {"type": "string"}},
3030          },
3031      }
3032  
3033      client = MlflowClient()
3034      prompt = client.register_prompt(
3035          name="test_dict_response_client",
3036          template="Analyze this: {{text}}",
3037          response_format=response_format,
3038          tags={"analysis_type": "text"},
3039      )
3040  
3041      assert prompt.response_format == response_format
3042      assert prompt.tags["analysis_type"] == "text"
3043  
3044      # Load and verify
3045      loaded_prompt = client.get_prompt_version("test_dict_response_client", 1)
3046      assert loaded_prompt.response_format == response_format
3047  
3048  
3049  def test_create_prompt_text_backward_compatibility_client():
3050      client = MlflowClient()
3051      prompt = client.register_prompt(
3052          name="test_text_backward_client",
3053          template="Hello {{name}}!",
3054          commit_message="Test backward compatibility via client",
3055      )
3056  
3057      assert prompt.is_text_prompt
3058      assert prompt.template == "Hello {{name}}!"
3059      assert prompt.commit_message == "Test backward compatibility via client"
3060  
3061      # Load and verify
3062      loaded_prompt = client.get_prompt_version("test_text_backward_client", 1)
3063      assert loaded_prompt.is_text_prompt
3064      assert loaded_prompt.template == "Hello {{name}}!"
3065  
3066  
3067  def test_create_prompt_complex_chat_template_client():
3068      chat_template = [
3069          {
3070              "role": "system",
3071              "content": "You are a {{style}} assistant named {{name}}.",
3072          },
3073          {"role": "user", "content": "{{greeting}}! {{question}}"},
3074          {
3075              "role": "assistant",
3076              "content": "I understand you're asking about {{topic}}.",
3077          },
3078      ]
3079  
3080      client = MlflowClient()
3081      prompt = client.register_prompt(
3082          name="test_complex_chat_client",
3083          template=chat_template,
3084          tags={"complexity": "high"},
3085      )
3086  
3087      assert prompt.template == chat_template
3088      assert prompt.tags["complexity"] == "high"
3089  
3090      # Load and verify
3091      loaded_prompt = client.get_prompt_version("test_complex_chat_client", 1)
3092      assert not loaded_prompt.is_text_prompt
3093      assert loaded_prompt.template == chat_template
3094  
3095  
3096  def test_create_prompt_with_none_response_format_client():
3097      client = MlflowClient()
3098      prompt = client.register_prompt(
3099          name="test_none_response_client",
3100          template="Hello {{name}}!",
3101          response_format=None,
3102      )
3103  
3104      assert prompt.response_format is None
3105  
3106      # Load and verify
3107      loaded_prompt = client.get_prompt_version("test_none_response_client", 1)
3108      assert loaded_prompt.response_format is None
3109  
3110  
3111  def test_create_prompt_with_single_message_chat_client():
3112      chat_template = [{"role": "user", "content": "Hello {{name}}!"}]
3113  
3114      client = MlflowClient()
3115      prompt = client.register_prompt(name="test_single_message_client", template=chat_template)
3116  
3117      assert prompt.template == chat_template
3118      assert prompt.variables == {"name"}
3119  
3120      # Load and verify
3121      loaded_prompt = client.get_prompt_version("test_single_message_client", 1)
3122      assert not loaded_prompt.is_text_prompt
3123      assert loaded_prompt.template == chat_template
3124  
3125  
3126  def test_create_prompt_with_multiple_variables_in_chat_client():
3127      chat_template = [
3128          {
3129              "role": "system",
3130              "content": "You are a {{style}} assistant named {{name}}.",
3131          },
3132          {"role": "user", "content": "{{greeting}}! {{question}}"},
3133          {
3134              "role": "assistant",
3135              "content": "I understand you're asking about {{topic}}.",
3136          },
3137      ]
3138  
3139      client = MlflowClient()
3140      prompt = client.register_prompt(name="test_multiple_variables_client", template=chat_template)
3141  
3142      expected_variables = {"style", "name", "greeting", "question", "topic"}
3143      assert prompt.variables == expected_variables
3144  
3145      # Load and verify
3146      loaded_prompt = client.get_prompt_version("test_multiple_variables_client", 1)
3147      assert loaded_prompt.variables == expected_variables
3148  
3149  
3150  def test_create_prompt_with_mixed_content_types_client():
3151      chat_template = [
3152          {"role": "system", "content": "You are a helpful assistant."},
3153          {"role": "user", "content": "Hello {{name}}!"},
3154          {"role": "assistant", "content": "Hi there! How can I help you today?"},
3155      ]
3156  
3157      client = MlflowClient()
3158      prompt = client.register_prompt(name="test_mixed_content_client", template=chat_template)
3159  
3160      assert prompt.template == chat_template
3161      assert prompt.variables == {"name"}
3162  
3163      # Load and verify
3164      loaded_prompt = client.get_prompt_version("test_mixed_content_client", 1)
3165      assert not loaded_prompt.is_text_prompt
3166      assert loaded_prompt.template == chat_template
3167  
3168  
3169  def test_create_prompt_with_nested_variables_client():
3170      chat_template = [
3171          {
3172              "role": "system",
3173              "content": "You are a {{user.preferences.style}} assistant.",
3174          },
3175          {
3176              "role": "user",
3177              "content": "Hello {{user.name}}! {{user.preferences.greeting}}",
3178          },
3179      ]
3180  
3181      client = MlflowClient()
3182      prompt = client.register_prompt(name="test_nested_variables_client", template=chat_template)
3183  
3184      expected_variables = {
3185          "user.preferences.style",
3186          "user.name",
3187          "user.preferences.greeting",
3188      }
3189      assert prompt.variables == expected_variables
3190  
3191      # Load and verify
3192      loaded_prompt = client.get_prompt_version("test_nested_variables_client", 1)
3193      assert loaded_prompt.variables == expected_variables
3194  
3195  
3196  def test_link_prompt_with_response_format_to_run():
3197      response_format = {
3198          "type": "object",
3199          "properties": {"answer": {"type": "string"}},
3200      }
3201      client = MlflowClient()
3202      prompt = client.register_prompt(
3203          name="test_response_link",
3204          template="What is {{question}}?",
3205          response_format=response_format,
3206      )
3207  
3208      # Create run and link prompt
3209      run = client.create_run(client.create_experiment("test_exp"))
3210      client.link_prompt_version_to_run(run.info.run_id, prompt)
3211  
3212      # Verify linking
3213      run_data = client.get_run(run.info.run_id)
3214      linked_prompts_tag = run_data.data.tags.get(TraceTagKey.LINKED_PROMPTS)
3215      assert linked_prompts_tag is not None
3216  
3217      linked_prompts = json.loads(linked_prompts_tag)
3218      assert len(linked_prompts) == 1
3219      assert linked_prompts[0]["name"] == "test_response_link"
3220      assert linked_prompts[0]["version"] == "1"
3221  
3222  
3223  def test_link_multiple_prompt_types_to_run():
3224      client = MlflowClient()
3225  
3226      # Create text prompt
3227      text_prompt = client.register_prompt(name="test_text_link", template="Hello {{name}}!")
3228  
3229      # Create chat prompt
3230      chat_template = [
3231          {"role": "system", "content": "You are a helpful assistant."},
3232          {"role": "user", "content": "{{question}}"},
3233      ]
3234      chat_prompt = client.register_prompt(name="test_chat_link_multiple", template=chat_template)
3235  
3236      # Create run and link both prompts
3237      run = client.create_run(client.create_experiment("test_exp"))
3238      client.link_prompt_version_to_run(run.info.run_id, text_prompt)
3239      client.link_prompt_version_to_run(run.info.run_id, chat_prompt)
3240  
3241      # Verify linking
3242      run_data = client.get_run(run.info.run_id)
3243      linked_prompts_tag = run_data.data.tags.get(TraceTagKey.LINKED_PROMPTS)
3244      assert linked_prompts_tag is not None
3245  
3246      linked_prompts = json.loads(linked_prompts_tag)
3247      assert len(linked_prompts) == 2
3248  
3249      expected_prompts = [
3250          {"name": "test_text_link", "version": "1"},
3251          {"name": "test_chat_link_multiple", "version": "1"},
3252      ]
3253      for expected_prompt in expected_prompts:
3254          assert expected_prompt in linked_prompts
3255  
3256  
3257  def test_mlflow_client_create_dataset(mock_store):
3258      created_dataset = EvaluationDataset(
3259          dataset_id="test_dataset_id",
3260          name="test_dataset",
3261          digest="abcdef123456",
3262          created_time=1234567890,
3263          last_update_time=1234567890,
3264          tags={"environment": "production", "version": "1.0"},
3265      )
3266      created_dataset.experiment_ids = ["exp1", "exp2"]
3267      mock_store.create_dataset.return_value = created_dataset
3268  
3269      # Mock context registry to return empty tags so mlflow.user is not auto-added
3270      with mock.patch(
3271          "mlflow.tracking._tracking_service.client.context_registry.resolve_tags", return_value={}
3272      ):
3273          dataset = MlflowClient().create_dataset(
3274              name="qa_evaluation",
3275              experiment_id=["exp1", "exp2"],
3276              tags={"environment": "production", "version": "1.0"},
3277          )
3278  
3279      assert dataset.dataset_id == "test_dataset_id"
3280      assert dataset.name == "test_dataset"
3281      assert dataset.tags == {"environment": "production", "version": "1.0"}
3282  
3283      mock_store.create_dataset.assert_called_once_with(
3284          name="qa_evaluation",
3285          tags={"environment": "production", "version": "1.0"},
3286          experiment_ids=["exp1", "exp2"],
3287      )
3288  
3289  
3290  def test_mlflow_client_create_evaluation_dataset_minimal(mock_store):
3291      created_dataset = EvaluationDataset(
3292          dataset_id="test_dataset_id",
3293          name="test_dataset",
3294          digest="abcdef123456",
3295          created_time=1234567890,
3296          last_update_time=1234567890,
3297      )
3298      mock_store.create_dataset.return_value = created_dataset
3299  
3300      # Mock context registry to return empty tags so mlflow.user is not auto-added
3301      with mock.patch(
3302          "mlflow.tracking._tracking_service.client.context_registry.resolve_tags", return_value={}
3303      ):
3304          dataset = MlflowClient().create_dataset(name="test_dataset")
3305  
3306      assert dataset.dataset_id == "test_dataset_id"
3307      assert dataset.name == "test_dataset"
3308  
3309      mock_store.create_dataset.assert_called_once_with(
3310          name="test_dataset",
3311          tags=None,
3312          experiment_ids=None,
3313      )
3314  
3315  
3316  def test_mlflow_client_get_dataset(mock_store):
3317      mock_store.get_dataset.return_value = EvaluationDataset(
3318          dataset_id="dataset_123",
3319          name="test_dataset",
3320          digest="abcdef123456",
3321          created_time=1234567890,
3322          last_update_time=1234567890,
3323          tags={"source": "human-annotated"},
3324      )
3325  
3326      dataset = MlflowClient().get_dataset("dataset_123")
3327  
3328      assert dataset.dataset_id == "dataset_123"
3329      assert dataset.name == "test_dataset"
3330      assert dataset.tags == {"source": "human-annotated"}
3331  
3332      mock_store.get_dataset.assert_called_once_with("dataset_123")
3333  
3334  
3335  def test_mlflow_client_delete_dataset(mock_store):
3336      MlflowClient().delete_dataset("dataset_123")
3337  
3338      mock_store.delete_dataset.assert_called_once_with("dataset_123")
3339  
3340  
3341  def test_mlflow_client_search_datasets(mock_store):
3342      mock_store.search_datasets.return_value = PagedList(
3343          [
3344              EvaluationDataset(
3345                  dataset_id="dataset_1",
3346                  name="dataset_1",
3347                  digest="digest1",
3348                  created_time=1234567890,
3349                  last_update_time=1234567890,
3350              ),
3351              EvaluationDataset(
3352                  dataset_id="dataset_2",
3353                  name="dataset_2",
3354                  digest="digest2",
3355                  created_time=1234567890,
3356                  last_update_time=1234567890,
3357              ),
3358          ],
3359          "next_token",
3360      )
3361  
3362      result = MlflowClient().search_datasets(
3363          experiment_ids=["exp1", "exp2"],
3364          filter_string="name LIKE 'qa_%'",
3365          max_results=100,
3366          order_by=["created_time DESC"],
3367          page_token="page_token_123",
3368      )
3369  
3370      assert len(result) == 2
3371      assert result[0].dataset_id == "dataset_1"
3372      assert result[1].dataset_id == "dataset_2"
3373      assert result.token == "next_token"
3374  
3375      mock_store.search_datasets.assert_called_once_with(
3376          experiment_ids=["exp1", "exp2"],
3377          filter_string="name LIKE 'qa_%'",
3378          max_results=100,
3379          order_by=["created_time DESC"],
3380          page_token="page_token_123",
3381      )
3382  
3383  
3384  def test_mlflow_client_search_datasets_empty_results(mock_store):
3385      mock_store.search_datasets.return_value = PagedList([], None)
3386  
3387      result = MlflowClient().search_datasets(
3388          experiment_ids=["exp1"], filter_string="name = 'nonexistent'"
3389      )
3390  
3391      assert len(result) == 0
3392      assert result.token is None
3393  
3394  
3395  def test_mlflow_client_search_datasets_defaults(mock_store):
3396      mock_store.search_datasets.return_value = PagedList([], None)
3397  
3398      result = MlflowClient().search_datasets()
3399  
3400      assert len(result) == 0
3401      assert result.token is None
3402  
3403      mock_store.search_datasets.assert_called_once_with(
3404          experiment_ids=None,
3405          filter_string=None,
3406          max_results=SEARCH_EVALUATION_DATASETS_MAX_RESULTS,
3407          order_by=None,
3408          page_token=None,
3409      )
3410  
3411  
3412  @pytest.mark.skipif(is_windows(), reason="FileStore URI handling issues on Windows")
3413  def test_mlflow_client_datasets_filestore_not_supported(tmp_path):
3414      pytest.skip("FileStore is no longer supported.")
3415      file_store_uri = str(tmp_path)
3416      client = MlflowClient(tracking_uri=file_store_uri)
3417  
3418      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3419          client.create_dataset(name="test_dataset")
3420      assert exc_info.value.error_code == "FEATURE_DISABLED"
3421  
3422      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3423          client.get_dataset("dataset_123")
3424      assert exc_info.value.error_code == "FEATURE_DISABLED"
3425  
3426      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3427          client.delete_dataset("dataset_123")
3428      assert exc_info.value.error_code == "FEATURE_DISABLED"
3429  
3430      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3431          client.search_datasets()
3432      assert exc_info.value.error_code == "FEATURE_DISABLED"
3433  
3434      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3435          client.set_dataset_tags("dataset_123", {"tag1": "value1"})
3436      assert exc_info.value.error_code == "FEATURE_DISABLED"
3437  
3438      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3439          client.delete_dataset_tag("dataset_123", "tag1")
3440      assert exc_info.value.error_code == "FEATURE_DISABLED"
3441  
3442      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3443          client.add_dataset_to_experiments("dataset_123", ["1", "2"])
3444      assert exc_info.value.error_code == "FEATURE_DISABLED"
3445  
3446      with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info:
3447          client.remove_dataset_from_experiments("dataset_123", ["1", "2"])
3448      assert exc_info.value.error_code == "FEATURE_DISABLED"
3449  
3450  
3451  def test_mlflow_client_set_dataset_tags(mock_store):
3452      MlflowClient().set_dataset_tags(
3453          dataset_id="dataset_123",
3454          tags={"env": "prod", "version": "2.0"},
3455      )
3456  
3457      mock_store.set_dataset_tags.assert_called_once_with(
3458          dataset_id="dataset_123",
3459          tags={"env": "prod", "version": "2.0"},
3460      )
3461  
3462  
3463  def test_mlflow_client_delete_dataset_tag(mock_store):
3464      MlflowClient().delete_dataset_tag(
3465          dataset_id="dataset_123",
3466          key="deprecated",
3467      )
3468  
3469      mock_store.delete_dataset_tag.assert_called_once_with(
3470          dataset_id="dataset_123",
3471          key="deprecated",
3472      )
3473  
3474  
3475  def test_mlflow_client_add_dataset_to_experiments(mock_store):
3476      mock_dataset = Mock(spec=EvaluationDataset)
3477      mock_dataset.dataset_id = "dataset_123"
3478      mock_dataset.experiment_ids = ["1", "2", "3"]
3479      mock_store.add_dataset_to_experiments.return_value = mock_dataset
3480  
3481      client = MlflowClient()
3482      result = client.add_dataset_to_experiments(
3483          dataset_id="dataset_123",
3484          experiment_ids=["2", "3"],
3485      )
3486  
3487      assert result == mock_dataset
3488      assert result.experiment_ids == ["1", "2", "3"]
3489      mock_store.add_dataset_to_experiments.assert_called_once_with("dataset_123", ["2", "3"])
3490  
3491  
3492  def test_mlflow_client_remove_dataset_from_experiments(mock_store):
3493      mock_dataset = Mock(spec=EvaluationDataset)
3494      mock_dataset.dataset_id = "dataset_123"
3495      mock_dataset.experiment_ids = ["1"]
3496      mock_store.remove_dataset_from_experiments.return_value = mock_dataset
3497  
3498      client = MlflowClient()
3499      result = client.remove_dataset_from_experiments(
3500          dataset_id="dataset_123",
3501          experiment_ids=["2", "3"],
3502      )
3503  
3504      assert result == mock_dataset
3505      assert result.experiment_ids == ["1"]
3506      mock_store.remove_dataset_from_experiments.assert_called_once_with("dataset_123", ["2", "3"])
3507  
3508  
3509  def test_mlflow_client_dataset_associations_databricks_blocking(mock_store):
3510      with mock.patch("mlflow.utils.databricks_utils.is_databricks_uri") as mock_is_dbx:
3511          mock_is_dbx.return_value = True
3512          client = MlflowClient(tracking_uri="databricks")
3513  
3514          with pytest.raises(
3515              MlflowException, match="not supported when tracking URI is 'databricks'"
3516          ) as exc_info:
3517              client.add_dataset_to_experiments("dataset_123", ["1", "2"])
3518          assert exc_info.value.error_code == "INVALID_PARAMETER_VALUE"
3519  
3520          with pytest.raises(
3521              MlflowException, match="not supported when tracking URI is 'databricks'"
3522          ) as exc_info:
3523              client.remove_dataset_from_experiments("dataset_123", ["1", "2"])
3524          assert exc_info.value.error_code == "INVALID_PARAMETER_VALUE"
3525  
3526  
3527  def test_log_spans_and_get_trace_with_sqlalchemy_store(tmp_path: Path) -> None:
3528      tracking_uri = f"sqlite:///{tmp_path}/test.db"
3529  
3530      with _use_tracking_uri(tracking_uri):
3531          client = MlflowClient()
3532  
3533          assert isinstance(client._tracking_client.store, SqlAlchemyTrackingStore)
3534  
3535          experiment_id = client.create_experiment("test_log_spans_get_trace")
3536          trace_id = f"tr-{uuid.uuid4().hex}"
3537  
3538          # Create test spans using OpenTelemetry format
3539          otel_span1 = OTelReadableSpan(
3540              name="parent_span",
3541              context=trace_api.SpanContext(
3542                  trace_id=12345,
3543                  span_id=111,
3544                  is_remote=False,
3545                  trace_flags=trace_api.TraceFlags(1),
3546              ),
3547              parent=None,
3548              attributes={
3549                  "mlflow.traceRequestId": json.dumps(trace_id, cls=TraceJSONEncoder),
3550                  "llm.model_name": "test-model",
3551                  "custom.attribute": "parent-value",
3552              },
3553              start_time=1_000_000_000,
3554              end_time=2_000_000_000,
3555              resource=None,
3556          )
3557  
3558          otel_span2 = OTelReadableSpan(
3559              name="child_span",
3560              context=trace_api.SpanContext(
3561                  trace_id=12345,
3562                  span_id=222,
3563                  is_remote=False,
3564                  trace_flags=trace_api.TraceFlags(1),
3565              ),
3566              parent=trace_api.SpanContext(
3567                  trace_id=12345,
3568                  span_id=111,
3569                  is_remote=False,
3570                  trace_flags=trace_api.TraceFlags(1),
3571              ),
3572              attributes={
3573                  "mlflow.traceRequestId": json.dumps(trace_id, cls=TraceJSONEncoder),
3574                  "operation.type": "database_query",
3575                  "custom.attribute": "child-value",
3576              },
3577              start_time=1_200_000_000,
3578              end_time=1_800_000_000,
3579              resource=None,
3580          )
3581  
3582          # Convert to MLflow spans
3583          mlflow_spans = [
3584              create_mlflow_span(otel_span1, trace_id, "LLM"),
3585              create_mlflow_span(otel_span2, trace_id, "LLM"),
3586          ]
3587  
3588          # Log spans directly to the store (simulating OTLP endpoint)
3589          store = client._tracking_client.store
3590          logged_spans = store.log_spans(experiment_id, mlflow_spans)
3591  
3592          # Verify spans were logged
3593          assert len(logged_spans) == 2
3594  
3595          # Verify the trace has the spans location tag set
3596          trace_info = store.get_trace_info(trace_id)
3597          assert trace_info.tags.get(TraceTagKey.SPANS_LOCATION) == SpansLocation.TRACKING_STORE
3598  
3599          # Now test that mlflow.get_trace() works and loads spans from the database
3600          trace = mlflow.get_trace(trace_id)
3601  
3602          # Verify trace structure
3603          assert trace.info.trace_id == trace_id
3604          assert trace.info.tags.get(TraceTagKey.SPANS_LOCATION) == SpansLocation.TRACKING_STORE
3605  
3606          # Verify spans were loaded from database
3607          assert len(trace.data.spans) == 2
3608  
3609          # Sort spans by start time for consistent testing
3610          spans_by_start_time = sorted(trace.data.spans, key=lambda s: s.start_time_ns)
3611  
3612          # Verify parent span
3613          parent_span = spans_by_start_time[0]
3614          assert parent_span.name == "parent_span"
3615          assert parent_span.trace_id == trace_id
3616          assert parent_span.start_time_ns == 1_000_000_000
3617          assert parent_span.end_time_ns == 2_000_000_000
3618          assert parent_span.attributes.get("llm.model_name") == "test-model"
3619          assert parent_span.attributes.get("custom.attribute") == "parent-value"
3620  
3621          # Verify child span
3622          child_span = spans_by_start_time[1]
3623          assert child_span.name == "child_span"
3624          assert child_span.trace_id == trace_id
3625          assert child_span.start_time_ns == 1_200_000_000
3626          assert child_span.end_time_ns == 1_800_000_000
3627          assert child_span.attributes.get("operation.type") == "database_query"
3628          assert child_span.attributes.get("custom.attribute") == "child-value"
3629  
3630  
3631  def test_mlflow_get_trace_with_sqlalchemy_store(tmp_path: Path) -> None:
3632      tracking_uri = f"sqlite:///{tmp_path}/test.db"
3633  
3634      with _use_tracking_uri(tracking_uri):
3635          client = MlflowClient()
3636  
3637          assert isinstance(client._tracking_client.store, SqlAlchemyTrackingStore)
3638  
3639          with mlflow.start_span() as span:
3640              pass
3641  
3642          trace_id = span.trace_id
3643          mlflow.flush_trace_async_logging()
3644          sql_alchemy_store_module = "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore"
3645          with (
3646              mock.patch(f"{sql_alchemy_store_module}.get_trace") as mock_get_trace,
3647          ):
3648              mlflow.get_trace(trace_id)
3649  
3650          mock_get_trace.assert_called_once_with(trace_id)
3651  
3652          with (
3653              mock.patch(
3654                  f"{sql_alchemy_store_module}.get_trace",
3655                  side_effect=MlflowNotImplementedException,
3656              ) as mock_get_trace,
3657              mock.patch(f"{sql_alchemy_store_module}.batch_get_traces") as mock_batch_get_traces,
3658          ):
3659              mlflow.get_trace(trace_id)
3660  
3661          mock_get_trace.assert_called_once_with(trace_id)
3662          mock_batch_get_traces.assert_called_once_with([trace_id])
3663  
3664  
3665  def test_create_issue_basic(tmp_path: Path):
3666      tracking_uri = f"sqlite:///{tmp_path}/test.db"
3667  
3668      with _use_tracking_uri(tracking_uri):
3669          client = MlflowClient()
3670          exp_id = client.create_experiment("test_create_issue")
3671          tracing_client = client._tracing_client
3672  
3673          issue = tracing_client._create_issue(
3674              experiment_id=exp_id,
3675              name="Test issue",
3676              description="This is a test issue",
3677          )
3678  
3679          assert issue.issue_id.startswith("iss-")
3680          assert issue.experiment_id == exp_id
3681          assert issue.name == "Test issue"
3682          assert issue.description == "This is a test issue"
3683          assert issue.status == IssueStatus.PENDING
3684          assert issue.severity is None
3685          assert issue.root_causes is None
3686          assert issue.source_run_id is None
3687          assert issue.created_by is None
3688          assert issue.created_timestamp > 0
3689          assert issue.last_updated_timestamp == issue.created_timestamp
3690  
3691  
3692  def test_create_issue_with_all_fields(tmp_path: Path):
3693      tracking_uri = f"sqlite:///{tmp_path}/test.db"
3694  
3695      with _use_tracking_uri(tracking_uri):
3696          client = MlflowClient()
3697          exp_id = client.create_experiment("test_create_issue_all_fields")
3698          tracing_client = client._tracing_client
3699          with mlflow.start_run(experiment_id=exp_id) as run:
3700              issue = tracing_client._create_issue(
3701                  experiment_id=exp_id,
3702                  name="High latency",
3703                  description="API response times exceed threshold",
3704                  status=IssueStatus.RESOLVED,
3705                  severity=IssueSeverity.HIGH,
3706                  root_causes=["Database query slow", "Network congestion"],
3707                  source_run_id=run.info.run_id,
3708                  created_by="monitoring_system",
3709              )
3710  
3711      assert issue.issue_id.startswith("iss-")
3712      assert issue.experiment_id == exp_id
3713      assert issue.name == "High latency"
3714      assert issue.description == "API response times exceed threshold"
3715      assert issue.status == IssueStatus.RESOLVED
3716      assert issue.severity == IssueSeverity.HIGH
3717      assert issue.root_causes == ["Database query slow", "Network congestion"]
3718      assert issue.source_run_id == run.info.run_id
3719      assert issue.created_by == "monitoring_system"
3720      assert issue.created_timestamp > 0