test_client.py
1 import json 2 import os 3 import pickle 4 import threading 5 import time 6 import uuid 7 from pathlib import Path 8 from unittest import mock 9 from unittest.mock import Mock, patch 10 11 import pytest 12 from opentelemetry import trace as trace_api 13 from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan 14 from pydantic import BaseModel 15 16 import mlflow 17 from mlflow import MlflowClient, flush_async_logging 18 from mlflow.config import enable_async_logging 19 from mlflow.entities import ( 20 EvaluationDataset, 21 ExperimentTag, 22 IssueSeverity, 23 IssueStatus, 24 LoggedModel, 25 Run, 26 RunInfo, 27 RunStatus, 28 RunTag, 29 SourceType, 30 Span, 31 SpanStatusCode, 32 SpanType, 33 Trace, 34 TraceInfo, 35 ViewType, 36 ) 37 from mlflow.entities.file_info import FileInfo 38 from mlflow.entities.logged_model_status import LoggedModelStatus 39 from mlflow.entities.metric import Metric 40 from mlflow.entities.model_registry import ModelVersion, ModelVersionTag 41 from mlflow.entities.model_registry.model_version_status import ModelVersionStatus 42 from mlflow.entities.model_registry.prompt_version import IS_PROMPT_TAG_KEY 43 from mlflow.entities.param import Param 44 from mlflow.entities.span import create_mlflow_span 45 from mlflow.entities.trace_data import TraceData 46 from mlflow.entities.trace_location import TraceLocation, TraceLocationType, UCSchemaLocation 47 from mlflow.entities.trace_state import TraceState 48 from mlflow.entities.trace_status import TraceStatus 49 from mlflow.environment_variables import MLFLOW_TRACKING_USERNAME 50 from mlflow.exceptions import ( 51 MlflowException, 52 MlflowNotImplementedException, 53 MlflowTraceDataCorrupted, 54 MlflowTraceDataNotFound, 55 ) 56 from mlflow.prompt.registry_utils import PromptCache 57 from mlflow.store.artifact.artifact_repo import ArtifactRepository 58 from mlflow.store.entities.paged_list import PagedList 59 from mlflow.store.model_registry.sqlalchemy_store import ( 60 SqlAlchemyStore as SqlAlchemyModelRegistryStore, 61 ) 62 from mlflow.store.tracking import SEARCH_EVALUATION_DATASETS_MAX_RESULTS, SEARCH_MAX_RESULTS_DEFAULT 63 from mlflow.store.tracking.sqlalchemy_store import SqlAlchemyStore as SqlAlchemyTrackingStore 64 from mlflow.tracing.constant import SpansLocation, TraceMetadataKey, TraceTagKey 65 from mlflow.tracing.provider import _get_tracer, trace_disabled 66 from mlflow.tracing.utils import TraceJSONEncoder 67 from mlflow.tracking import set_registry_uri 68 from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS 69 from mlflow.tracking._model_registry.utils import ( 70 _get_store_registry as _get_model_registry_store_registry, 71 ) 72 from mlflow.tracking._tracking_service.utils import _register, _use_tracking_uri 73 from mlflow.tracking.default_experiment import DEFAULT_EXPERIMENT_ID 74 from mlflow.utils.databricks_utils import _construct_databricks_run_url 75 from mlflow.utils.mlflow_tags import ( 76 MLFLOW_GIT_COMMIT, 77 MLFLOW_PARENT_RUN_ID, 78 MLFLOW_PROJECT_ENTRY_POINT, 79 MLFLOW_SOURCE_NAME, 80 MLFLOW_SOURCE_TYPE, 81 MLFLOW_USER, 82 ) 83 from mlflow.utils.os import is_windows 84 85 from tests.tracing.conftest import async_logging_enabled # noqa: F401 86 from tests.tracing.helper import create_test_trace_info, get_traces 87 88 89 @pytest.fixture(autouse=True) 90 def reset_registry_uri(): 91 yield 92 set_registry_uri(None) 93 94 95 @pytest.fixture 96 def mock_store(): 97 with mock.patch("mlflow.tracking._tracking_service.utils._get_store") as mock_get_store: 98 mock_store = mock_get_store.return_value 99 with mock.patch("mlflow.tracing.client._get_store", return_value=mock_store): 100 yield mock_store 101 102 103 @pytest.fixture 104 def mock_artifact_repo(): 105 with mock.patch( 106 "mlflow.tracking._tracking_service.client.get_artifact_repository" 107 ) as mock_get_repo: 108 mock_repo = mock_get_repo.return_value 109 with mock.patch("mlflow.tracing.client.get_artifact_repository", return_value=mock_repo): 110 yield mock_repo 111 112 113 @pytest.fixture 114 def mock_registry_store(): 115 mock_store = mock.MagicMock() 116 mock_store.create_model_version = mock.create_autospec( 117 SqlAlchemyModelRegistryStore.create_model_version 118 ) 119 with mock.patch("mlflow.tracking._model_registry.utils._get_store", return_value=mock_store): 120 yield mock_store 121 122 123 @pytest.fixture 124 def mock_databricks_tracking_store(): 125 experiment_id = "test-exp-id" 126 run_id = "runid" 127 128 class MockDatabricksTrackingStore: 129 def __init__(self, run_id, experiment_id): 130 self.run_id = run_id 131 self.experiment_id = experiment_id 132 133 def get_run(self, *args, **kwargs): 134 return Run( 135 RunInfo(self.run_id, self.experiment_id, "userid", "status", 0, 1, None), None 136 ) 137 138 mock_tracking_store = MockDatabricksTrackingStore(run_id, experiment_id) 139 140 with mock.patch( 141 "mlflow.tracking._tracking_service.utils._tracking_store_registry.get_store", 142 return_value=mock_tracking_store, 143 ): 144 yield mock_tracking_store 145 146 147 @pytest.fixture 148 def mock_store_start_trace(): 149 def _mock_start_trace(trace_info): 150 return create_test_trace_info( 151 trace_id="tr-123", 152 experiment_id=trace_info.experiment_id, 153 request_time=trace_info.request_time, 154 execution_duration=trace_info.execution_duration, 155 state=trace_info.state, 156 trace_metadata=trace_info.trace_metadata, 157 tags={ 158 "mlflow.user": "bob", 159 "mlflow.artifactLocation": "test", 160 **trace_info.tags, 161 }, 162 ) 163 164 with mock.patch( 165 "mlflow.tracing.client.TracingClient.start_trace", side_effect=_mock_start_trace 166 ) as mock_start_trace: 167 yield mock_start_trace 168 169 170 @pytest.fixture 171 def mock_spark_session(): 172 with mock.patch( 173 "mlflow.utils.databricks_utils._get_active_spark_session" 174 ) as mock_spark_session: 175 yield mock_spark_session.return_value 176 177 178 @pytest.fixture 179 def mock_time(): 180 time = 1552319350.244724 181 with mock.patch("time.time", return_value=time): 182 yield time 183 184 185 @pytest.fixture 186 def setup_async_logging(): 187 enable_async_logging(True) 188 yield 189 flush_async_logging() 190 enable_async_logging(False) 191 192 193 def test_client_create_run(mock_store, mock_time): 194 experiment_id = mock.Mock() 195 196 MlflowClient().create_run(experiment_id) 197 198 mock_store.create_run.assert_called_once_with( 199 experiment_id=experiment_id, 200 user_id="unknown", 201 start_time=int(mock_time * 1000), 202 tags=[], 203 run_name=None, 204 ) 205 206 207 def test_client_create_run_with_name(mock_store, mock_time): 208 experiment_id = mock.Mock() 209 210 MlflowClient().create_run(experiment_id, run_name="my name") 211 212 mock_store.create_run.assert_called_once_with( 213 experiment_id=experiment_id, 214 user_id="unknown", 215 start_time=int(mock_time * 1000), 216 tags=[], 217 run_name="my name", 218 ) 219 220 221 def test_client_get_trace(mock_store, mock_artifact_repo): 222 trace_id = "trace:/catalog.schema/123" 223 mock_store.batch_get_traces.return_value = [ 224 Trace( 225 TraceInfo( 226 trace_id=trace_id, 227 trace_location=TraceLocation( 228 type=TraceLocationType.UC_SCHEMA, 229 uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"), 230 ), 231 request_time=123, 232 execution_duration=456, 233 state=TraceState.OK, 234 tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts"}, 235 ), 236 TraceData( 237 spans=[ 238 Span.from_dict({ 239 "name": "predict", 240 "context": { 241 "trace_id": "0x123456789", 242 "span_id": "0x12345", 243 }, 244 "parent_id": None, 245 "start_time": 123000000, 246 "end_time": 579000000, 247 "status_code": "OK", 248 "status_message": "", 249 "attributes": { 250 "mlflow.traceRequestId": f'"{trace_id}"', 251 "mlflow.spanType": '"LLM"', 252 "mlflow.spanFunctionName": '"predict"', 253 "mlflow.spanInputs": '{"prompt": "What is the meaning of life?"}', 254 "mlflow.spanOutputs": '{"answer": 42}', 255 }, 256 "events": [], 257 }) 258 ] 259 ), 260 ) 261 ] 262 trace = MlflowClient().get_trace(trace_id) 263 mock_store.batch_get_traces.assert_called_once_with([trace_id], "catalog.schema") 264 mock_artifact_repo.download_trace_data.assert_not_called() 265 266 assert trace.info.trace_id == trace_id 267 assert trace.info.trace_location.uc_schema.catalog_name == "catalog" 268 assert trace.info.trace_location.uc_schema.schema_name == "schema" 269 assert trace.info.timestamp_ms == 123 270 assert trace.info.execution_time_ms == 456 271 assert trace.info.status == TraceStatus.OK 272 assert trace.info.tags == {"mlflow.artifactLocation": "dbfs:/path/to/artifacts"} 273 assert trace.data.request == '{"prompt": "What is the meaning of life?"}' 274 assert trace.data.response == '{"answer": 42}' 275 assert len(trace.data.spans) == 1 276 assert trace.data.spans[0].name == "predict" 277 assert trace.data.spans[0].trace_id == trace_id 278 assert trace.data.spans[0].inputs == {"prompt": "What is the meaning of life?"} 279 assert trace.data.spans[0].outputs == {"answer": 42} 280 assert trace.data.spans[0].start_time_ns == 123000000 281 assert trace.data.spans[0].end_time_ns == 579000000 282 assert trace.data.spans[0].status.status_code == SpanStatusCode.OK 283 284 285 def test_client_get_trace_empty_result(mock_store): 286 mock_store.batch_get_traces.return_value = [] 287 with pytest.raises(MlflowException, match="not found"): 288 MlflowClient().get_trace("trace:/catalog.schema/123") 289 290 291 def test_client_get_trace_from_artifact_repo(mock_store, mock_artifact_repo): 292 mock_store.get_trace_info.return_value = TraceInfo( 293 trace_id="tr-1234567", 294 trace_location=TraceLocation.from_experiment_id("0"), 295 request_time=123, 296 execution_duration=456, 297 state=TraceState.OK, 298 tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts"}, 299 ) 300 mock_artifact_repo.download_trace_data.return_value = { 301 "request": '{"prompt": "What is the meaning of life?"}', 302 "response": '{"answer": 42}', 303 "spans": [ 304 { 305 "name": "predict", 306 "context": { 307 "trace_id": "0x123456789", 308 "span_id": "0x12345", 309 }, 310 "parent_id": None, 311 "start_time": 123000000, 312 "end_time": 579000000, 313 "status_code": "OK", 314 "status_message": "", 315 "attributes": { 316 "mlflow.traceRequestId": '"tr-1234567"', 317 "mlflow.spanType": '"LLM"', 318 "mlflow.spanFunctionName": '"predict"', 319 "mlflow.spanInputs": '{"prompt": "What is the meaning of life?"}', 320 "mlflow.spanOutputs": '{"answer": 42}', 321 }, 322 "events": [], 323 } 324 ], 325 } 326 trace = MlflowClient().get_trace("1234567") 327 mock_store.get_trace_info.assert_called_once_with("1234567") 328 mock_artifact_repo.download_trace_data.assert_called_once() 329 330 assert trace.info.trace_id == "tr-1234567" 331 assert trace.info.experiment_id == "0" 332 assert trace.info.timestamp_ms == 123 333 assert trace.info.execution_time_ms == 456 334 assert trace.info.status == TraceStatus.OK 335 assert trace.info.tags == {"mlflow.artifactLocation": "dbfs:/path/to/artifacts"} 336 assert trace.data.request == '{"prompt": "What is the meaning of life?"}' 337 assert trace.data.response == '{"answer": 42}' 338 assert len(trace.data.spans) == 1 339 assert trace.data.spans[0].name == "predict" 340 assert trace.data.spans[0].trace_id == "tr-1234567" 341 assert trace.data.spans[0].inputs == {"prompt": "What is the meaning of life?"} 342 assert trace.data.spans[0].outputs == {"answer": 42} 343 assert trace.data.spans[0].start_time_ns == 123000000 344 assert trace.data.spans[0].end_time_ns == 579000000 345 assert trace.data.spans[0].status.status_code == SpanStatusCode.OK 346 347 348 def test_client_get_trace_throws_for_missing_or_corrupted_data(mock_store, mock_artifact_repo): 349 mock_store.get_trace_info.return_value = TraceInfo( 350 trace_id="1234567", 351 trace_location=TraceLocation.from_experiment_id("0"), 352 request_time=123, 353 execution_duration=456, 354 state=TraceState.OK, 355 tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts"}, 356 ) 357 mock_artifact_repo.download_trace_data.side_effect = MlflowTraceDataNotFound("1234567") 358 359 with pytest.raises( 360 MlflowException, 361 match="Trace with ID 1234567 cannot be loaded because it is missing span data", 362 ): 363 MlflowClient().get_trace("1234567") 364 365 mock_artifact_repo.download_trace_data.side_effect = MlflowTraceDataCorrupted("1234567") 366 with pytest.raises( 367 MlflowException, 368 match="Trace with ID 1234567 cannot be loaded because its span data is corrupted", 369 ): 370 MlflowClient().get_trace("1234567") 371 372 373 @pytest.mark.parametrize("include_spans", [True, False]) 374 @pytest.mark.parametrize("num_results", [0, 5]) 375 def test_client_search_traces_with_get_traces( 376 mock_store, mock_artifact_repo, include_spans, num_results 377 ): 378 mock_trace_infos = [ 379 TraceInfo( 380 trace_id=f"trace:/catalog.schema/{i}", 381 trace_location=TraceLocation( 382 type=TraceLocationType.UC_SCHEMA, 383 uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"), 384 ), 385 request_time=123, 386 execution_duration=456, 387 state=TraceState.OK, 388 ) 389 for i in range(num_results) 390 ] 391 mock_store.search_traces.return_value = (mock_trace_infos, None) 392 mock_store.batch_get_traces.return_value = [ 393 Trace(info=info, data=TraceData(spans=[])) for info in mock_trace_infos 394 ] 395 396 results = MlflowClient().search_traces( 397 locations=["catalog.schema"], 398 include_spans=include_spans, 399 ) 400 mock_store.search_traces.assert_called_once_with( 401 experiment_ids=None, 402 filter_string=None, 403 max_results=100, 404 order_by=None, 405 page_token=None, 406 model_id=None, 407 locations=["catalog.schema"], 408 ) 409 assert len(results) == num_results 410 411 if include_spans and num_results > 0: 412 mock_store.batch_get_traces.assert_called_once_with( 413 [f"trace:/catalog.schema/{i}" for i in range(num_results)], 414 "catalog.schema", 415 ) 416 else: 417 mock_store.batch_get_traces.assert_not_called() 418 419 mock_artifact_repo.download_trace_data.assert_not_called() 420 421 # The TraceInfo is already fetched prior to the upload_trace_data call, 422 # so we should not call _get_trace_info again 423 mock_store.get_trace_info.assert_not_called() 424 425 426 def test_client_search_traces_with_large_results(mock_store, mock_artifact_repo): 427 mock_trace_infos = [ 428 TraceInfo( 429 trace_id=f"trace:/catalog.schema/{i}", 430 trace_location=TraceLocation( 431 type=TraceLocationType.UC_SCHEMA, 432 uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"), 433 ), 434 request_time=123, 435 execution_duration=456, 436 state=TraceState.OK, 437 ) 438 for i in range(100) 439 ] 440 mock_store.search_traces.return_value = (mock_trace_infos, None) 441 442 mock_store.batch_get_traces.return_value = [ 443 Trace(info=mock_trace_infos[0], data=TraceData(spans=[])) for i in range(10) 444 ] 445 446 results = MlflowClient().search_traces(locations=["catalog.schema"]) 447 mock_store.search_traces.assert_called_once_with( 448 experiment_ids=None, 449 filter_string=None, 450 max_results=100, 451 order_by=None, 452 page_token=None, 453 model_id=None, 454 locations=["catalog.schema"], 455 ) 456 assert len(results) == 100 457 assert mock_store.batch_get_traces.call_count == 10 458 assert mock_store.batch_get_traces.has_calls([ 459 mock.call([f"trace:/catalog.schema/{j * 10 + i}" for i in range(10)], "catalog.schema") 460 for j in range(10) 461 ]) 462 mock_artifact_repo.download_trace_data.assert_not_called() 463 464 465 @pytest.mark.parametrize("include_spans", [True, False]) 466 def test_client_search_traces_mixed(mock_store, mock_artifact_repo, include_spans): 467 mock_traces = [ 468 TraceInfo( 469 trace_id="1234567", 470 trace_location=TraceLocation( 471 type=TraceLocationType.UC_SCHEMA, 472 uc_schema=UCSchemaLocation(catalog_name="catalog", schema_name="schema"), 473 ), 474 request_time=123, 475 execution_duration=456, 476 state=TraceState.OK, 477 tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/1"}, 478 ), 479 TraceInfo( 480 trace_id="8910", 481 trace_location=TraceLocation.from_experiment_id("1"), 482 request_time=456, 483 execution_duration=789, 484 state=TraceState.OK, 485 tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/2"}, 486 ), 487 ] 488 mock_store.search_traces.return_value = (mock_traces, None) 489 mock_store.batch_get_traces.return_value = [ 490 Trace(info=mock_traces[0], data=TraceData(spans=[])) 491 ] 492 mock_artifact_repo.download_trace_data.return_value = {} 493 results = MlflowClient().search_traces( 494 locations=["1", "catalog.schema"], include_spans=include_spans 495 ) 496 497 mock_store.search_traces.assert_called_once_with( 498 experiment_ids=None, 499 filter_string=None, 500 max_results=100, 501 order_by=None, 502 page_token=None, 503 model_id=None, 504 locations=["1", "catalog.schema"], 505 ) 506 assert len(results) == 2 507 if include_spans: 508 mock_store.batch_get_traces.assert_called_once_with(["1234567"], "catalog.schema") 509 mock_artifact_repo.download_trace_data.assert_called() 510 else: 511 mock_store.batch_get_traces.assert_not_called() 512 mock_artifact_repo.download_trace_data.assert_not_called() 513 514 515 @pytest.mark.parametrize("include_spans", [True, False]) 516 @pytest.mark.parametrize("num_results", [0, 5]) 517 def test_client_search_traces_with_get_traces_tracking_store( 518 mock_store, mock_artifact_repo, include_spans, num_results 519 ): 520 mock_trace_infos = [ 521 TraceInfo( 522 trace_id=f"tr-123456789{i}", 523 trace_location=TraceLocation.from_experiment_id(f"exp-{i}"), 524 request_time=123, 525 execution_duration=456, 526 state=TraceState.OK, 527 tags={TraceTagKey.SPANS_LOCATION: SpansLocation.TRACKING_STORE}, 528 ) 529 for i in range(num_results) 530 ] 531 mock_store.search_traces.return_value = (mock_trace_infos, None) 532 mock_store.batch_get_traces.return_value = [ 533 Trace(info=info, data=TraceData(spans=[])) for info in mock_trace_infos 534 ] 535 536 results = MlflowClient().search_traces( 537 locations=["exp-0", "exp-1", "exp-2"], 538 include_spans=include_spans, 539 ) 540 mock_store.search_traces.assert_called_once_with( 541 experiment_ids=None, 542 filter_string=None, 543 max_results=100, 544 order_by=None, 545 page_token=None, 546 model_id=None, 547 locations=["exp-0", "exp-1", "exp-2"], 548 ) 549 assert len(results) == num_results 550 551 if include_spans and num_results > 0: 552 mock_store.batch_get_traces.assert_called_once_with( 553 [f"tr-123456789{i}" for i in range(num_results)], 554 None, 555 ) 556 else: 557 mock_store.batch_get_traces.assert_not_called() 558 559 mock_artifact_repo.download_trace_data.assert_not_called() 560 561 # The TraceInfo is already fetched prior to the upload_trace_data call, 562 # so we should not call _get_trace_info again 563 mock_store.get_trace_info.assert_not_called() 564 565 566 @pytest.mark.parametrize("include_spans", [True, False]) 567 def test_client_search_traces_with_artifact_repo(mock_store, mock_artifact_repo, include_spans): 568 mock_traces = [ 569 TraceInfo( 570 trace_id="tr-1234567", 571 trace_location=TraceLocation.from_experiment_id("1"), 572 request_time=123, 573 execution_duration=456, 574 state=TraceState.OK, 575 tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/1"}, 576 ), 577 TraceInfo( 578 trace_id="tr-8910", 579 trace_location=TraceLocation.from_experiment_id("2"), 580 request_time=456, 581 execution_duration=789, 582 state=TraceState.OK, 583 tags={"mlflow.artifactLocation": "dbfs:/path/to/artifacts/2"}, 584 ), 585 ] 586 mock_store.search_traces.return_value = (mock_traces, None) 587 mock_artifact_repo.download_trace_data.return_value = {} 588 results = MlflowClient().search_traces(locations=["1", "2", "3"], include_spans=include_spans) 589 590 mock_store.search_traces.assert_called_once_with( 591 experiment_ids=None, 592 filter_string=None, 593 max_results=100, 594 order_by=None, 595 page_token=None, 596 model_id=None, 597 locations=["1", "2", "3"], 598 ) 599 assert len(results) == 2 600 if include_spans: 601 mock_artifact_repo.download_trace_data.assert_called() 602 else: 603 mock_artifact_repo.download_trace_data.assert_not_called() 604 605 # The TraceInfo is already fetched prior to the upload_trace_data call, 606 # so we should not call _get_trace_info again 607 mock_store.get_trace_info.assert_not_called() 608 609 610 @pytest.mark.parametrize("include_spans", [True, False]) 611 def test_client_search_traces_trace_data_download_error(mock_store, include_spans): 612 class CustomArtifactRepository(ArtifactRepository): 613 def log_artifact(self, local_file, artifact_path=None): 614 raise NotImplementedError("Should not be called") 615 616 def log_artifacts(self, local_dir, artifact_path=None): 617 raise NotImplementedError("Should not be called") 618 619 def list_artifacts(self, path): 620 raise NotImplementedError("Should not be called") 621 622 def _download_file(self, *args, **kwargs): 623 raise Exception("Failed to download trace data") 624 625 with mock.patch( 626 "mlflow.tracing.client.get_artifact_repository", 627 return_value=CustomArtifactRepository("test"), 628 ) as mock_get_artifact_repository: 629 mock_traces = [ 630 TraceInfo( 631 trace_id="1234567", 632 trace_location=TraceLocation.from_experiment_id("1"), 633 request_time=123, 634 execution_duration=456, 635 state=TraceState.OK, 636 tags={"mlflow.artifactLocation": "test"}, 637 ), 638 ] 639 mock_store.search_traces.return_value = (mock_traces, None) 640 traces = MlflowClient().search_traces(locations=["1"], include_spans=include_spans) 641 642 if include_spans: 643 assert traces == [] 644 mock_get_artifact_repository.assert_called() 645 else: 646 assert len(traces) == 1 647 assert traces[0].info.trace_id == "1234567" 648 mock_get_artifact_repository.assert_not_called() 649 650 651 def test_client_search_traces_validates_experiment_ids_type(): 652 with pytest.raises(MlflowException, match=r"locations must be a list"): 653 MlflowClient().search_traces(locations=4) 654 655 with pytest.raises(MlflowException, match=r"locations must be a list"): 656 MlflowClient().search_traces(locations="4") 657 658 659 def test_client_delete_traces(mock_store): 660 MlflowClient().delete_traces( 661 experiment_id="0", 662 max_timestamp_millis=1, 663 max_traces=2, 664 trace_ids=["tr-1234"], 665 ) 666 mock_store.delete_traces.assert_called_once_with( 667 experiment_id="0", 668 max_timestamp_millis=1, 669 max_traces=2, 670 trace_ids=["tr-1234"], 671 ) 672 673 674 @pytest.fixture 675 def disable_prompt_cache(): 676 from mlflow.environment_variables import ( 677 MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS, 678 MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS, 679 ) 680 681 MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS.set(0) 682 MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS.set(0) 683 yield 684 MLFLOW_ALIAS_PROMPT_CACHE_TTL_SECONDS.unset() 685 MLFLOW_VERSION_PROMPT_CACHE_TTL_SECONDS.unset() 686 687 688 @pytest.fixture(autouse=True) 689 def reset_prompt_cache(): 690 PromptCache._reset_instance() 691 yield 692 PromptCache._reset_instance() 693 694 695 @pytest.fixture(params=["file", "sqlalchemy"]) 696 def tracking_uri(request, tmp_path, db_uri): 697 """Set an MLflow Tracking URI with different type of backend.""" 698 if request.param == "file": 699 pytest.skip("FileStore is no longer supported.") 700 if "MLFLOW_SKINNY" in os.environ and request.param == "sqlalchemy": 701 pytest.skip("SQLAlchemy store is not available in skinny.") 702 703 original_tracking_uri = mlflow.get_tracking_uri() 704 705 if request.param == "file": 706 tracking_uri = tmp_path.joinpath("file").as_uri() 707 elif request.param == "sqlalchemy": 708 tracking_uri = db_uri 709 710 # NB: MLflow tracer does not handle the change of tracking URI well, 711 # so we need to reset the tracer to switch the tracking URI during testing. 712 mlflow.tracing.disable() 713 mlflow.set_tracking_uri(tracking_uri) 714 mlflow.tracing.enable() 715 716 yield tracking_uri 717 718 # Reset tracking URI 719 mlflow.set_tracking_uri(original_tracking_uri) 720 721 722 @pytest.mark.parametrize("with_active_run", [True, False]) 723 def test_start_and_end_trace(tracking_uri, with_active_run, async_logging_enabled): 724 client = MlflowClient(tracking_uri) 725 726 experiment_id = client.create_experiment("test_experiment") 727 728 class TestModel: 729 def predict(self, x, y): 730 root_span = client.start_trace( 731 name="predict", 732 inputs={"x": x, "y": y}, 733 tags={"tag": "tag_value"}, 734 experiment_id=experiment_id, 735 ) 736 trace_id = root_span.trace_id 737 738 z = x + y 739 740 child_span = client.start_span( 741 "child_span_1", 742 span_type=SpanType.LLM, 743 trace_id=trace_id, 744 parent_id=root_span.span_id, 745 inputs={"z": z}, 746 ) 747 748 z = z + 2 749 750 client.end_span( 751 trace_id=trace_id, 752 span_id=child_span.span_id, 753 outputs={"output": z}, 754 attributes={"delta": 2}, 755 ) 756 757 res = self.square(z, trace_id, root_span.span_id) 758 client.end_trace(trace_id, outputs={"output": res}, status="OK") 759 return res 760 761 def square(self, t, trace_id, parent_id): 762 span = client.start_span( 763 "child_span_2", 764 trace_id=trace_id, 765 parent_id=parent_id, 766 inputs={"t": t}, 767 ) 768 769 res = t**2 770 time.sleep(0.1) 771 772 client.end_span( 773 trace_id=trace_id, 774 span_id=span.span_id, 775 outputs={"output": res}, 776 ) 777 return res 778 779 model = TestModel() 780 if with_active_run: 781 with mlflow.start_run(experiment_id=experiment_id) as run: 782 model.predict(1, 2) 783 run_id = run.info.run_id 784 else: 785 model.predict(1, 2) 786 787 if async_logging_enabled: 788 mlflow.flush_trace_async_logging(terminate=True) 789 790 trace_id = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True).info.trace_id 791 792 # Validate that trace is logged to the backend 793 trace = client.get_trace(trace_id) 794 assert trace is not None 795 796 assert trace.info.trace_id is not None 797 assert trace.info.execution_time_ms >= 0.1 * 1e3 # at least 0.1 sec 798 assert trace.info.status == TraceStatus.OK 799 assert trace.info.request_metadata[TraceMetadataKey.INPUTS] == '{"x": 1, "y": 2}' 800 assert trace.info.request_metadata[TraceMetadataKey.OUTPUTS] == '{"output": 25}' 801 if with_active_run: 802 assert trace.info.request_metadata[TraceMetadataKey.SOURCE_RUN] == run_id 803 assert trace.info.experiment_id == run.info.experiment_id 804 else: 805 assert trace.info.experiment_id == experiment_id 806 807 assert trace.data.request == '{"x": 1, "y": 2}' 808 assert trace.data.response == '{"output": 25}' 809 assert len(trace.data.spans) == 3 810 811 span_name_to_span = {span.name: span for span in trace.data.spans} 812 root_span = span_name_to_span["predict"] 813 # NB: Start time of root span and trace info does not match because there is some 814 # latency for starting the trace within the backend 815 # assert root_span.start_time_ns // 1e6 == trace.info.timestamp_ms 816 assert root_span.parent_id is None 817 assert root_span.attributes == { 818 "mlflow.experimentId": experiment_id, 819 "mlflow.traceRequestId": trace.info.trace_id, 820 "mlflow.spanType": "UNKNOWN", 821 "mlflow.spanInputs": {"x": 1, "y": 2}, 822 "mlflow.spanOutputs": {"output": 25}, 823 } 824 825 child_span_1 = span_name_to_span["child_span_1"] 826 assert child_span_1.parent_id == root_span.span_id 827 assert child_span_1.attributes == { 828 "mlflow.traceRequestId": trace.info.trace_id, 829 "mlflow.spanType": "LLM", 830 "mlflow.spanInputs": {"z": 3}, 831 "mlflow.spanOutputs": {"output": 5}, 832 "delta": 2, 833 } 834 835 child_span_2 = span_name_to_span["child_span_2"] 836 assert child_span_2.parent_id == root_span.span_id 837 assert child_span_2.attributes == { 838 "mlflow.traceRequestId": trace.info.trace_id, 839 "mlflow.spanType": "UNKNOWN", 840 "mlflow.spanInputs": {"t": 5}, 841 "mlflow.spanOutputs": {"output": 25}, 842 } 843 assert child_span_2.start_time_ns <= child_span_2.end_time_ns - 0.1 * 1e6 844 845 846 def test_start_and_end_trace_capture_falsy_input_and_output(tracking_uri): 847 # This test is to verify that falsy input and output values are correctly logged 848 client = MlflowClient(tracking_uri) 849 experiment_id = client.create_experiment("test_experiment") 850 851 root = client.start_trace(name="root", experiment_id=experiment_id, inputs=[]) 852 span = client.start_span(name="child", trace_id=root.trace_id, parent_id=root.span_id, inputs=0) 853 client.end_span(trace_id=root.trace_id, span_id=span.span_id, outputs=False) 854 client.end_trace(trace_id=root.trace_id, outputs="") 855 856 trace = client.get_trace(root.trace_id, flush=True) 857 assert trace.data.spans[0].inputs == [] 858 assert trace.data.spans[0].outputs == "" 859 assert trace.data.spans[1].inputs == 0 860 assert trace.data.spans[1].outputs is False 861 862 863 # TODO: we should investigate whether we need to support this 864 @pytest.mark.skip(reason="This is not supported by latest span-level export") 865 @pytest.mark.usefixtures("reset_active_experiment") 866 def test_start_and_end_trace_before_all_span_end(async_logging_enabled): 867 # This test is to verify that the trace is still exported even if some spans are not ended 868 exp_id = mlflow.set_experiment("test_experiment_1").experiment_id 869 870 class TestModel: 871 def __init__(self): 872 self._client = MlflowClient() 873 874 def predict(self, x): 875 root_span = self._client.start_trace(name="predict") 876 trace_id = root_span.trace_id 877 child_span = self._client.start_span( 878 "ended-span", 879 trace_id=trace_id, 880 parent_id=root_span.span_id, 881 ) 882 time.sleep(0.1) 883 self._client.end_span(trace_id, child_span.span_id) 884 885 res = self.square(x, trace_id, root_span.span_id) 886 self._client.end_trace(trace_id) 887 return res 888 889 def square(self, t, trace_id, parent_id): 890 self._client.start_span("non-ended-span", trace_id=trace_id, parent_id=parent_id) 891 time.sleep(0.1) 892 # The span created above is not ended 893 return t**2 894 895 model = TestModel() 896 model.predict(1) 897 898 if async_logging_enabled: 899 mlflow.flush_trace_async_logging(terminate=True) 900 901 traces = MlflowClient().search_traces(locations=[exp_id]) 902 assert len(traces) == 1 903 904 trace_info = traces[0].info 905 assert trace_info.trace_id is not None 906 assert trace_info.experiment_id == exp_id 907 assert trace_info.timestamp_ms is not None 908 assert trace_info.execution_time_ms is not None 909 assert trace_info.status == TraceStatus.OK 910 911 trace_data = traces[0].data 912 assert trace_data.request is None 913 assert trace_data.response is None 914 assert len(trace_data.spans) == 3 # The non-ended span should be also included in the trace 915 916 span_name_to_span = {span.name: span for span in trace_data.spans} 917 root_span = span_name_to_span["predict"] 918 assert root_span.parent_id is None 919 assert root_span.status.status_code == SpanStatusCode.OK 920 921 ended_span = span_name_to_span["ended-span"] 922 assert ended_span.parent_id == root_span.span_id 923 assert ended_span.start_time_ns < ended_span.end_time_ns 924 assert ended_span.status.status_code == SpanStatusCode.OK 925 926 # The non-ended span should have null end_time and UNSET status 927 non_ended_span = span_name_to_span["non-ended-span"] 928 assert non_ended_span.parent_id == root_span.span_id 929 assert non_ended_span.start_time_ns is not None 930 assert non_ended_span.end_time_ns is None 931 assert non_ended_span.status.status_code == SpanStatusCode.UNSET 932 933 934 def test_log_trace_with_databricks_tracking_uri(mock_store_start_trace, monkeypatch): 935 monkeypatch.setenv("MLFLOW_EXPERIMENT_NAME", "test") 936 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob") 937 938 mock_experiment = mock.MagicMock() 939 mock_experiment.experiment_id = "test_experiment_id" 940 941 class TestModel: 942 def __init__(self): 943 self._client = MlflowClient() 944 945 def predict(self, x, y): 946 root_span = self._client.start_trace( 947 name="predict", 948 inputs={"x": x, "y": y}, 949 # Trying to override mlflow.user tag, which will be ignored 950 tags={"tag": "tag_value", "mlflow.user": "unknown"}, 951 ) 952 trace_id = root_span.trace_id 953 954 z = x + y 955 956 child_span = self._client.start_span( 957 "child_span_1", 958 span_type=SpanType.LLM, 959 trace_id=trace_id, 960 parent_id=root_span.span_id, 961 inputs={"z": z}, 962 ) 963 964 z = z + 2 965 966 self._client.end_span( 967 trace_id=trace_id, 968 span_id=child_span.span_id, 969 outputs={"output": z}, 970 attributes={"delta": 2}, 971 ) 972 self._client.end_trace(trace_id, outputs=z, status="OK") 973 return z 974 975 model = TestModel() 976 977 with ( 978 mock.patch("mlflow.get_tracking_uri", return_value="databricks"), 979 mock.patch("mlflow.tracking.context.default_context._get_source_name", return_value="test"), 980 mock.patch( 981 "mlflow.tracing.client.TracingClient._upload_trace_data" 982 ) as mock_upload_trace_data, 983 mock.patch( 984 "mlflow.tracing.client.TracingClient.set_trace_tags", 985 ), 986 mock.patch( 987 "mlflow.tracing.client.TracingClient.set_trace_tag", 988 ), 989 mock.patch( 990 "mlflow.tracing.client.TracingClient.get_trace_info", 991 ), 992 mock.patch( 993 "mlflow.MlflowClient.get_experiment_by_name", 994 return_value=mock_experiment, 995 ), 996 ): 997 model.predict(1, 2) 998 mlflow.flush_trace_async_logging(terminate=True) 999 1000 mock_store_start_trace.assert_called_once() 1001 mock_upload_trace_data.assert_called_once() 1002 1003 1004 def test_start_and_end_trace_does_not_log_trace_when_disabled( 1005 tracking_uri, monkeypatch, async_logging_enabled 1006 ): 1007 client = MlflowClient(tracking_uri) 1008 experiment_id = client.create_experiment("test_experiment") 1009 1010 @trace_disabled 1011 def func(): 1012 span = client.start_trace( 1013 name="predict", 1014 experiment_id=experiment_id, 1015 inputs={"x": 1, "y": 2}, 1016 attributes={"attr": "value"}, 1017 tags={"tag": "tag_value"}, 1018 ) 1019 child_span = client.start_span( 1020 "child_span_1", 1021 trace_id=span.trace_id, 1022 parent_id=span.span_id, 1023 ) 1024 client.end_span( 1025 trace_id=span.trace_id, 1026 span_id=child_span.span_id, 1027 outputs={"output": 5}, 1028 ) 1029 client.end_trace(span.trace_id, outputs=5, status="OK") 1030 return "done" 1031 1032 mock_logger = mock.MagicMock() 1033 monkeypatch.setattr(mlflow.tracking.client, "_logger", mock_logger) 1034 1035 res = func() 1036 1037 assert res == "done" 1038 assert client.search_traces(locations=[experiment_id]) == [] 1039 # No warning should be issued 1040 mock_logger.warning.assert_not_called() 1041 1042 1043 def test_start_trace_within_active_run(async_logging_enabled): 1044 exp_id = mlflow.create_experiment("test") 1045 1046 client = mlflow.MlflowClient() 1047 with mlflow.start_run(): 1048 root_span = client.start_trace( 1049 name="test", 1050 experiment_id=exp_id, 1051 ) 1052 client.end_trace(root_span.trace_id) 1053 1054 if async_logging_enabled: 1055 mlflow.flush_trace_async_logging(terminate=True) 1056 1057 traces = client.search_traces(locations=[exp_id]) 1058 assert len(traces) == 1 1059 assert traces[0].info.experiment_id == exp_id 1060 1061 1062 def test_start_trace_raise_error_when_active_trace_exists(): 1063 with mlflow.start_span("fluent_span"): 1064 with pytest.raises(MlflowException, match=r"Another trace is already set in the global"): 1065 mlflow.tracking.MlflowClient().start_trace("test") 1066 1067 1068 def test_end_trace_raise_error_when_trace_not_exist(): 1069 client = mlflow.tracking.MlflowClient() 1070 mock_tracing_client = mock.MagicMock() 1071 mock_tracing_client.get_trace.return_value = None 1072 client._tracing_client = mock_tracing_client 1073 1074 with pytest.raises(MlflowException, match=r"Trace with ID test not found"): 1075 client.end_trace("test") 1076 1077 1078 def test_end_trace_works_for_trace_in_pending_status(): 1079 client = mlflow.tracking.MlflowClient() 1080 mock_tracing_client = mock.MagicMock() 1081 mock_tracing_client.get_trace.return_value = Trace( 1082 info=create_test_trace_info("test", state=TraceState.IN_PROGRESS), data=TraceData() 1083 ) 1084 client._tracing_client = mock_tracing_client 1085 client.end_span = lambda *args: None 1086 1087 client.end_trace("test") 1088 1089 1090 @pytest.mark.parametrize("state", [TraceState.OK, TraceState.ERROR]) 1091 def test_end_trace_raise_error_for_trace_in_end_status(state): 1092 client = mlflow.tracking.MlflowClient() 1093 mock_tracing_client = mock.MagicMock() 1094 mock_tracing_client.get_trace.return_value = Trace( 1095 info=create_test_trace_info("test", state=state), data=TraceData() 1096 ) 1097 client._tracing_client = mock_tracing_client 1098 1099 with pytest.raises(MlflowException, match=r"Trace with ID test already finished"): 1100 client.end_trace("test") 1101 1102 1103 def test_trace_status_either_pending_or_end(): 1104 all_statuses = {status.value for status in TraceStatus} 1105 pending_or_end_statuses = TraceStatus.pending_statuses() | TraceStatus.end_statuses() 1106 unclassified_statuses = all_statuses - pending_or_end_statuses 1107 assert len(unclassified_statuses) == 0, ( 1108 f"Please add {unclassified_statuses} to " 1109 "either pending_statuses or end_statuses in TraceStatus class definition" 1110 ) 1111 1112 1113 def test_start_span_raise_error_when_parent_id_is_not_provided(): 1114 with pytest.raises(MlflowException, match=r"start_span\(\) must be called with"): 1115 mlflow.tracking.MlflowClient().start_span("span_name", trace_id="test", parent_id=None) 1116 1117 1118 def test_log_trace(tracking_uri): 1119 client = MlflowClient(tracking_uri) 1120 experiment_id = client.create_experiment("test_experiment") 1121 1122 span = client.start_trace( 1123 name="test", 1124 span_type=SpanType.LLM, 1125 experiment_id=experiment_id, 1126 tags={"custom_tag": "tag_value"}, 1127 ) 1128 client.end_trace(span.trace_id, status="OK") 1129 1130 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1131 1132 # Purge all traces in the backend once 1133 client.delete_traces(experiment_id=experiment_id, trace_ids=[trace.info.trace_id]) 1134 assert client.search_traces(locations=[experiment_id]) == [] 1135 1136 # Log the trace manually — _log_trace triggers async export via span processor 1137 new_trace_id = client._log_trace(trace) 1138 1139 # Validate the trace is added to the backend (flush=True waits for async writes) 1140 backend_traces = client.search_traces(locations=[experiment_id], flush=True) 1141 assert len(backend_traces) == 1 1142 assert backend_traces[0].info.trace_id == new_trace_id # new request ID is assigned 1143 assert backend_traces[0].info.experiment_id == experiment_id 1144 assert backend_traces[0].info.status == trace.info.status 1145 assert backend_traces[0].info.tags["custom_tag"] == "tag_value" 1146 assert backend_traces[0].info.request_preview == trace.info.request_preview 1147 assert backend_traces[0].info.response_preview == trace.info.response_preview 1148 assert len(backend_traces[0].data.spans) == len(trace.data.spans) 1149 assert backend_traces[0].data.spans[0].name == trace.data.spans[0].name 1150 1151 # If the experiment ID is None in the given trace, it should be set to the default experiment 1152 trace.info.experiment_id = None 1153 new_trace_id = client._log_trace(trace) 1154 backend_traces = client.search_traces(locations=[DEFAULT_EXPERIMENT_ID], flush=True) 1155 assert len(backend_traces) == 1 1156 1157 1158 @pytest.mark.filterwarnings("ignore::FutureWarning") 1159 def test_search_traces_experiment_ids_deprecation_warning(): 1160 client = MlflowClient() 1161 exp_id = mlflow.set_experiment("test_experiment_deprecation").experiment_id 1162 1163 # Test that using experiment_ids shows a deprecation warning 1164 with pytest.warns(FutureWarning, match="experiment_ids.*deprecated.*use.*locations"): 1165 result = client.search_traces(experiment_ids=[exp_id]) 1166 assert isinstance(result, list) 1167 1168 1169 def test_ignore_exception_from_tracing_logic(monkeypatch, async_logging_enabled): 1170 exp_id = mlflow.set_experiment("test_experiment_1").experiment_id 1171 client = MlflowClient() 1172 1173 class TestModel: 1174 def predict(self, x): 1175 root_span = client.start_trace(experiment_id=exp_id, name="predict") 1176 trace_id = root_span.trace_id 1177 child_span = client.start_span( 1178 name="child", trace_id=trace_id, parent_id=root_span.span_id 1179 ) 1180 client.end_span(trace_id, child_span.span_id) 1181 client.end_trace(trace_id) 1182 return x 1183 1184 model = TestModel() 1185 1186 # Mock the span processor's on_end handler to raise an exception 1187 processor = _get_tracer(__name__).span_processor 1188 1189 def _always_fail(*args, **kwargs): 1190 raise ValueError("Some error") 1191 1192 # Exception while starting the trace should be caught not raise 1193 monkeypatch.setattr(processor, "on_start", _always_fail) 1194 response = model.predict(1) 1195 assert response == 1 1196 assert len(get_traces()) == 0 1197 1198 # Exception while ending the trace should be caught not raise 1199 monkeypatch.setattr(processor, "on_end", _always_fail) 1200 response = model.predict(1) 1201 assert response == 1 1202 assert len(get_traces()) == 0 1203 1204 1205 def test_set_and_delete_trace_tag_on_active_trace(monkeypatch): 1206 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob") 1207 monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") 1208 1209 client = mlflow.tracking.MlflowClient() 1210 1211 root_span = client.start_trace(name="test") 1212 trace_id = root_span.trace_id 1213 client.set_trace_tag(trace_id, "foo", "bar") 1214 client.end_trace(trace_id) 1215 1216 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1217 assert trace.info.tags["foo"] == "bar" 1218 1219 1220 def test_set_trace_tag_on_logged_trace(mock_store): 1221 mlflow.tracking.MlflowClient().set_trace_tag("test", "foo", "bar") 1222 mlflow.tracking.MlflowClient().set_trace_tag("test", "mlflow.some.reserved.tag", "value") 1223 mock_store.set_trace_tag.assert_has_calls([ 1224 mock.call("test", "foo", "bar"), 1225 mock.call("test", "mlflow.some.reserved.tag", "value"), 1226 ]) 1227 1228 1229 def test_delete_trace_tag_on_active_trace(monkeypatch): 1230 monkeypatch.setenv(MLFLOW_TRACKING_USERNAME.name, "bob") 1231 monkeypatch.setattr(mlflow.tracking.context.default_context, "_get_source_name", lambda: "test") 1232 1233 client = mlflow.tracking.MlflowClient() 1234 root_span = client.start_trace(name="test", tags={"foo": "bar", "baz": "qux"}) 1235 trace_id = root_span.trace_id 1236 client.delete_trace_tag(trace_id, "foo") 1237 client.end_trace(trace_id) 1238 1239 trace = mlflow.get_trace(mlflow.get_last_active_trace_id(), flush=True) 1240 assert "baz" in trace.info.tags 1241 assert "foo" not in trace.info.tags 1242 1243 1244 def test_delete_trace_tag_on_logged_trace(mock_store): 1245 mlflow.tracking.MlflowClient().delete_trace_tag("test", "foo") 1246 mock_store.delete_trace_tag.assert_called_once_with("test", "foo") 1247 1248 1249 def test_client_create_experiment(mock_store): 1250 MlflowClient().create_experiment("someName", "someLocation", {"key1": "val1", "key2": "val2"}) 1251 1252 mock_store.create_experiment.assert_called_once_with( 1253 artifact_location="someLocation", 1254 tags=[ExperimentTag("key1", "val1"), ExperimentTag("key2", "val2")], 1255 name="someName", 1256 ) 1257 1258 1259 def test_client_create_run_overrides(mock_store): 1260 experiment_id = mock.Mock() 1261 user = mock.Mock() 1262 start_time = mock.Mock() 1263 run_name = mock.Mock() 1264 tags = { 1265 MLFLOW_USER: user, 1266 MLFLOW_PARENT_RUN_ID: mock.Mock(), 1267 MLFLOW_SOURCE_TYPE: SourceType.to_string(SourceType.JOB), 1268 MLFLOW_SOURCE_NAME: mock.Mock(), 1269 MLFLOW_PROJECT_ENTRY_POINT: mock.Mock(), 1270 MLFLOW_GIT_COMMIT: mock.Mock(), 1271 "other-key": "other-value", 1272 } 1273 1274 MlflowClient().create_run(experiment_id, start_time, tags, run_name) 1275 1276 mock_store.create_run.assert_called_once_with( 1277 experiment_id=experiment_id, 1278 user_id=user, 1279 start_time=start_time, 1280 tags=[RunTag(key, value) for key, value in tags.items()], 1281 run_name=run_name, 1282 ) 1283 mock_store.reset_mock() 1284 MlflowClient().create_run(experiment_id, start_time, tags) 1285 mock_store.create_run.assert_called_once_with( 1286 experiment_id=experiment_id, 1287 user_id=user, 1288 start_time=start_time, 1289 tags=[RunTag(key, value) for key, value in tags.items()], 1290 run_name=None, 1291 ) 1292 1293 1294 def test_client_set_terminated_no_change_name(mock_store): 1295 experiment_id = mock.Mock() 1296 run = MlflowClient().create_run(experiment_id, run_name="my name") 1297 MlflowClient().set_terminated(run.info.run_id) 1298 _, kwargs = mock_store.update_run_info.call_args 1299 assert kwargs["run_name"] is None 1300 1301 1302 def test_client_search_runs_defaults(mock_store): 1303 MlflowClient().search_runs([1, 2, 3]) 1304 mock_store.search_runs.assert_called_once_with( 1305 experiment_ids=[1, 2, 3], 1306 filter_string="", 1307 run_view_type=ViewType.ACTIVE_ONLY, 1308 max_results=SEARCH_MAX_RESULTS_DEFAULT, 1309 order_by=None, 1310 page_token=None, 1311 ) 1312 1313 1314 def test_client_search_runs_filter(mock_store): 1315 MlflowClient().search_runs(["a", "b", "c"], "my filter") 1316 mock_store.search_runs.assert_called_once_with( 1317 experiment_ids=["a", "b", "c"], 1318 filter_string="my filter", 1319 run_view_type=ViewType.ACTIVE_ONLY, 1320 max_results=SEARCH_MAX_RESULTS_DEFAULT, 1321 order_by=None, 1322 page_token=None, 1323 ) 1324 1325 1326 def test_client_search_runs_view_type(mock_store): 1327 MlflowClient().search_runs(["a", "b", "c"], "my filter", ViewType.DELETED_ONLY) 1328 mock_store.search_runs.assert_called_once_with( 1329 experiment_ids=["a", "b", "c"], 1330 filter_string="my filter", 1331 run_view_type=ViewType.DELETED_ONLY, 1332 max_results=SEARCH_MAX_RESULTS_DEFAULT, 1333 order_by=None, 1334 page_token=None, 1335 ) 1336 1337 1338 def test_client_search_runs_max_results(mock_store): 1339 MlflowClient().search_runs([5], "my filter", ViewType.ALL, 2876) 1340 mock_store.search_runs.assert_called_once_with( 1341 experiment_ids=[5], 1342 filter_string="my filter", 1343 run_view_type=ViewType.ALL, 1344 max_results=2876, 1345 order_by=None, 1346 page_token=None, 1347 ) 1348 1349 1350 def test_client_search_runs_int_experiment_id(mock_store): 1351 MlflowClient().search_runs(123) 1352 mock_store.search_runs.assert_called_once_with( 1353 experiment_ids=[123], 1354 filter_string="", 1355 run_view_type=ViewType.ACTIVE_ONLY, 1356 max_results=SEARCH_MAX_RESULTS_DEFAULT, 1357 order_by=None, 1358 page_token=None, 1359 ) 1360 1361 1362 def test_client_search_runs_string_experiment_id(mock_store): 1363 MlflowClient().search_runs("abc") 1364 mock_store.search_runs.assert_called_once_with( 1365 experiment_ids=["abc"], 1366 filter_string="", 1367 run_view_type=ViewType.ACTIVE_ONLY, 1368 max_results=SEARCH_MAX_RESULTS_DEFAULT, 1369 order_by=None, 1370 page_token=None, 1371 ) 1372 1373 1374 def test_client_search_runs_order_by(mock_store): 1375 MlflowClient().search_runs([5], order_by=["a", "b"]) 1376 mock_store.search_runs.assert_called_once_with( 1377 experiment_ids=[5], 1378 filter_string="", 1379 run_view_type=ViewType.ACTIVE_ONLY, 1380 max_results=SEARCH_MAX_RESULTS_DEFAULT, 1381 order_by=["a", "b"], 1382 page_token=None, 1383 ) 1384 1385 1386 def test_client_search_runs_page_token(mock_store): 1387 MlflowClient().search_runs([5], page_token="blah") 1388 mock_store.search_runs.assert_called_once_with( 1389 experiment_ids=[5], 1390 filter_string="", 1391 run_view_type=ViewType.ACTIVE_ONLY, 1392 max_results=SEARCH_MAX_RESULTS_DEFAULT, 1393 order_by=None, 1394 page_token="blah", 1395 ) 1396 1397 1398 def test_update_registered_model(mock_registry_store): 1399 """ 1400 Update registered model no longer supports name change. 1401 """ 1402 expected_return_value = "some expected return value." 1403 mock_registry_store.rename_registered_model.return_value = expected_return_value 1404 expected_return_value_2 = "other expected return value." 1405 mock_registry_store.update_registered_model.return_value = expected_return_value_2 1406 res = MlflowClient(registry_uri="sqlite:///somedb.db").update_registered_model( 1407 name="orig name", description="new description" 1408 ) 1409 assert expected_return_value_2 == res 1410 mock_registry_store.update_registered_model.assert_called_once_with( 1411 name="orig name", description="new description", deployment_job_id=None 1412 ) 1413 mock_registry_store.rename_registered_model.assert_not_called() 1414 1415 1416 def test_create_model_version(mock_registry_store): 1417 """ 1418 Basic test for create model version. 1419 """ 1420 mock_registry_store.create_model_version.return_value = _default_model_version() 1421 res = MlflowClient(registry_uri="sqlite:///somedb.db").create_model_version( 1422 "orig name", "source", "run-id", tags={"key": "value"}, description="desc" 1423 ) 1424 assert res == _default_model_version() 1425 mock_registry_store.create_model_version.assert_called_once_with( 1426 "orig name", 1427 "source", 1428 "run-id", 1429 [ModelVersionTag(key="key", value="value")], 1430 None, 1431 "desc", 1432 local_model_path=None, 1433 model_id=None, 1434 ) 1435 1436 1437 def test_update_model_version(mock_registry_store): 1438 """ 1439 Update registered model no longer support state changes. 1440 """ 1441 mock_registry_store.update_model_version.return_value = _default_model_version() 1442 res = MlflowClient(registry_uri="sqlite:///somedb.db").update_model_version( 1443 name="orig name", version="1", description="desc" 1444 ) 1445 assert _default_model_version() == res 1446 mock_registry_store.update_model_version.assert_called_once_with( 1447 name="orig name", version="1", description="desc" 1448 ) 1449 mock_registry_store.transition_model_version_stage.assert_not_called() 1450 1451 1452 def test_transition_model_version_stage(mock_registry_store): 1453 name = "Model 1" 1454 version = "12" 1455 stage = "Production" 1456 expected_result = ModelVersion(name, version, creation_timestamp=123, current_stage=stage) 1457 mock_registry_store.transition_model_version_stage.return_value = expected_result 1458 actual_result = MlflowClient(registry_uri="sqlite:///somedb.db").transition_model_version_stage( 1459 name, version, stage 1460 ) 1461 mock_registry_store.transition_model_version_stage.assert_called_once_with( 1462 name=name, version=version, stage=stage, archive_existing_versions=False 1463 ) 1464 assert expected_result == actual_result 1465 1466 1467 def test_registry_uri_set_as_param(): 1468 uri = "sqlite:///somedb.db" 1469 client = MlflowClient(tracking_uri="databricks://tracking", registry_uri=uri) 1470 assert client._registry_uri == uri 1471 1472 1473 def test_registry_uri_from_set_registry_uri(): 1474 uri = "sqlite:///somedb.db" 1475 set_registry_uri(uri) 1476 client = MlflowClient(tracking_uri="databricks://tracking") 1477 assert client._registry_uri == uri 1478 set_registry_uri(None) 1479 1480 1481 def test_registry_uri_from_tracking_uri_param(): 1482 tracking_uri = "databricks://tracking_vhawoierj" 1483 with mock.patch( 1484 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 1485 return_value=tracking_uri, 1486 ): 1487 client = MlflowClient(tracking_uri=tracking_uri) 1488 # For databricks tracking URIs, registry URI defaults to Unity Catalog with profile 1489 assert client._registry_uri == "databricks-uc://tracking_vhawoierj" 1490 1491 1492 def test_registry_uri_from_implicit_tracking_uri(): 1493 tracking_uri = "databricks://tracking_wierojasdf" 1494 with mock.patch( 1495 "mlflow.tracking._tracking_service.utils.get_tracking_uri", 1496 return_value=tracking_uri, 1497 ): 1498 client = MlflowClient() 1499 # For databricks tracking URIs, registry URI defaults to Unity Catalog with profile 1500 assert client._registry_uri == "databricks-uc://tracking_wierojasdf" 1501 1502 1503 def test_create_model_version_nondatabricks_source_no_runlink(mock_registry_store): 1504 run_id = "runid" 1505 client = MlflowClient(tracking_uri="http://10.123.1231.11") 1506 mock_registry_store.create_model_version.return_value = ModelVersion( 1507 "name", 1508 1, 1509 0, 1510 1, 1511 source="source", 1512 run_id=run_id, 1513 ) 1514 model_version = client.create_model_version("name", "source", "runid") 1515 assert model_version.name == "name" 1516 assert model_version.source == "source" 1517 assert model_version.run_id == "runid" 1518 # verify that the store was not provided a run link 1519 mock_registry_store.create_model_version.assert_called_once_with( 1520 "name", "source", "runid", [], None, None, local_model_path=None, model_id=None 1521 ) 1522 1523 1524 def test_create_model_version_nondatabricks_source_no_run_id(mock_registry_store): 1525 client = MlflowClient(tracking_uri="http://10.123.1231.11") 1526 mock_registry_store.create_model_version.return_value = ModelVersion( 1527 "name", 1, 0, 1, source="source" 1528 ) 1529 model_version = client.create_model_version("name", "source") 1530 assert model_version.name == "name" 1531 assert model_version.source == "source" 1532 assert model_version.run_id is None 1533 # verify that the store was not provided a run id 1534 mock_registry_store.create_model_version.assert_called_once_with( 1535 "name", "source", None, [], None, None, local_model_path=None, model_id=None 1536 ) 1537 1538 1539 def test_create_model_version_explicitly_set_run_link( 1540 mock_registry_store, mock_databricks_tracking_store 1541 ): 1542 run_link = "my-run-link" 1543 hostname = "https://workspace.databricks.com/" 1544 workspace_id = "10002" 1545 mock_registry_store.create_model_version.return_value = ModelVersion( 1546 "name", 1547 1, 1548 0, 1549 1, 1550 source="source", 1551 run_id=mock_databricks_tracking_store.run_id, 1552 run_link=run_link, 1553 ) 1554 1555 # mocks to make sure that even if you're in a notebook, this setting is respected. 1556 with ( 1557 mock.patch("mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True), 1558 mock.patch( 1559 "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils", 1560 return_value=(hostname, workspace_id), 1561 ), 1562 ): 1563 client = MlflowClient(tracking_uri="databricks", registry_uri="otherplace") 1564 model_version = client.create_model_version("name", "source", "runid", run_link=run_link) 1565 assert model_version.run_link == run_link 1566 # verify that the store was provided with the explicitly passed in run link 1567 mock_registry_store.create_model_version.assert_called_once_with( 1568 "name", "source", "runid", [], run_link, None, local_model_path=None, model_id=None 1569 ) 1570 1571 1572 def test_create_model_version_run_link_in_notebook_with_default_profile( 1573 mock_registry_store, mock_databricks_tracking_store 1574 ): 1575 hostname = "https://workspace.databricks.com/" 1576 workspace_id = "10002" 1577 workspace_url = _construct_databricks_run_url( 1578 hostname, 1579 mock_databricks_tracking_store.experiment_id, 1580 mock_databricks_tracking_store.run_id, 1581 workspace_id, 1582 ) 1583 1584 with ( 1585 mock.patch("mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=True), 1586 mock.patch( 1587 "mlflow.utils.databricks_utils.get_workspace_info_from_dbutils", 1588 return_value=(hostname, workspace_id), 1589 ), 1590 ): 1591 client = MlflowClient(tracking_uri="databricks", registry_uri="otherplace") 1592 mock_registry_store.create_model_version.return_value = ModelVersion( 1593 "name", 1594 1, 1595 0, 1596 1, 1597 source="source", 1598 run_id=mock_databricks_tracking_store.run_id, 1599 run_link=workspace_url, 1600 ) 1601 model_version = client.create_model_version("name", "source", "runid") 1602 assert model_version.run_link == workspace_url 1603 # verify that the client generated the right URL 1604 mock_registry_store.create_model_version.assert_called_once_with( 1605 "name", "source", "runid", [], workspace_url, None, local_model_path=None, model_id=None 1606 ) 1607 1608 1609 def test_create_model_version_with_source(mock_registry_store, mock_databricks_tracking_store): 1610 model_id = "model_id" 1611 mock_registry_store.create_model_version.return_value = ModelVersion( 1612 "name", 1613 1, 1614 0, 1615 1, 1616 source="/path/to/source", 1617 run_id=mock_databricks_tracking_store.run_id, 1618 run_link=None, 1619 model_id=model_id, 1620 ) 1621 mock_logged_model = LoggedModel( 1622 experiment_id="exp_id", 1623 model_id="model_id", 1624 name="name", 1625 artifact_location="/path/to/source", 1626 creation_timestamp=0, 1627 last_updated_timestamp=0, 1628 ) 1629 1630 with mock.patch( 1631 "mlflow.tracking.client.MlflowClient.get_logged_model", return_value=mock_logged_model 1632 ): 1633 client = MlflowClient(tracking_uri="databricks") 1634 model_version = client.create_model_version( 1635 "name", f"models:/{model_id}", "runid", run_link=None, model_id=model_id 1636 ) 1637 assert model_version.model_id == model_id 1638 mock_registry_store.create_model_version.assert_called_once_with( 1639 "name", 1640 f"models:/{model_id}", 1641 "runid", 1642 [], 1643 None, 1644 None, 1645 local_model_path=None, 1646 model_id="model_id", 1647 ) 1648 1649 mock_registry_store.create_model_version.reset_mock() 1650 with mock.patch( 1651 "mlflow.tracking.client.MlflowClient.get_logged_model", return_value=mock_logged_model 1652 ): 1653 client = MlflowClient(tracking_uri="databricks", registry_uri="databricks-uc") 1654 model_version = client.create_model_version( 1655 "name", f"models:/{model_id}", "runid", run_link=None, model_id=model_id 1656 ) 1657 assert model_version.model_id == model_id 1658 mock_registry_store.create_model_version.assert_called_once_with( 1659 "name", 1660 f"models:/{model_id}", 1661 "runid", 1662 [], 1663 None, 1664 None, 1665 local_model_path=None, 1666 model_id="model_id", 1667 ) 1668 1669 1670 def test_create_model_version_with_nondatabricks_source_uc_registry(mock_registry_store): 1671 name = "name" 1672 model_id = "model_id" 1673 run_id = "runid" 1674 source = "/path/to/source" 1675 model_uri = f"models:/{model_id}" 1676 mock_registry_store.create_model_version.return_value = ModelVersion( 1677 "name", 1678 1, 1679 0, 1680 1, 1681 source=source, 1682 run_id=run_id, 1683 run_link=None, 1684 model_id=model_id, 1685 ) 1686 mock_logged_model = LoggedModel( 1687 experiment_id="exp_id", 1688 model_id=model_id, 1689 name=name, 1690 artifact_location=source, 1691 creation_timestamp=0, 1692 last_updated_timestamp=0, 1693 ) 1694 1695 with mock.patch( 1696 "mlflow.tracking.client.MlflowClient.get_logged_model", return_value=mock_logged_model 1697 ): 1698 client = MlflowClient(tracking_uri="http://10.123.1231.11", registry_uri="databricks-uc") 1699 model_version = client.create_model_version( 1700 name, model_uri, run_id, run_link=None, model_id=model_id 1701 ) 1702 assert model_version.model_id == model_id 1703 mock_registry_store.create_model_version.assert_called_once_with( 1704 name, 1705 source, 1706 run_id, 1707 [], 1708 None, 1709 None, 1710 local_model_path=None, 1711 model_id=None, 1712 ) 1713 1714 1715 def test_creation_default_values_in_unity_catalog(mock_registry_store): 1716 client = MlflowClient(tracking_uri="databricks", registry_uri="databricks-uc") 1717 mock_registry_store.create_model_version.return_value = ModelVersion( 1718 "name", 1719 1, 1720 0, 1721 1, 1722 source="source", 1723 run_id="runid", 1724 ) 1725 client.create_model_version("name", "source", "runid") 1726 # verify that registry store was called with tags=[] and run_link=None 1727 mock_registry_store.create_model_version.assert_called_once_with( 1728 "name", "source", "runid", [], None, None, local_model_path=None, model_id=None 1729 ) 1730 client.create_registered_model(name="name", description="description") 1731 # verify that registry store was called with tags=[] 1732 mock_registry_store.create_registered_model.assert_called_once_with( 1733 "name", [], "description", None 1734 ) 1735 1736 1737 def test_await_model_version_creation(mock_registry_store): 1738 mv = ModelVersion( 1739 name="name", 1740 version=1, 1741 creation_timestamp=123, 1742 status=ModelVersionStatus.to_string(ModelVersionStatus.FAILED_REGISTRATION), 1743 ) 1744 mock_registry_store.create_model_version.return_value = mv 1745 1746 client = MlflowClient(tracking_uri="http://10.123.1231.11") 1747 1748 client.create_model_version("name", "source") 1749 mock_registry_store._await_model_version_creation.assert_called_once_with( 1750 mv, DEFAULT_AWAIT_MAX_SLEEP_SECONDS 1751 ) 1752 1753 1754 def test_create_model_version_run_link_with_configured_profile( 1755 mock_registry_store, mock_databricks_tracking_store 1756 ): 1757 hostname = "https://workspace.databricks.com/" 1758 workspace_id = "10002" 1759 workspace_url = _construct_databricks_run_url( 1760 hostname, 1761 mock_databricks_tracking_store.experiment_id, 1762 mock_databricks_tracking_store.run_id, 1763 workspace_id, 1764 ) 1765 1766 with ( 1767 mock.patch("mlflow.utils.databricks_utils.is_in_databricks_notebook", return_value=False), 1768 mock.patch( 1769 "mlflow.utils.databricks_utils.get_workspace_info_from_databricks_secrets", 1770 return_value=(hostname, workspace_id), 1771 ), 1772 ): 1773 client = MlflowClient(tracking_uri="databricks", registry_uri="otherplace") 1774 mock_registry_store.create_model_version.return_value = ModelVersion( 1775 "name", 1776 1, 1777 0, 1778 1, 1779 source="source", 1780 run_id=mock_databricks_tracking_store.run_id, 1781 run_link=workspace_url, 1782 ) 1783 model_version = client.create_model_version("name", "source", "runid") 1784 assert model_version.run_link == workspace_url 1785 # verify that the client generated the right URL 1786 mock_registry_store.create_model_version.assert_called_once_with( 1787 "name", "source", "runid", [], workspace_url, None, local_model_path=None, model_id=None 1788 ) 1789 1790 1791 def test_create_model_version_copy_called_db_to_db(mock_registry_store): 1792 client = MlflowClient( 1793 tracking_uri="databricks://tracking", 1794 registry_uri="databricks://registry:workspace", 1795 ) 1796 mock_registry_store.create_model_version.return_value = _default_model_version() 1797 with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock: 1798 client.create_model_version( 1799 "model name", 1800 "dbfs:/source", 1801 "run_12345", 1802 run_link="not:/important/for/test", 1803 ) 1804 upload_mock.assert_called_once_with( 1805 "dbfs:/source", 1806 "run_12345", 1807 "databricks://tracking", 1808 "databricks://registry:workspace", 1809 ) 1810 1811 1812 def test_create_model_version_copy_called_nondb_to_db(mock_registry_store): 1813 client = MlflowClient( 1814 tracking_uri="https://tracking", registry_uri="databricks://registry:workspace" 1815 ) 1816 mock_registry_store.create_model_version.return_value = _default_model_version() 1817 with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock: 1818 client.create_model_version( 1819 "model name", "s3:/source", "run_12345", run_link="not:/important/for/test" 1820 ) 1821 upload_mock.assert_called_once_with( 1822 "s3:/source", 1823 "run_12345", 1824 "https://tracking", 1825 "databricks://registry:workspace", 1826 ) 1827 1828 1829 def test_create_model_version_copy_not_called_to_db(mock_registry_store): 1830 client = MlflowClient( 1831 tracking_uri="databricks://registry:workspace", 1832 registry_uri="databricks://registry:workspace", 1833 ) 1834 mock_registry_store.create_model_version.return_value = _default_model_version() 1835 with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock: 1836 client.create_model_version( 1837 "model name", 1838 "dbfs:/source", 1839 "run_12345", 1840 run_link="not:/important/for/test", 1841 ) 1842 upload_mock.assert_not_called() 1843 1844 1845 def test_create_model_version_copy_not_called_to_nondb(mock_registry_store): 1846 client = MlflowClient(tracking_uri="databricks://tracking", registry_uri="https://registry") 1847 mock_registry_store.create_model_version.return_value = _default_model_version() 1848 with mock.patch("mlflow.tracking.client._upload_artifacts_to_databricks") as upload_mock: 1849 client.create_model_version( 1850 "model name", 1851 "dbfs:/source", 1852 "run_12345", 1853 run_link="not:/important/for/test", 1854 ) 1855 upload_mock.assert_not_called() 1856 1857 1858 def _default_model_version(): 1859 return ModelVersion("model name", 1, creation_timestamp=123, status="READY") 1860 1861 1862 def test_client_can_be_serialized_with_pickle(tmp_path): 1863 """ 1864 Verifies that instances of `MlflowClient` can be serialized using pickle, even if the underlying 1865 Tracking and Model Registry stores used by the client are not serializable using pickle 1866 """ 1867 1868 class MockUnpickleableTrackingStore(SqlAlchemyTrackingStore): 1869 pass 1870 1871 class MockUnpickleableModelRegistryStore(SqlAlchemyModelRegistryStore): 1872 pass 1873 1874 backend_store_path = tmp_path.joinpath("test.db") 1875 artifact_store_path = tmp_path.joinpath("artifacts") 1876 1877 mock_tracking_store = MockUnpickleableTrackingStore( 1878 f"sqlite:///{backend_store_path}", str(artifact_store_path) 1879 ) 1880 mock_model_registry_store = MockUnpickleableModelRegistryStore( 1881 f"sqlite:///{backend_store_path}" 1882 ) 1883 1884 # Verify that the mock stores cannot be pickled because they are defined within a function 1885 # (i.e. the test function) 1886 with pytest.raises(AttributeError, match="<locals>.MockUnpickleableTrackingStore'"): 1887 pickle.dumps(mock_tracking_store) 1888 1889 with pytest.raises(AttributeError, match="<locals>.MockUnpickleableModelRegistryStore'"): 1890 pickle.dumps(mock_model_registry_store) 1891 1892 _register("pickle", lambda *args, **kwargs: mock_tracking_store) 1893 _get_model_registry_store_registry().register( 1894 "pickle", lambda *args, **kwargs: mock_model_registry_store 1895 ) 1896 1897 # Create an MlflowClient with the store that cannot be pickled, perform 1898 # tracking & model registry operations, and verify that the client can still be pickled 1899 client = MlflowClient("pickle://foo") 1900 client.create_experiment("test_experiment") 1901 client.create_registered_model("test_model") 1902 pickle.dumps(client) 1903 1904 1905 @pytest.fixture 1906 def mock_registry_store_with_get_latest_version(mock_registry_store): 1907 mock_get_latest_versions = mock.Mock() 1908 mock_get_latest_versions.return_value = [ 1909 ModelVersion( 1910 "model_name", 1911 1, 1912 0, 1913 ) 1914 ] 1915 1916 mock_registry_store.get_latest_versions = mock_get_latest_versions 1917 return mock_registry_store 1918 1919 1920 def test_set_model_version_tag(mock_registry_store_with_get_latest_version): 1921 # set_model_version_tag using version 1922 MlflowClient().set_model_version_tag("model_name", 1, "tag1", "foobar") 1923 mock_registry_store_with_get_latest_version.set_model_version_tag.assert_called_once_with( 1924 "model_name", 1, ModelVersionTag(key="tag1", value="foobar") 1925 ) 1926 1927 mock_registry_store_with_get_latest_version.set_model_version_tag.reset_mock() 1928 1929 # set_model_version_tag using stage 1930 MlflowClient().set_model_version_tag("model_name", key="tag1", value="foobar", stage="Staging") 1931 mock_registry_store_with_get_latest_version.set_model_version_tag.assert_called_once_with( 1932 "model_name", 1, ModelVersionTag(key="tag1", value="foobar") 1933 ) 1934 1935 # set_model_version_tag with version and stage set 1936 with pytest.raises(MlflowException, match="version and stage cannot be set together"): 1937 MlflowClient().set_model_version_tag("model_name", 1, "tag1", "foobar", stage="Staging") 1938 1939 # set_model_version_tag with version and stage not set 1940 with pytest.raises(MlflowException, match="version or stage must be set"): 1941 MlflowClient().set_model_version_tag("model_name", key="tag1", value="foobar") 1942 1943 1944 def test_delete_model_version_tag(mock_registry_store_with_get_latest_version): 1945 # delete_model_version_tag using version 1946 MlflowClient().delete_model_version_tag("model_name", 1, "tag1") 1947 mock_registry_store_with_get_latest_version.delete_model_version_tag.assert_called_once_with( 1948 "model_name", 1, "tag1" 1949 ) 1950 1951 mock_registry_store_with_get_latest_version.delete_model_version_tag.reset_mock() 1952 1953 # delete_model_version_tag using stage 1954 MlflowClient().delete_model_version_tag("model_name", key="tag1", stage="Staging") 1955 mock_registry_store_with_get_latest_version.delete_model_version_tag.assert_called_once_with( 1956 "model_name", 1, "tag1" 1957 ) 1958 1959 # delete_model_version_tag with version and stage set 1960 with pytest.raises(MlflowException, match="version and stage cannot be set together"): 1961 MlflowClient().delete_model_version_tag( 1962 "model_name", version=1, key="tag1", stage="staging" 1963 ) 1964 1965 # delete_model_version_tag with version and stage not set 1966 with pytest.raises(MlflowException, match="version or stage must be set"): 1967 MlflowClient().delete_model_version_tag("model_name", key="tag1") 1968 1969 1970 def test_set_registered_model_alias(mock_registry_store): 1971 MlflowClient().set_registered_model_alias("model_name", "test_alias", 1) 1972 mock_registry_store.set_registered_model_alias.assert_called_once_with( 1973 "model_name", "test_alias", 1 1974 ) 1975 1976 1977 def test_delete_registered_model_alias(mock_registry_store): 1978 MlflowClient().delete_registered_model_alias("model_name", "test_alias") 1979 mock_registry_store.delete_registered_model_alias.assert_called_once_with( 1980 "model_name", "test_alias" 1981 ) 1982 1983 1984 def test_get_model_version_by_alias(mock_registry_store): 1985 mock_registry_store.get_model_version_by_alias.return_value = _default_model_version() 1986 res = MlflowClient().get_model_version_by_alias("model_name", "test_alias") 1987 assert res == _default_model_version() 1988 mock_registry_store.get_model_version_by_alias.assert_called_once_with( 1989 "model_name", "test_alias" 1990 ) 1991 1992 1993 def test_update_run(mock_store): 1994 MlflowClient().update_run(run_id="run_id", status="FINISHED", name="my name") 1995 mock_store.update_run_info.assert_called_once_with( 1996 run_id="run_id", 1997 run_status=RunStatus.from_string("FINISHED"), 1998 end_time=mock.ANY, 1999 run_name="my name", 2000 ) 2001 2002 2003 def test_client_log_metric_params_tags_overrides(mock_store): 2004 experiment_id = mock.Mock() 2005 start_time = mock.Mock() 2006 run_name = mock.Mock() 2007 run = MlflowClient().create_run(experiment_id, start_time, tags={}, run_name=run_name) 2008 run_id = run.info.run_id 2009 2010 run_operation = MlflowClient().log_metric(run_id, "m1", 0.87, 123456789, 1, synchronous=False) 2011 run_operation.wait() 2012 2013 run_operation = MlflowClient().log_param(run_id, "p1", "pv1", synchronous=False) 2014 run_operation.wait() 2015 2016 run_operation = MlflowClient().set_tag(run_id, "t1", "tv1", synchronous=False) 2017 run_operation.wait() 2018 2019 mock_store.log_metric_async.assert_called_once_with(run_id, Metric("m1", 0.87, 123456789, 1)) 2020 mock_store.log_param_async.assert_called_once_with(run_id, Param("p1", "pv1")) 2021 mock_store.set_tag_async.assert_called_once_with(run_id, RunTag("t1", "tv1")) 2022 2023 mock_store.reset_mock() 2024 2025 # log_batch_async 2026 MlflowClient().create_run(experiment_id, start_time, {}) 2027 metrics = [Metric("m1", 0.87, 123456789, 1), Metric("m2", 0.87, 123456789, 1)] 2028 tags = [RunTag("t1", "tv1"), RunTag("t2", "tv2")] 2029 params = [Param("p1", "pv1"), Param("p2", "pv2")] 2030 run_operation = MlflowClient().log_batch(run_id, metrics, params, tags, synchronous=False) 2031 run_operation.wait() 2032 2033 mock_store.log_batch_async.assert_called_once_with( 2034 run_id=run_id, metrics=metrics, params=params, tags=tags 2035 ) 2036 2037 2038 def test_invalid_run_id_log_artifact(): 2039 with pytest.raises( 2040 MlflowException, 2041 match=r"Invalid run id.*", 2042 ): 2043 MlflowClient().log_artifact("tr-123", "path") 2044 2045 2046 def test_enable_async_logging(mock_store, setup_async_logging): 2047 MlflowClient().log_param(run_id="run_id", key="key", value="val") 2048 mock_store.log_param_async.assert_called_once_with("run_id", Param("key", "val")) 2049 2050 MlflowClient().log_metric(run_id="run_id", key="key", value="val", step=1, timestamp=1) 2051 mock_store.log_metric_async.assert_called_once_with("run_id", Metric("key", "val", 1, 1)) 2052 2053 2054 def test_file_store_download_upload_trace_data(tmp_path): 2055 pytest.skip("FileStore is no longer supported.") 2056 with _use_tracking_uri(tmp_path.joinpath("mlruns").as_uri()): 2057 client = MlflowClient() 2058 span = client.start_trace("test", inputs={"test": 1}) 2059 client.end_trace(span.trace_id, outputs={"result": 2}) 2060 trace = mlflow.get_trace(span.trace_id, flush=True) 2061 trace_data = client.get_trace(span.trace_id, flush=True).data 2062 assert trace_data.request == trace.data.request 2063 assert trace_data.response == trace.data.response 2064 2065 2066 def test_get_trace_throw_if_trace_id_is_online_trace_id(db_uri): 2067 client = MlflowClient("databricks") 2068 trace_id = "3a3c3b56-910a-4721-8d02-0333eda5f37e" 2069 with pytest.raises(MlflowException, match="Traces from inference tables can only be loaded"): 2070 client.get_trace(trace_id) 2071 2072 another_client = MlflowClient(db_uri) 2073 with pytest.raises(MlflowException, match=r"Trace with ID '[\w-]+' not found"): 2074 another_client.get_trace(trace_id) 2075 2076 2077 @pytest.fixture(params=["file", "sqlalchemy"]) 2078 def registry_uri(request, tmp_path, db_uri): 2079 """Set an MLflow Model Registry URI with different type of backend.""" 2080 if request.param == "file": 2081 pytest.skip("FileStore is no longer supported.") 2082 if "MLFLOW_SKINNY" in os.environ and request.param == "sqlalchemy": 2083 pytest.skip("SQLAlchemy store is not available in skinny.") 2084 2085 original_registry_uri = mlflow.get_registry_uri() 2086 2087 if request.param == "file": 2088 registry_uri = tmp_path.joinpath("file").as_uri() 2089 elif request.param == "sqlalchemy": 2090 registry_uri = db_uri 2091 2092 yield registry_uri 2093 2094 # Reset tracking URI 2095 mlflow.set_tracking_uri(original_registry_uri) 2096 2097 2098 def test_crud_prompts(tracking_uri): 2099 client = MlflowClient(tracking_uri=tracking_uri) 2100 2101 client.register_prompt( 2102 name="prompt_1", 2103 template="Hi, {{title}} {{name}}! How are you today?", 2104 commit_message="A friendly greeting", 2105 ) 2106 2107 prompt = client.load_prompt("prompt_1", version=1) 2108 assert prompt.name == "prompt_1" 2109 assert prompt.template == "Hi, {{title}} {{name}}! How are you today?" 2110 assert prompt.commit_message == "A friendly greeting" 2111 2112 client.register_prompt( 2113 name="prompt_1", 2114 template="Hi, {{title}} {{name}}! What's up?", 2115 commit_message="New greeting", 2116 ) 2117 2118 prompt = client.load_prompt("prompt_1", version=2) 2119 assert prompt.template == "Hi, {{title}} {{name}}! What's up?" 2120 2121 prompt = client.load_prompt("prompt_1", version=1) 2122 assert prompt.template == "Hi, {{title}} {{name}}! How are you today?" 2123 2124 prompt = client.load_prompt("prompts:/prompt_1/2") 2125 assert prompt.template == "Hi, {{title}} {{name}}! What's up?" 2126 2127 # Test loading non-existent prompts 2128 assert mlflow.load_prompt("does_not_exist", version=1, allow_missing=True) is None 2129 2130 2131 def test_create_prompt_with_tags_and_metadata(tracking_uri, disable_prompt_cache): 2132 def wait_for_prompt_linking(): 2133 """Wait for background prompt linking threads to complete.""" 2134 for t in threading.enumerate(): 2135 if t.name.startswith("link_prompt_to_experiment_thread"): 2136 t.join(timeout=5.0) 2137 if t.is_alive(): 2138 raise TimeoutError(f"Thread {t.name} did not complete within timeout.") 2139 2140 client = MlflowClient(tracking_uri=tracking_uri) 2141 2142 # Create prompt with version-specific tags 2143 client.register_prompt( 2144 name="prompt_1", 2145 template="Hi, {{name}}!", 2146 tags={"author": "Alice"}, # This will be version-level tags now 2147 ) 2148 2149 # Wait for the background linking thread to complete 2150 wait_for_prompt_linking() 2151 2152 # Set some prompt-level tags separately 2153 client.set_prompt_tag("prompt_1", "application", "greeting") 2154 client.set_prompt_tag("prompt_1", "language", "en") 2155 2156 # Test version 1 2157 prompt_v1 = client.load_prompt("prompt_1", version=1) 2158 assert prompt_v1.template == "Hi, {{name}}!" 2159 # Version tags are separate from prompt tags 2160 assert prompt_v1.tags == {"author": "Alice"} 2161 2162 # Wait for the background linking thread from load_prompt 2163 wait_for_prompt_linking() 2164 2165 # Test prompt-level tags (separate from version) 2166 prompt_entity = client.get_prompt("prompt_1") 2167 # Note: Currently includes the version tags too, but we expect this behavior to change 2168 assert prompt_entity.tags == { 2169 "author": "Alice", # This appears due to current implementation 2170 "application": "greeting", 2171 "language": "en", 2172 "_mlflow_experiment_ids": ",0,", # Linked to Default experiment 2173 } 2174 2175 # Create version 2 with different version-level tags 2176 client.register_prompt( 2177 name="prompt_1", 2178 template="こんにちは、{{name}}!", 2179 tags={"author": "Bob", "date": "2022-01-01"}, # Version-level tags 2180 ) 2181 2182 # Wait for the background linking thread from register_prompt 2183 wait_for_prompt_linking() 2184 2185 # Update some prompt-level tags 2186 client.set_prompt_tag("prompt_1", "project", "toy") 2187 client.set_prompt_tag("prompt_1", "language", "ja") 2188 2189 # Test version 2 2190 prompt_v2 = client.load_prompt("prompt_1", version=2) 2191 assert prompt_v2.template == "こんにちは、{{name}}!" 2192 # Version 2 has its own version tags (decoupled from prompt and version 1) 2193 assert prompt_v2.tags == {"author": "Bob", "date": "2022-01-01"} 2194 2195 # Wait for the background linking thread from load_prompt 2196 wait_for_prompt_linking() 2197 2198 # Verify prompt-level tags are updated and separate 2199 prompt_entity_updated = client.get_prompt("prompt_1") 2200 # Note: Currently the prompt tags get overwritten by the newest version's tags 2201 assert prompt_entity_updated.tags == { 2202 "author": "Bob", # This appears due to current implementation 2203 "date": "2022-01-01", # This appears due to current implementation 2204 "application": "greeting", 2205 "project": "toy", 2206 "language": "ja", 2207 "_mlflow_experiment_ids": ",0,", # Linked to Default experiment 2208 } 2209 2210 # Version 1 tags should be unchanged (decoupled from prompt tags) 2211 prompt_v1_after_update = client.load_prompt("prompt_1", version=1) 2212 assert prompt_v1_after_update.tags == {"author": "Alice"} # Unchanged 2213 2214 2215 def test_create_prompt_error_handling(tracking_uri, disable_prompt_cache): 2216 client = MlflowClient(tracking_uri=tracking_uri) 2217 2218 # Exceeds the max length 2219 with pytest.raises(MlflowException, match=r"Prompt text exceeds max length of"): 2220 client.register_prompt(name="prompt_1", template="a" * 100_001) 2221 2222 # When the first version creation fails, RegisteredModel should not be created 2223 with pytest.raises(MlflowException, match=r"Prompt with name=prompt_1 not found"): 2224 client.load_prompt("prompt_1", version=1) 2225 2226 client.register_prompt("prompt_1", template="Hi, {{title}} {{name}}!") 2227 assert client.load_prompt("prompt_1", version=1) is not None 2228 2229 # When the subsequent version creation fails, RegisteredModel should remain 2230 with pytest.raises(MlflowException, match=r"Prompt text exceeds max length of"): 2231 client.register_prompt(name="prompt_1", template="a" * 100_001) 2232 2233 assert client.load_prompt("prompt_1", version=1) is not None 2234 2235 2236 def test_create_prompt_with_invalid_name(tracking_uri): 2237 client = MlflowClient(tracking_uri=tracking_uri) 2238 2239 with pytest.raises(MlflowException, match=r"Prompt name must be a non-empty string"): 2240 client.register_prompt(name="", template="Hi, {{name}}!") 2241 2242 with pytest.raises(MlflowException, match=r"Prompt name must be a non-empty string"): 2243 client.register_prompt(name=123, template="Hi, {{name}}!") 2244 2245 for invalid_pattern in [ 2246 "prompt_1/2", 2247 "m%6fdel", 2248 "prompt?!?", 2249 "prompt with space", 2250 ]: 2251 with pytest.raises(MlflowException, match=r"Prompt name can only contain alphanumeric"): 2252 client.register_prompt(name=invalid_pattern, template="Hi, {{name}}!") 2253 2254 # Name conflicts with a model 2255 client.create_registered_model("model") 2256 with pytest.raises(MlflowException, match=r"Model 'model' exists with the same name."): 2257 client.register_prompt(name="model", template="Hi, {{name}}!") 2258 2259 2260 def test_load_prompt_error(tracking_uri): 2261 client = MlflowClient(tracking_uri=tracking_uri) 2262 2263 with pytest.raises(MlflowException, match=r"Prompt with name=test not found"): 2264 client.load_prompt("test", version=1) 2265 2266 # Both file and sqlalchemy return the same error format now 2267 error_msg = r"Prompt with name=test not found" 2268 2269 with pytest.raises(MlflowException, match=error_msg): 2270 client.load_prompt("test", version=2) 2271 2272 with pytest.raises(MlflowException, match=error_msg): 2273 client.load_prompt("test", version=2, allow_missing=False) 2274 2275 # Load prompt with a model name 2276 client.create_registered_model("model") 2277 client.create_model_version("model", "source") 2278 2279 with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"): 2280 client.load_prompt("model", version=1) 2281 2282 with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"): 2283 client.load_prompt("model", version=1) 2284 2285 with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"): 2286 client.load_prompt("model", version=1, allow_missing=False) 2287 2288 with pytest.raises(MlflowException, match=r"Name `model` is registered as a model"): 2289 client.load_prompt("model", version=1, allow_missing=False) 2290 2291 2292 def test_link_prompt_version_to_run(tracking_uri): 2293 client = MlflowClient(tracking_uri=tracking_uri) 2294 2295 prompt = client.register_prompt("prompt", template="Hi, {{name}}!") 2296 2297 # Create actual runs to link to 2298 run1 = client.create_run(experiment_id="0").info.run_id 2299 run2 = client.create_run(experiment_id="0").info.run_id 2300 2301 # Test that the method can be called without error 2302 client.link_prompt_version_to_run(run1, prompt) 2303 client.link_prompt_version_to_run(run2, prompt) 2304 2305 # Verify tag was set by checking the run data 2306 run_data = client.get_run(run1) 2307 linked_prompts_tag = run_data.data.tags.get("mlflow.linkedPrompts") 2308 assert linked_prompts_tag is not None 2309 2310 # Verify the JSON structure 2311 linked_prompts = json.loads(linked_prompts_tag) 2312 assert any(p["name"] == "prompt" and p["version"] == "1" for p in linked_prompts) 2313 2314 # Test error case 2315 with pytest.raises(MlflowException, match=r"The `prompt` argument must be"): 2316 client.link_prompt_version_to_run(run1, 123) 2317 2318 2319 @pytest.mark.parametrize("registry_uri", ["databricks"]) 2320 def test_crud_prompt_on_unsupported_registry(registry_uri): 2321 client = MlflowClient(registry_uri=registry_uri) 2322 2323 with pytest.raises(MlflowException, match=r"The 'register_prompt' API is not supported"): 2324 client.register_prompt( 2325 name="prompt_1", 2326 template="Hi, {{title}} {{name}}! How are you today?", 2327 commit_message="A friendly greeting", 2328 tags={"model": "my-model"}, 2329 ) 2330 2331 with pytest.raises(MlflowException, match=r"The 'load_prompt' API is not supported"): 2332 client.load_prompt("prompt_1") 2333 2334 2335 def test_block_create_model_with_prompt_tag(tracking_uri): 2336 client = MlflowClient(tracking_uri=tracking_uri) 2337 2338 with pytest.raises(MlflowException, match=r"Prompts cannot be registered"): 2339 client.create_registered_model( 2340 name="model", 2341 tags={IS_PROMPT_TAG_KEY: "true"}, 2342 ) 2343 2344 with pytest.raises(MlflowException, match=r"Prompts cannot be registered"): 2345 client.create_model_version( 2346 name="model", 2347 source="source", 2348 tags={IS_PROMPT_TAG_KEY: "false"}, 2349 ) 2350 2351 2352 def test_block_create_prompt_with_existing_model_name(tracking_uri): 2353 client = MlflowClient(tracking_uri=tracking_uri) 2354 2355 client.create_registered_model("model") 2356 2357 with pytest.raises(MlflowException, match=r"Model 'model' exists with"): 2358 client.register_prompt( 2359 name="model", 2360 template="Hi, {{title}} {{name}}! How are you today?", 2361 commit_message="A friendly greeting", 2362 tags={"model": "my-model"}, 2363 ) 2364 2365 2366 def test_block_handling_prompt_with_model_apis(tracking_uri): 2367 client = MlflowClient(tracking_uri=tracking_uri) 2368 client.register_prompt("prompt", template="Hi, {{name}}!") 2369 client.set_prompt_alias("prompt", alias="alias", version=1) 2370 # Validate the prompt is registered 2371 prompt = client.load_prompt("prompt", version=1) 2372 assert prompt.name == "prompt" 2373 assert prompt.aliases == ["alias"] 2374 2375 apis_to_args = [ 2376 (client.rename_registered_model, ["prompt", "new_name"]), 2377 (client.update_registered_model, ["prompt", "new_description"]), 2378 (client.delete_registered_model, ["prompt"]), 2379 (client.get_registered_model, ["prompt"]), 2380 (client.get_latest_versions, ["prompt"]), 2381 (client.set_registered_model_tag, ["prompt", "tag", "value"]), 2382 (client.delete_registered_model_tag, ["prompt", "tag"]), 2383 (client.update_model_version, ["prompt", 1, "new_description"]), 2384 (client.transition_model_version_stage, ["prompt", 1, "Production"]), 2385 (client.delete_model_version, ["prompt", 1]), 2386 (client.get_model_version, ["prompt", 1]), 2387 (client.get_model_version_download_uri, ["prompt", 1]), 2388 (client.set_model_version_tag, ["prompt", 1, "tag", "value"]), 2389 (client.delete_model_version_tag, ["prompt", 1, "tag"]), 2390 (client.set_registered_model_alias, ["prompt", "alias", 1]), 2391 (client.delete_registered_model_alias, ["prompt", "alias"]), 2392 (client.get_model_version_by_alias, ["prompt", "alias"]), 2393 ] 2394 2395 for api, args in apis_to_args: 2396 with pytest.raises(MlflowException, match=r"Registered Model with name='prompt' not found"): 2397 api(*args) 2398 2399 with pytest.raises(MlflowException, match=r"Model with uri 'models:/prompt/1' not found"): 2400 client.copy_model_version("models:/prompt/1", "new_model") 2401 2402 2403 def test_log_and_detach_prompt(tracking_uri): 2404 client = MlflowClient(tracking_uri=tracking_uri) 2405 2406 client.register_prompt(name="p1", template="Hi, there!") 2407 time.sleep(0.001) # To avoid timestamp precision issue in Windows 2408 client.register_prompt(name="p2", template="Hi, {{name}}!") 2409 2410 run_id = client.create_run(experiment_id="0").info.run_id 2411 2412 # Check that initially no prompts are linked to the run 2413 run = client.get_run(run_id) 2414 linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS) 2415 assert linked_prompts_tag is None 2416 2417 client.link_prompt_version_to_run(run_id, "prompts:/p1/1") 2418 run = client.get_run(run_id) 2419 linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS) 2420 assert linked_prompts_tag is not None 2421 prompts = json.loads(linked_prompts_tag) 2422 assert len(prompts) == 1 2423 assert prompts[0]["name"] == "p1" 2424 2425 client.link_prompt_version_to_run(run_id, "prompts:/p2/1") 2426 run = client.get_run(run_id) 2427 linked_prompts_tag = run.data.tags.get(TraceTagKey.LINKED_PROMPTS) 2428 prompts = json.loads(linked_prompts_tag) 2429 assert len(prompts) == 2 2430 prompt_names = [p["name"] for p in prompts] 2431 assert "p1" in prompt_names 2432 assert "p2" in prompt_names 2433 2434 2435 def test_search_prompt(tracking_uri): 2436 client = MlflowClient(tracking_uri=tracking_uri) 2437 2438 client.register_prompt(name="prompt_1", template="Hi, {{name}}!") 2439 client.register_prompt(name="prompt_2", template="Hello, {{name}}!") 2440 client.register_prompt(name="prompt_3", template="Greetings, {{name}}!") 2441 client.register_prompt(name="prompt_4", template="Howdy, {{name}}!") 2442 client.register_prompt(name="prompt_5", template="Salutations, {{name}}!") 2443 client.register_prompt(name="prompt_6", template="Bonjour, {{name}}!") 2444 client.register_prompt(name="test", template="Test Template") 2445 client.register_prompt(name="new", template="Bonjour, {{name}}!") 2446 2447 prompts = client.search_prompts(filter_string="name='prompt_1'") 2448 assert len(prompts) == 1 2449 assert prompts[0].name == "prompt_1" 2450 2451 prompts = client.search_prompts(filter_string="name LIKE '%prompt%'") 2452 assert len(prompts) == 6 2453 assert all("prompt" in prompt.name for prompt in prompts) 2454 2455 prompts = client.search_prompts() 2456 assert len(prompts) == 8 2457 2458 prompts = client.search_prompts(max_results=3) 2459 assert len(prompts) == 3 2460 2461 2462 def test_delete_prompt_version_no_auto_cleanup(tracking_uri): 2463 client = MlflowClient(tracking_uri=tracking_uri) 2464 2465 # Create prompt and version 2466 client.register_prompt(name="test_prompt", template="Hello {{name}}!") 2467 2468 # Verify prompt and version exist 2469 prompt = client.get_prompt("test_prompt") 2470 assert prompt is not None 2471 assert prompt.name == "test_prompt" 2472 2473 prompt_version = client.get_prompt_version("test_prompt", 1) 2474 assert prompt_version is not None 2475 assert prompt_version.version == 1 2476 2477 # Delete the version - prompt should remain 2478 client.delete_prompt_version("test_prompt", "1") 2479 2480 # Prompt should still exist even though it has no versions 2481 prompt = client.get_prompt("test_prompt") 2482 assert prompt is not None 2483 assert prompt.name == "test_prompt" 2484 2485 # Version should be gone 2486 with pytest.raises(MlflowException, match=r"Prompt.*name=test_prompt.*version=1.*not found"): 2487 client.get_prompt_version("test_prompt", 1) 2488 2489 2490 def test_delete_prompt_version_invalidates_cached_load_prompt(tracking_uri): 2491 client = MlflowClient(tracking_uri=tracking_uri) 2492 2493 prompt_ver = client.register_prompt(name="test_prompt", template="Version 1") 2494 loaded = client.load_prompt(prompt_ver.name, version=prompt_ver.version) 2495 assert loaded.template == "Version 1" 2496 2497 client.delete_prompt_version(prompt_ver.name, str(prompt_ver.version)) 2498 2499 with pytest.raises( 2500 MlflowException, 2501 match=rf"Prompt.*name={prompt_ver.name}.*version={prompt_ver.version}.*not found", 2502 ): 2503 client.get_prompt_version(prompt_ver.name, prompt_ver.version) 2504 2505 with pytest.raises( 2506 MlflowException, 2507 match=rf"Prompt.*name={prompt_ver.name}.*version={prompt_ver.version}.*not found", 2508 ): 2509 client.load_prompt(prompt_ver.name, version=prompt_ver.version) 2510 2511 2512 def test_delete_prompt_version_invalidates_latest_cache(tracking_uri): 2513 client = MlflowClient(tracking_uri=tracking_uri) 2514 2515 prompt_v1 = client.register_prompt(name="test_prompt", template="Version 1") 2516 prompt_v2 = client.register_prompt(name=prompt_v1.name, template="Version 2") 2517 2518 latest_prompt = client.load_prompt(f"prompts:/{prompt_v1.name}@latest") 2519 assert latest_prompt.version == prompt_v2.version 2520 assert latest_prompt.template == prompt_v2.template 2521 2522 client.delete_prompt_version(prompt_v2.name, str(prompt_v2.version)) 2523 2524 latest_prompt_after_delete = client.load_prompt(f"prompts:/{prompt_v1.name}@latest") 2525 assert latest_prompt_after_delete.version == prompt_v1.version 2526 assert latest_prompt_after_delete.template == prompt_v1.template 2527 2528 2529 def test_set_prompt_model_config_invalidates_latest_cache(tracking_uri): 2530 client = MlflowClient(tracking_uri=tracking_uri) 2531 2532 cache_ttl_seconds = 60 2533 prompt = client.register_prompt(name="test_prompt", template="test") 2534 prompt_before_update = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds) 2535 assert prompt_before_update.model_config is None 2536 2537 model_config = {"model_name": "gpt-4", "temperature": 0.7} 2538 mlflow.genai.set_prompt_model_config( 2539 name=prompt.name, 2540 version=prompt.version, 2541 model_config=model_config, 2542 ) 2543 2544 prompt_after_update = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds) 2545 assert prompt_after_update.model_config == model_config 2546 2547 2548 def test_delete_prompt_model_config_invalidates_latest_cache(tracking_uri): 2549 client = MlflowClient(tracking_uri=tracking_uri) 2550 2551 cache_ttl_seconds = 60 2552 model_config = {"model_name": "gpt-4", "temperature": 0.7} 2553 prompt = client.register_prompt( 2554 name="test_prompt", 2555 template="test", 2556 model_config=model_config, 2557 ) 2558 prompt_before_delete = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds) 2559 assert prompt_before_delete.model_config == model_config 2560 2561 mlflow.genai.delete_prompt_model_config(name=prompt.name, version=prompt.version) 2562 2563 prompt_after_delete = client.load_prompt(prompt.name, cache_ttl_seconds=cache_ttl_seconds) 2564 assert prompt_after_delete.model_config is None 2565 2566 2567 def test_delete_prompt_version_invalidates_alias_cache(tracking_uri): 2568 client = MlflowClient(tracking_uri=tracking_uri) 2569 2570 prompt_v1 = client.register_prompt(name="test_prompt", template="Version 1") 2571 client.register_prompt(name=prompt_v1.name, template="Version 2") 2572 client.set_prompt_alias(prompt_v1.name, alias="production", version=prompt_v1.version) 2573 2574 aliased_prompt = client.load_prompt(f"prompts:/{prompt_v1.name}@production") 2575 assert aliased_prompt.version == prompt_v1.version 2576 assert aliased_prompt.template == prompt_v1.template 2577 2578 client.delete_prompt_version(prompt_v1.name, str(prompt_v1.version)) 2579 2580 with pytest.raises( 2581 MlflowException, 2582 match=( 2583 r"Prompt (.*) does not exist.|Prompt alias (.*) not found.|" 2584 rf"Prompt.*version={prompt_v1.version}.*not found" 2585 ), 2586 ): 2587 client.load_prompt(f"prompts:/{prompt_v1.name}@production") 2588 2589 2590 def test_delete_prompt_with_no_versions(tracking_uri): 2591 client = MlflowClient(tracking_uri=tracking_uri) 2592 mlflow.set_experiment("test_delete_prompt_with_no_versions") 2593 2594 # Create prompt and version, then delete version 2595 client.register_prompt(name="empty_prompt", template="Hello {{name}}!") 2596 client.delete_prompt_version("empty_prompt", "1") 2597 2598 # Verify prompt exists but has no versions 2599 prompt = client.get_prompt("empty_prompt") 2600 assert prompt is not None 2601 2602 # Delete the prompt - should work regardless of registry type 2603 client.delete_prompt("empty_prompt") 2604 2605 # Prompt should be gone 2606 prompt = client.get_prompt("empty_prompt") 2607 assert prompt is None 2608 2609 2610 def test_delete_prompt_invalidates_cached_load_prompt(tracking_uri): 2611 client = MlflowClient(tracking_uri=tracking_uri) 2612 2613 prompt_ver = client.register_prompt(name="test_prompt", template="Version 1") 2614 loaded = client.load_prompt(prompt_ver.name, version=prompt_ver.version) 2615 assert loaded.template == "Version 1" 2616 2617 client.delete_prompt(prompt_ver.name) 2618 2619 assert client.get_prompt(prompt_ver.name) is None 2620 2621 with pytest.raises(MlflowException, match=rf"Prompt.*name={prompt_ver.name}.*not found"): 2622 client.load_prompt(prompt_ver.name, version=prompt_ver.version) 2623 2624 2625 def test_delete_prompt_complete_workflow(tracking_uri): 2626 client = MlflowClient(tracking_uri=tracking_uri) 2627 2628 # Create prompt with multiple versions 2629 client.register_prompt(name="workflow_prompt", template="Version 1: {{name}}") 2630 client.register_prompt(name="workflow_prompt", template="Version 2: {{name}}") 2631 client.register_prompt(name="workflow_prompt", template="Version 3: {{name}}") 2632 2633 # Verify all versions exist 2634 v1 = client.get_prompt_version("workflow_prompt", 1) 2635 v2 = client.get_prompt_version("workflow_prompt", 2) 2636 v3 = client.get_prompt_version("workflow_prompt", 3) 2637 assert v1.template == "Version 1: {{name}}" 2638 assert v2.template == "Version 2: {{name}}" 2639 assert v3.template == "Version 3: {{name}}" 2640 2641 # Delete versions one by one 2642 client.delete_prompt_version("workflow_prompt", "1") 2643 client.delete_prompt_version("workflow_prompt", "2") 2644 client.delete_prompt_version("workflow_prompt", "3") 2645 2646 # Prompt should still exist with no versions 2647 prompt = client.get_prompt("workflow_prompt") 2648 assert prompt is not None 2649 2650 # Now delete the prompt itself 2651 client.delete_prompt("workflow_prompt") 2652 2653 # Prompt should be completely gone 2654 prompt = client.get_prompt("workflow_prompt") 2655 assert prompt is None 2656 2657 2658 def test_delete_prompt_error_handling(tracking_uri): 2659 client = MlflowClient(tracking_uri=tracking_uri) 2660 2661 # Test deleting non-existent prompt 2662 with pytest.raises(MlflowException, match=r"Prompt with name=nonexistent not found"): 2663 client.delete_prompt("nonexistent") 2664 2665 # Test deleting non-existent version 2666 client.register_prompt(name="test_errors", template="Hello {{name}}!") 2667 with pytest.raises(MlflowException, match=r"Prompt.*name=test_errors.*version=999.*not found"): 2668 client.delete_prompt_version("test_errors", "999") 2669 2670 2671 def test_delete_prompt_version_behavior_consistency(tracking_uri): 2672 client = MlflowClient(tracking_uri=tracking_uri) 2673 2674 # Create multiple prompts with versions 2675 for i in range(3): 2676 prompt_name = f"consistency_test_{i}" 2677 client.register_prompt(name=prompt_name, template=f"Template {i}: {{{{name}}}}") 2678 2679 # Delete the version immediately 2680 client.delete_prompt_version(prompt_name, "1") 2681 2682 # Prompt should remain but have no versions 2683 prompt = client.get_prompt(prompt_name) 2684 assert prompt is not None 2685 assert prompt.name == prompt_name 2686 2687 # Version should be gone 2688 with pytest.raises(MlflowException, match=r"Prompt.*version.*not found"): 2689 client.get_prompt_version(prompt_name, 1) 2690 2691 # Clean up - delete all prompts 2692 for i in range(3): 2693 client.delete_prompt(f"consistency_test_{i}") 2694 prompt = client.get_prompt(f"consistency_test_{i}") 2695 assert prompt is None 2696 2697 2698 @pytest.mark.parametrize("registry_uri", ["databricks-uc"]) 2699 def test_delete_prompt_with_versions_unity_catalog_error(registry_uri): 2700 # Mock Unity Catalog behavior 2701 client = MlflowClient(registry_uri=registry_uri) 2702 2703 # Mock the search_prompt_versions to return a PagedList with versions 2704 mock_versions = PagedList([Mock(version="1")], None) 2705 2706 with ( 2707 patch.object(client, "search_prompt_versions", return_value=mock_versions), 2708 patch.object(client, "_registry_uri", registry_uri), 2709 ): 2710 with pytest.raises( 2711 MlflowException, match=r"Cannot delete prompt .* because it still has undeleted" 2712 ): 2713 client.delete_prompt("test_prompt") 2714 2715 2716 def test_link_prompt_version_to_model_smoke_test(tracking_uri): 2717 client = MlflowClient(tracking_uri=tracking_uri) 2718 2719 # Create an experiment and a run to have a proper context 2720 experiment_id = client.create_experiment("test_experiment") 2721 with mlflow.start_run(experiment_id=experiment_id): 2722 # Create a model with a run context 2723 model = client.create_logged_model(experiment_id=experiment_id) 2724 2725 # Register a prompt 2726 client.register_prompt(name="test_prompt", template="Hello, {{name}}!") 2727 2728 # Link the prompt version to the model (this should not raise an exception) 2729 # This is the main assertion - that the method call succeeds 2730 client.link_prompt_version_to_model( 2731 name="test_prompt", version="1", model_id=model.model_id 2732 ) 2733 2734 2735 def test_link_prompts_to_trace_smoke_test(tracking_uri): 2736 client = MlflowClient(tracking_uri=tracking_uri) 2737 2738 # Create an experiment and a run to have a proper context 2739 experiment_id = client.create_experiment("test_experiment") 2740 with mlflow.start_run(experiment_id=experiment_id): 2741 # Create a simple trace for testing 2742 trace_info = client.start_trace("test_trace") 2743 trace_id = trace_info.request_id 2744 2745 # Register a prompt 2746 client.register_prompt(name="test_prompt", template="Hello, {{name}}!") 2747 2748 # Get the prompt version and link to the trace (this should not raise an exception) 2749 # This is the main assertion - that the method call succeeds 2750 prompt_version = client.get_prompt_version("test_prompt", "1") 2751 client.link_prompt_versions_to_trace(prompt_versions=[prompt_version], trace_id=trace_id) 2752 2753 2754 def test_log_model_artifact(tmp_path: Path, tracking_uri: str) -> None: 2755 client = MlflowClient(tracking_uri=tracking_uri) 2756 experiment_id = client.create_experiment("test") 2757 model = client.create_logged_model(experiment_id=experiment_id) 2758 tmp_path = tmp_path.joinpath("artifacts") 2759 tmp_path.mkdir() 2760 tmp_file = tmp_path.joinpath("file") 2761 tmp_file.write_text("a") 2762 client.log_model_artifact(model_id=model.model_id, local_path=str(tmp_file)) 2763 artifacts = client.list_logged_model_artifacts(model_id=model.model_id) 2764 assert artifacts == [FileInfo(path="file", is_dir=False, file_size=1)] 2765 another_tmp_file = tmp_path.joinpath("another_file") 2766 another_tmp_file.write_text("aa") 2767 client.log_model_artifact(model_id=model.model_id, local_path=str(another_tmp_file)) 2768 artifacts = client.list_logged_model_artifacts(model_id=model.model_id) 2769 artifacts = sorted(artifacts, key=lambda x: x.path) 2770 assert artifacts == [ 2771 FileInfo(path="another_file", is_dir=False, file_size=2), 2772 FileInfo(path="file", is_dir=False, file_size=1), 2773 ] 2774 2775 2776 def test_log_model_artifacts(tmp_path: Path, tracking_uri: str) -> None: 2777 client = MlflowClient(tracking_uri=tracking_uri) 2778 experiment_id = client.create_experiment("test") 2779 model = client.create_logged_model(experiment_id=experiment_id) 2780 tmp_path = tmp_path.joinpath("artifacts") 2781 tmp_path.mkdir() 2782 tmp_file = tmp_path.joinpath("file") 2783 tmp_file.write_text("a") 2784 tmp_dir = tmp_path.joinpath("dir") 2785 tmp_dir.mkdir() 2786 another_file = tmp_dir.joinpath("another_file") 2787 another_file.write_text("aa") 2788 client.log_model_artifacts(model_id=model.model_id, local_dir=str(tmp_path)) 2789 artifacts = client.list_logged_model_artifacts(model_id=model.model_id) 2790 artifacts = sorted(artifacts, key=lambda x: x.path) 2791 assert artifacts == [ 2792 FileInfo(path="dir", is_dir=True, file_size=None), 2793 FileInfo(path="file", is_dir=False, file_size=1), 2794 ] 2795 artifacts = client.list_logged_model_artifacts(model_id=model.model_id, path="dir") 2796 assert artifacts == [FileInfo(path="dir/another_file", is_dir=False, file_size=2)] 2797 2798 2799 def test_logged_model_model_id_required(tracking_uri): 2800 client = MlflowClient(tracking_uri=tracking_uri) 2801 2802 with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"): 2803 client.finalize_logged_model("", LoggedModelStatus.READY) 2804 2805 with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"): 2806 client.get_logged_model("") 2807 2808 with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"): 2809 client.delete_logged_model("") 2810 2811 with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"): 2812 client.set_logged_model_tags("", {}) 2813 2814 with pytest.raises(MlflowException, match="`model_id` must be a non-empty string, but got ''"): 2815 client.delete_logged_model_tag("", "") 2816 2817 2818 @pytest.mark.skipif( 2819 "MLFLOW_SKINNY" in os.environ, 2820 reason="Skinny client does not support the np or pandas dependencies", 2821 ) 2822 def test_log_metric_link_to_active_model(tracking_uri): 2823 model = mlflow.create_external_model(name="test_model") 2824 mlflow.set_active_model(name=model.name) 2825 client = MlflowClient(tracking_uri=tracking_uri) 2826 with mlflow.start_run() as run: 2827 client.log_metric(run.info.run_id, "metric", 1) 2828 logged_model = mlflow.get_logged_model(model_id=model.model_id) 2829 assert logged_model.name == model.name 2830 assert logged_model.model_id == model.model_id 2831 assert logged_model.metrics[0].key == "metric" 2832 assert logged_model.metrics[0].value == 1 2833 2834 2835 @pytest.mark.skipif( 2836 "MLFLOW_SKINNY" in os.environ, 2837 reason="Skinny client does not support the np or pandas dependencies", 2838 ) 2839 def test_log_batch_link_to_active_model(tracking_uri): 2840 model = mlflow.create_external_model(name="test_model") 2841 mlflow.set_active_model(name=model.name) 2842 client = MlflowClient(tracking_uri=tracking_uri) 2843 with mlflow.start_run() as run: 2844 client.log_batch(run.info.run_id, [Metric("metric1", 1, 0, 0), Metric("metric2", 2, 0, 0)]) 2845 logged_model = mlflow.get_logged_model(model_id=model.model_id) 2846 assert logged_model.name == model.name 2847 assert logged_model.model_id == model.model_id 2848 assert {m.key: m.value for m in logged_model.metrics} == { 2849 "metric1": 1, 2850 "metric2": 2, 2851 } 2852 2853 2854 def test_load_prompt_with_alias_uri(tracking_uri, disable_prompt_cache): 2855 client = MlflowClient(tracking_uri=tracking_uri) 2856 2857 # Register two versions of a prompt 2858 client.register_prompt(name="alias_prompt", template="Hello, world!") 2859 client.register_prompt(name="alias_prompt", template="Hello, {{name}}!") 2860 2861 # Assign alias to version 1 2862 client.set_prompt_alias("alias_prompt", alias="production", version=1) 2863 prompt = client.load_prompt("prompts:/alias_prompt@production") 2864 assert prompt.template == "Hello, world!" 2865 assert "production" in prompt.aliases 2866 2867 # Reassign alias to version 2 2868 client.set_prompt_alias("alias_prompt", alias="production", version=2) 2869 prompt = client.load_prompt("prompts:/alias_prompt@production") 2870 assert prompt.template == "Hello, {{name}}!" 2871 assert "production" in prompt.aliases 2872 2873 # Delete alias and verify loading fails 2874 client.delete_prompt_alias("alias_prompt", alias="production") 2875 with pytest.raises( 2876 MlflowException, match=r"Prompt (.*) does not exist.|Prompt alias (.*) not found." 2877 ): 2878 client.load_prompt("prompts:/alias_prompt@production") 2879 2880 # Loading with the 'latest' alias 2881 prompt = client.load_prompt("prompts:/alias_prompt@latest") 2882 assert prompt.template == "Hello, {{name}}!" 2883 2884 2885 def test_load_prompt_allow_missing_name_version(tracking_uri): 2886 client = MlflowClient(tracking_uri=tracking_uri) 2887 2888 # Non-existent prompt by name+version should return None when allow_missing=True 2889 result = client.load_prompt("nonexistent_prompt", version=1, allow_missing=True) 2890 assert result is None 2891 2892 # Non-existent prompt by name+version should raise exception when allow_missing=False 2893 with pytest.raises(MlflowException, match="Prompt with name=nonexistent_prompt not found"): 2894 client.load_prompt("nonexistent_prompt", version=1, allow_missing=False) 2895 2896 # Existing prompt with non-existent version should return None when allow_missing=True 2897 client.register_prompt(name="existing_prompt", template="Hello, world!") 2898 result = client.load_prompt("existing_prompt", version=999, allow_missing=True) 2899 assert result is None 2900 2901 # Existing prompt with non-existent version should raise exception when allow_missing=False 2902 with pytest.raises( 2903 MlflowException, match=r"Prompt \(name=existing_prompt, version=999\) not found" 2904 ): 2905 client.load_prompt("existing_prompt", version=999, allow_missing=False) 2906 2907 2908 def test_load_prompt_allow_missing_uri_version(tracking_uri): 2909 client = MlflowClient(tracking_uri=tracking_uri) 2910 2911 # Non-existent prompt by URI+version should return None when allow_missing=True 2912 result = client.load_prompt("prompts:/nonexistent_prompt/1", allow_missing=True) 2913 assert result is None 2914 2915 # Non-existent prompt by URI+version should raise exception when allow_missing=False 2916 with pytest.raises(MlflowException, match="Prompt with name=nonexistent_prompt not found"): 2917 client.load_prompt("prompts:/nonexistent_prompt/1", allow_missing=False) 2918 2919 # Existing prompt with non-existent version via URI should return None when allow_missing=True 2920 client.register_prompt(name="existing_prompt", template="Hello, world!") 2921 result = client.load_prompt("prompts:/existing_prompt/999", allow_missing=True) 2922 assert result is None 2923 2924 # Existing prompt with non-existent version via URI should raise when allow_missing=False 2925 with pytest.raises( 2926 MlflowException, match=r"Prompt \(name=existing_prompt, version=999\) not found" 2927 ): 2928 client.load_prompt("prompts:/existing_prompt/999", allow_missing=False) 2929 2930 2931 def test_load_prompt_allow_missing_uri_alias(tracking_uri): 2932 client = MlflowClient(tracking_uri=tracking_uri) 2933 2934 # Non-existent prompt with alias should return None when allow_missing=True 2935 result = client.load_prompt("prompts:/nonexistent_prompt@production", allow_missing=True) 2936 assert result is None 2937 2938 # Non-existent prompt with alias should raise exception when allow_missing=False 2939 with pytest.raises(MlflowException, match="Prompt with name=nonexistent_prompt not found"): 2940 client.load_prompt("prompts:/nonexistent_prompt@production", allow_missing=False) 2941 2942 # Existing prompt with non-existent alias should return None when allow_missing=True 2943 client.register_prompt(name="existing_prompt", template="Hello, world!") 2944 result = client.load_prompt("prompts:/existing_prompt@nonexistent_alias", allow_missing=True) 2945 assert result is None 2946 2947 # Existing prompt with non-existent alias should raise exception when allow_missing=False 2948 with pytest.raises(MlflowException, match="Prompt alias nonexistent_alias not found"): 2949 client.load_prompt("prompts:/existing_prompt@nonexistent_alias", allow_missing=False) 2950 2951 2952 def test_create_prompt_chat_format_client_integration(): 2953 chat_template = [ 2954 {"role": "system", "content": "You are a {{style}} assistant."}, 2955 {"role": "user", "content": "{{question}}"}, 2956 ] 2957 2958 response_format = {"type": "string"} 2959 2960 # Use client to create prompt 2961 client = MlflowClient() 2962 prompt = client.register_prompt( 2963 name="test_chat_client", 2964 template=chat_template, 2965 response_format=response_format, 2966 commit_message="Test chat prompt via client", 2967 ) 2968 2969 assert prompt.template == chat_template 2970 assert prompt.response_format == response_format 2971 2972 # Load via client 2973 loaded_prompt = client.get_prompt_version("test_chat_client", 1) 2974 assert not loaded_prompt.is_text_prompt 2975 assert loaded_prompt.template == chat_template 2976 assert loaded_prompt.response_format == response_format 2977 2978 2979 def test_link_chat_prompt_version_to_run(): 2980 chat_template = [ 2981 {"role": "system", "content": "You are a helpful assistant."}, 2982 {"role": "user", "content": "Hello {{name}}!"}, 2983 ] 2984 2985 client = MlflowClient() 2986 prompt = client.register_prompt(name="test_chat_link", template=chat_template) 2987 2988 # Create run and link prompt 2989 run = client.create_run(client.create_experiment("test_exp")) 2990 client.link_prompt_version_to_run(run.info.run_id, prompt) 2991 2992 # Verify linking 2993 run_data = client.get_run(run.info.run_id) 2994 linked_prompts_tag = run_data.data.tags.get(TraceTagKey.LINKED_PROMPTS) 2995 assert linked_prompts_tag is not None 2996 2997 linked_prompts = json.loads(linked_prompts_tag) 2998 assert len(linked_prompts) == 1 2999 assert linked_prompts[0]["name"] == "test_chat_link" 3000 assert linked_prompts[0]["version"] == "1" 3001 3002 3003 def test_create_prompt_with_pydantic_response_format_client(): 3004 class ResponseSchema(BaseModel): 3005 answer: str 3006 confidence: float 3007 3008 client = MlflowClient() 3009 prompt = client.register_prompt( 3010 name="test_pydantic_client", 3011 template="What is {{question}}?", 3012 response_format=ResponseSchema, 3013 commit_message="Test Pydantic response format via client", 3014 ) 3015 3016 assert prompt.response_format == ResponseSchema.model_json_schema() 3017 assert prompt.commit_message == "Test Pydantic response format via client" 3018 3019 # Load and verify 3020 loaded_prompt = client.get_prompt_version("test_pydantic_client", 1) 3021 assert loaded_prompt.response_format == ResponseSchema.model_json_schema() 3022 3023 3024 def test_create_prompt_with_dict_response_format_client(): 3025 response_format = { 3026 "type": "object", 3027 "properties": { 3028 "summary": {"type": "string"}, 3029 "key_points": {"type": "array", "items": {"type": "string"}}, 3030 }, 3031 } 3032 3033 client = MlflowClient() 3034 prompt = client.register_prompt( 3035 name="test_dict_response_client", 3036 template="Analyze this: {{text}}", 3037 response_format=response_format, 3038 tags={"analysis_type": "text"}, 3039 ) 3040 3041 assert prompt.response_format == response_format 3042 assert prompt.tags["analysis_type"] == "text" 3043 3044 # Load and verify 3045 loaded_prompt = client.get_prompt_version("test_dict_response_client", 1) 3046 assert loaded_prompt.response_format == response_format 3047 3048 3049 def test_create_prompt_text_backward_compatibility_client(): 3050 client = MlflowClient() 3051 prompt = client.register_prompt( 3052 name="test_text_backward_client", 3053 template="Hello {{name}}!", 3054 commit_message="Test backward compatibility via client", 3055 ) 3056 3057 assert prompt.is_text_prompt 3058 assert prompt.template == "Hello {{name}}!" 3059 assert prompt.commit_message == "Test backward compatibility via client" 3060 3061 # Load and verify 3062 loaded_prompt = client.get_prompt_version("test_text_backward_client", 1) 3063 assert loaded_prompt.is_text_prompt 3064 assert loaded_prompt.template == "Hello {{name}}!" 3065 3066 3067 def test_create_prompt_complex_chat_template_client(): 3068 chat_template = [ 3069 { 3070 "role": "system", 3071 "content": "You are a {{style}} assistant named {{name}}.", 3072 }, 3073 {"role": "user", "content": "{{greeting}}! {{question}}"}, 3074 { 3075 "role": "assistant", 3076 "content": "I understand you're asking about {{topic}}.", 3077 }, 3078 ] 3079 3080 client = MlflowClient() 3081 prompt = client.register_prompt( 3082 name="test_complex_chat_client", 3083 template=chat_template, 3084 tags={"complexity": "high"}, 3085 ) 3086 3087 assert prompt.template == chat_template 3088 assert prompt.tags["complexity"] == "high" 3089 3090 # Load and verify 3091 loaded_prompt = client.get_prompt_version("test_complex_chat_client", 1) 3092 assert not loaded_prompt.is_text_prompt 3093 assert loaded_prompt.template == chat_template 3094 3095 3096 def test_create_prompt_with_none_response_format_client(): 3097 client = MlflowClient() 3098 prompt = client.register_prompt( 3099 name="test_none_response_client", 3100 template="Hello {{name}}!", 3101 response_format=None, 3102 ) 3103 3104 assert prompt.response_format is None 3105 3106 # Load and verify 3107 loaded_prompt = client.get_prompt_version("test_none_response_client", 1) 3108 assert loaded_prompt.response_format is None 3109 3110 3111 def test_create_prompt_with_single_message_chat_client(): 3112 chat_template = [{"role": "user", "content": "Hello {{name}}!"}] 3113 3114 client = MlflowClient() 3115 prompt = client.register_prompt(name="test_single_message_client", template=chat_template) 3116 3117 assert prompt.template == chat_template 3118 assert prompt.variables == {"name"} 3119 3120 # Load and verify 3121 loaded_prompt = client.get_prompt_version("test_single_message_client", 1) 3122 assert not loaded_prompt.is_text_prompt 3123 assert loaded_prompt.template == chat_template 3124 3125 3126 def test_create_prompt_with_multiple_variables_in_chat_client(): 3127 chat_template = [ 3128 { 3129 "role": "system", 3130 "content": "You are a {{style}} assistant named {{name}}.", 3131 }, 3132 {"role": "user", "content": "{{greeting}}! {{question}}"}, 3133 { 3134 "role": "assistant", 3135 "content": "I understand you're asking about {{topic}}.", 3136 }, 3137 ] 3138 3139 client = MlflowClient() 3140 prompt = client.register_prompt(name="test_multiple_variables_client", template=chat_template) 3141 3142 expected_variables = {"style", "name", "greeting", "question", "topic"} 3143 assert prompt.variables == expected_variables 3144 3145 # Load and verify 3146 loaded_prompt = client.get_prompt_version("test_multiple_variables_client", 1) 3147 assert loaded_prompt.variables == expected_variables 3148 3149 3150 def test_create_prompt_with_mixed_content_types_client(): 3151 chat_template = [ 3152 {"role": "system", "content": "You are a helpful assistant."}, 3153 {"role": "user", "content": "Hello {{name}}!"}, 3154 {"role": "assistant", "content": "Hi there! How can I help you today?"}, 3155 ] 3156 3157 client = MlflowClient() 3158 prompt = client.register_prompt(name="test_mixed_content_client", template=chat_template) 3159 3160 assert prompt.template == chat_template 3161 assert prompt.variables == {"name"} 3162 3163 # Load and verify 3164 loaded_prompt = client.get_prompt_version("test_mixed_content_client", 1) 3165 assert not loaded_prompt.is_text_prompt 3166 assert loaded_prompt.template == chat_template 3167 3168 3169 def test_create_prompt_with_nested_variables_client(): 3170 chat_template = [ 3171 { 3172 "role": "system", 3173 "content": "You are a {{user.preferences.style}} assistant.", 3174 }, 3175 { 3176 "role": "user", 3177 "content": "Hello {{user.name}}! {{user.preferences.greeting}}", 3178 }, 3179 ] 3180 3181 client = MlflowClient() 3182 prompt = client.register_prompt(name="test_nested_variables_client", template=chat_template) 3183 3184 expected_variables = { 3185 "user.preferences.style", 3186 "user.name", 3187 "user.preferences.greeting", 3188 } 3189 assert prompt.variables == expected_variables 3190 3191 # Load and verify 3192 loaded_prompt = client.get_prompt_version("test_nested_variables_client", 1) 3193 assert loaded_prompt.variables == expected_variables 3194 3195 3196 def test_link_prompt_with_response_format_to_run(): 3197 response_format = { 3198 "type": "object", 3199 "properties": {"answer": {"type": "string"}}, 3200 } 3201 client = MlflowClient() 3202 prompt = client.register_prompt( 3203 name="test_response_link", 3204 template="What is {{question}}?", 3205 response_format=response_format, 3206 ) 3207 3208 # Create run and link prompt 3209 run = client.create_run(client.create_experiment("test_exp")) 3210 client.link_prompt_version_to_run(run.info.run_id, prompt) 3211 3212 # Verify linking 3213 run_data = client.get_run(run.info.run_id) 3214 linked_prompts_tag = run_data.data.tags.get(TraceTagKey.LINKED_PROMPTS) 3215 assert linked_prompts_tag is not None 3216 3217 linked_prompts = json.loads(linked_prompts_tag) 3218 assert len(linked_prompts) == 1 3219 assert linked_prompts[0]["name"] == "test_response_link" 3220 assert linked_prompts[0]["version"] == "1" 3221 3222 3223 def test_link_multiple_prompt_types_to_run(): 3224 client = MlflowClient() 3225 3226 # Create text prompt 3227 text_prompt = client.register_prompt(name="test_text_link", template="Hello {{name}}!") 3228 3229 # Create chat prompt 3230 chat_template = [ 3231 {"role": "system", "content": "You are a helpful assistant."}, 3232 {"role": "user", "content": "{{question}}"}, 3233 ] 3234 chat_prompt = client.register_prompt(name="test_chat_link_multiple", template=chat_template) 3235 3236 # Create run and link both prompts 3237 run = client.create_run(client.create_experiment("test_exp")) 3238 client.link_prompt_version_to_run(run.info.run_id, text_prompt) 3239 client.link_prompt_version_to_run(run.info.run_id, chat_prompt) 3240 3241 # Verify linking 3242 run_data = client.get_run(run.info.run_id) 3243 linked_prompts_tag = run_data.data.tags.get(TraceTagKey.LINKED_PROMPTS) 3244 assert linked_prompts_tag is not None 3245 3246 linked_prompts = json.loads(linked_prompts_tag) 3247 assert len(linked_prompts) == 2 3248 3249 expected_prompts = [ 3250 {"name": "test_text_link", "version": "1"}, 3251 {"name": "test_chat_link_multiple", "version": "1"}, 3252 ] 3253 for expected_prompt in expected_prompts: 3254 assert expected_prompt in linked_prompts 3255 3256 3257 def test_mlflow_client_create_dataset(mock_store): 3258 created_dataset = EvaluationDataset( 3259 dataset_id="test_dataset_id", 3260 name="test_dataset", 3261 digest="abcdef123456", 3262 created_time=1234567890, 3263 last_update_time=1234567890, 3264 tags={"environment": "production", "version": "1.0"}, 3265 ) 3266 created_dataset.experiment_ids = ["exp1", "exp2"] 3267 mock_store.create_dataset.return_value = created_dataset 3268 3269 # Mock context registry to return empty tags so mlflow.user is not auto-added 3270 with mock.patch( 3271 "mlflow.tracking._tracking_service.client.context_registry.resolve_tags", return_value={} 3272 ): 3273 dataset = MlflowClient().create_dataset( 3274 name="qa_evaluation", 3275 experiment_id=["exp1", "exp2"], 3276 tags={"environment": "production", "version": "1.0"}, 3277 ) 3278 3279 assert dataset.dataset_id == "test_dataset_id" 3280 assert dataset.name == "test_dataset" 3281 assert dataset.tags == {"environment": "production", "version": "1.0"} 3282 3283 mock_store.create_dataset.assert_called_once_with( 3284 name="qa_evaluation", 3285 tags={"environment": "production", "version": "1.0"}, 3286 experiment_ids=["exp1", "exp2"], 3287 ) 3288 3289 3290 def test_mlflow_client_create_evaluation_dataset_minimal(mock_store): 3291 created_dataset = EvaluationDataset( 3292 dataset_id="test_dataset_id", 3293 name="test_dataset", 3294 digest="abcdef123456", 3295 created_time=1234567890, 3296 last_update_time=1234567890, 3297 ) 3298 mock_store.create_dataset.return_value = created_dataset 3299 3300 # Mock context registry to return empty tags so mlflow.user is not auto-added 3301 with mock.patch( 3302 "mlflow.tracking._tracking_service.client.context_registry.resolve_tags", return_value={} 3303 ): 3304 dataset = MlflowClient().create_dataset(name="test_dataset") 3305 3306 assert dataset.dataset_id == "test_dataset_id" 3307 assert dataset.name == "test_dataset" 3308 3309 mock_store.create_dataset.assert_called_once_with( 3310 name="test_dataset", 3311 tags=None, 3312 experiment_ids=None, 3313 ) 3314 3315 3316 def test_mlflow_client_get_dataset(mock_store): 3317 mock_store.get_dataset.return_value = EvaluationDataset( 3318 dataset_id="dataset_123", 3319 name="test_dataset", 3320 digest="abcdef123456", 3321 created_time=1234567890, 3322 last_update_time=1234567890, 3323 tags={"source": "human-annotated"}, 3324 ) 3325 3326 dataset = MlflowClient().get_dataset("dataset_123") 3327 3328 assert dataset.dataset_id == "dataset_123" 3329 assert dataset.name == "test_dataset" 3330 assert dataset.tags == {"source": "human-annotated"} 3331 3332 mock_store.get_dataset.assert_called_once_with("dataset_123") 3333 3334 3335 def test_mlflow_client_delete_dataset(mock_store): 3336 MlflowClient().delete_dataset("dataset_123") 3337 3338 mock_store.delete_dataset.assert_called_once_with("dataset_123") 3339 3340 3341 def test_mlflow_client_search_datasets(mock_store): 3342 mock_store.search_datasets.return_value = PagedList( 3343 [ 3344 EvaluationDataset( 3345 dataset_id="dataset_1", 3346 name="dataset_1", 3347 digest="digest1", 3348 created_time=1234567890, 3349 last_update_time=1234567890, 3350 ), 3351 EvaluationDataset( 3352 dataset_id="dataset_2", 3353 name="dataset_2", 3354 digest="digest2", 3355 created_time=1234567890, 3356 last_update_time=1234567890, 3357 ), 3358 ], 3359 "next_token", 3360 ) 3361 3362 result = MlflowClient().search_datasets( 3363 experiment_ids=["exp1", "exp2"], 3364 filter_string="name LIKE 'qa_%'", 3365 max_results=100, 3366 order_by=["created_time DESC"], 3367 page_token="page_token_123", 3368 ) 3369 3370 assert len(result) == 2 3371 assert result[0].dataset_id == "dataset_1" 3372 assert result[1].dataset_id == "dataset_2" 3373 assert result.token == "next_token" 3374 3375 mock_store.search_datasets.assert_called_once_with( 3376 experiment_ids=["exp1", "exp2"], 3377 filter_string="name LIKE 'qa_%'", 3378 max_results=100, 3379 order_by=["created_time DESC"], 3380 page_token="page_token_123", 3381 ) 3382 3383 3384 def test_mlflow_client_search_datasets_empty_results(mock_store): 3385 mock_store.search_datasets.return_value = PagedList([], None) 3386 3387 result = MlflowClient().search_datasets( 3388 experiment_ids=["exp1"], filter_string="name = 'nonexistent'" 3389 ) 3390 3391 assert len(result) == 0 3392 assert result.token is None 3393 3394 3395 def test_mlflow_client_search_datasets_defaults(mock_store): 3396 mock_store.search_datasets.return_value = PagedList([], None) 3397 3398 result = MlflowClient().search_datasets() 3399 3400 assert len(result) == 0 3401 assert result.token is None 3402 3403 mock_store.search_datasets.assert_called_once_with( 3404 experiment_ids=None, 3405 filter_string=None, 3406 max_results=SEARCH_EVALUATION_DATASETS_MAX_RESULTS, 3407 order_by=None, 3408 page_token=None, 3409 ) 3410 3411 3412 @pytest.mark.skipif(is_windows(), reason="FileStore URI handling issues on Windows") 3413 def test_mlflow_client_datasets_filestore_not_supported(tmp_path): 3414 pytest.skip("FileStore is no longer supported.") 3415 file_store_uri = str(tmp_path) 3416 client = MlflowClient(tracking_uri=file_store_uri) 3417 3418 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3419 client.create_dataset(name="test_dataset") 3420 assert exc_info.value.error_code == "FEATURE_DISABLED" 3421 3422 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3423 client.get_dataset("dataset_123") 3424 assert exc_info.value.error_code == "FEATURE_DISABLED" 3425 3426 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3427 client.delete_dataset("dataset_123") 3428 assert exc_info.value.error_code == "FEATURE_DISABLED" 3429 3430 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3431 client.search_datasets() 3432 assert exc_info.value.error_code == "FEATURE_DISABLED" 3433 3434 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3435 client.set_dataset_tags("dataset_123", {"tag1": "value1"}) 3436 assert exc_info.value.error_code == "FEATURE_DISABLED" 3437 3438 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3439 client.delete_dataset_tag("dataset_123", "tag1") 3440 assert exc_info.value.error_code == "FEATURE_DISABLED" 3441 3442 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3443 client.add_dataset_to_experiments("dataset_123", ["1", "2"]) 3444 assert exc_info.value.error_code == "FEATURE_DISABLED" 3445 3446 with pytest.raises(MlflowException, match="is not supported with FileStore") as exc_info: 3447 client.remove_dataset_from_experiments("dataset_123", ["1", "2"]) 3448 assert exc_info.value.error_code == "FEATURE_DISABLED" 3449 3450 3451 def test_mlflow_client_set_dataset_tags(mock_store): 3452 MlflowClient().set_dataset_tags( 3453 dataset_id="dataset_123", 3454 tags={"env": "prod", "version": "2.0"}, 3455 ) 3456 3457 mock_store.set_dataset_tags.assert_called_once_with( 3458 dataset_id="dataset_123", 3459 tags={"env": "prod", "version": "2.0"}, 3460 ) 3461 3462 3463 def test_mlflow_client_delete_dataset_tag(mock_store): 3464 MlflowClient().delete_dataset_tag( 3465 dataset_id="dataset_123", 3466 key="deprecated", 3467 ) 3468 3469 mock_store.delete_dataset_tag.assert_called_once_with( 3470 dataset_id="dataset_123", 3471 key="deprecated", 3472 ) 3473 3474 3475 def test_mlflow_client_add_dataset_to_experiments(mock_store): 3476 mock_dataset = Mock(spec=EvaluationDataset) 3477 mock_dataset.dataset_id = "dataset_123" 3478 mock_dataset.experiment_ids = ["1", "2", "3"] 3479 mock_store.add_dataset_to_experiments.return_value = mock_dataset 3480 3481 client = MlflowClient() 3482 result = client.add_dataset_to_experiments( 3483 dataset_id="dataset_123", 3484 experiment_ids=["2", "3"], 3485 ) 3486 3487 assert result == mock_dataset 3488 assert result.experiment_ids == ["1", "2", "3"] 3489 mock_store.add_dataset_to_experiments.assert_called_once_with("dataset_123", ["2", "3"]) 3490 3491 3492 def test_mlflow_client_remove_dataset_from_experiments(mock_store): 3493 mock_dataset = Mock(spec=EvaluationDataset) 3494 mock_dataset.dataset_id = "dataset_123" 3495 mock_dataset.experiment_ids = ["1"] 3496 mock_store.remove_dataset_from_experiments.return_value = mock_dataset 3497 3498 client = MlflowClient() 3499 result = client.remove_dataset_from_experiments( 3500 dataset_id="dataset_123", 3501 experiment_ids=["2", "3"], 3502 ) 3503 3504 assert result == mock_dataset 3505 assert result.experiment_ids == ["1"] 3506 mock_store.remove_dataset_from_experiments.assert_called_once_with("dataset_123", ["2", "3"]) 3507 3508 3509 def test_mlflow_client_dataset_associations_databricks_blocking(mock_store): 3510 with mock.patch("mlflow.utils.databricks_utils.is_databricks_uri") as mock_is_dbx: 3511 mock_is_dbx.return_value = True 3512 client = MlflowClient(tracking_uri="databricks") 3513 3514 with pytest.raises( 3515 MlflowException, match="not supported when tracking URI is 'databricks'" 3516 ) as exc_info: 3517 client.add_dataset_to_experiments("dataset_123", ["1", "2"]) 3518 assert exc_info.value.error_code == "INVALID_PARAMETER_VALUE" 3519 3520 with pytest.raises( 3521 MlflowException, match="not supported when tracking URI is 'databricks'" 3522 ) as exc_info: 3523 client.remove_dataset_from_experiments("dataset_123", ["1", "2"]) 3524 assert exc_info.value.error_code == "INVALID_PARAMETER_VALUE" 3525 3526 3527 def test_log_spans_and_get_trace_with_sqlalchemy_store(tmp_path: Path) -> None: 3528 tracking_uri = f"sqlite:///{tmp_path}/test.db" 3529 3530 with _use_tracking_uri(tracking_uri): 3531 client = MlflowClient() 3532 3533 assert isinstance(client._tracking_client.store, SqlAlchemyTrackingStore) 3534 3535 experiment_id = client.create_experiment("test_log_spans_get_trace") 3536 trace_id = f"tr-{uuid.uuid4().hex}" 3537 3538 # Create test spans using OpenTelemetry format 3539 otel_span1 = OTelReadableSpan( 3540 name="parent_span", 3541 context=trace_api.SpanContext( 3542 trace_id=12345, 3543 span_id=111, 3544 is_remote=False, 3545 trace_flags=trace_api.TraceFlags(1), 3546 ), 3547 parent=None, 3548 attributes={ 3549 "mlflow.traceRequestId": json.dumps(trace_id, cls=TraceJSONEncoder), 3550 "llm.model_name": "test-model", 3551 "custom.attribute": "parent-value", 3552 }, 3553 start_time=1_000_000_000, 3554 end_time=2_000_000_000, 3555 resource=None, 3556 ) 3557 3558 otel_span2 = OTelReadableSpan( 3559 name="child_span", 3560 context=trace_api.SpanContext( 3561 trace_id=12345, 3562 span_id=222, 3563 is_remote=False, 3564 trace_flags=trace_api.TraceFlags(1), 3565 ), 3566 parent=trace_api.SpanContext( 3567 trace_id=12345, 3568 span_id=111, 3569 is_remote=False, 3570 trace_flags=trace_api.TraceFlags(1), 3571 ), 3572 attributes={ 3573 "mlflow.traceRequestId": json.dumps(trace_id, cls=TraceJSONEncoder), 3574 "operation.type": "database_query", 3575 "custom.attribute": "child-value", 3576 }, 3577 start_time=1_200_000_000, 3578 end_time=1_800_000_000, 3579 resource=None, 3580 ) 3581 3582 # Convert to MLflow spans 3583 mlflow_spans = [ 3584 create_mlflow_span(otel_span1, trace_id, "LLM"), 3585 create_mlflow_span(otel_span2, trace_id, "LLM"), 3586 ] 3587 3588 # Log spans directly to the store (simulating OTLP endpoint) 3589 store = client._tracking_client.store 3590 logged_spans = store.log_spans(experiment_id, mlflow_spans) 3591 3592 # Verify spans were logged 3593 assert len(logged_spans) == 2 3594 3595 # Verify the trace has the spans location tag set 3596 trace_info = store.get_trace_info(trace_id) 3597 assert trace_info.tags.get(TraceTagKey.SPANS_LOCATION) == SpansLocation.TRACKING_STORE 3598 3599 # Now test that mlflow.get_trace() works and loads spans from the database 3600 trace = mlflow.get_trace(trace_id) 3601 3602 # Verify trace structure 3603 assert trace.info.trace_id == trace_id 3604 assert trace.info.tags.get(TraceTagKey.SPANS_LOCATION) == SpansLocation.TRACKING_STORE 3605 3606 # Verify spans were loaded from database 3607 assert len(trace.data.spans) == 2 3608 3609 # Sort spans by start time for consistent testing 3610 spans_by_start_time = sorted(trace.data.spans, key=lambda s: s.start_time_ns) 3611 3612 # Verify parent span 3613 parent_span = spans_by_start_time[0] 3614 assert parent_span.name == "parent_span" 3615 assert parent_span.trace_id == trace_id 3616 assert parent_span.start_time_ns == 1_000_000_000 3617 assert parent_span.end_time_ns == 2_000_000_000 3618 assert parent_span.attributes.get("llm.model_name") == "test-model" 3619 assert parent_span.attributes.get("custom.attribute") == "parent-value" 3620 3621 # Verify child span 3622 child_span = spans_by_start_time[1] 3623 assert child_span.name == "child_span" 3624 assert child_span.trace_id == trace_id 3625 assert child_span.start_time_ns == 1_200_000_000 3626 assert child_span.end_time_ns == 1_800_000_000 3627 assert child_span.attributes.get("operation.type") == "database_query" 3628 assert child_span.attributes.get("custom.attribute") == "child-value" 3629 3630 3631 def test_mlflow_get_trace_with_sqlalchemy_store(tmp_path: Path) -> None: 3632 tracking_uri = f"sqlite:///{tmp_path}/test.db" 3633 3634 with _use_tracking_uri(tracking_uri): 3635 client = MlflowClient() 3636 3637 assert isinstance(client._tracking_client.store, SqlAlchemyTrackingStore) 3638 3639 with mlflow.start_span() as span: 3640 pass 3641 3642 trace_id = span.trace_id 3643 mlflow.flush_trace_async_logging() 3644 sql_alchemy_store_module = "mlflow.store.tracking.sqlalchemy_store.SqlAlchemyStore" 3645 with ( 3646 mock.patch(f"{sql_alchemy_store_module}.get_trace") as mock_get_trace, 3647 ): 3648 mlflow.get_trace(trace_id) 3649 3650 mock_get_trace.assert_called_once_with(trace_id) 3651 3652 with ( 3653 mock.patch( 3654 f"{sql_alchemy_store_module}.get_trace", 3655 side_effect=MlflowNotImplementedException, 3656 ) as mock_get_trace, 3657 mock.patch(f"{sql_alchemy_store_module}.batch_get_traces") as mock_batch_get_traces, 3658 ): 3659 mlflow.get_trace(trace_id) 3660 3661 mock_get_trace.assert_called_once_with(trace_id) 3662 mock_batch_get_traces.assert_called_once_with([trace_id]) 3663 3664 3665 def test_create_issue_basic(tmp_path: Path): 3666 tracking_uri = f"sqlite:///{tmp_path}/test.db" 3667 3668 with _use_tracking_uri(tracking_uri): 3669 client = MlflowClient() 3670 exp_id = client.create_experiment("test_create_issue") 3671 tracing_client = client._tracing_client 3672 3673 issue = tracing_client._create_issue( 3674 experiment_id=exp_id, 3675 name="Test issue", 3676 description="This is a test issue", 3677 ) 3678 3679 assert issue.issue_id.startswith("iss-") 3680 assert issue.experiment_id == exp_id 3681 assert issue.name == "Test issue" 3682 assert issue.description == "This is a test issue" 3683 assert issue.status == IssueStatus.PENDING 3684 assert issue.severity is None 3685 assert issue.root_causes is None 3686 assert issue.source_run_id is None 3687 assert issue.created_by is None 3688 assert issue.created_timestamp > 0 3689 assert issue.last_updated_timestamp == issue.created_timestamp 3690 3691 3692 def test_create_issue_with_all_fields(tmp_path: Path): 3693 tracking_uri = f"sqlite:///{tmp_path}/test.db" 3694 3695 with _use_tracking_uri(tracking_uri): 3696 client = MlflowClient() 3697 exp_id = client.create_experiment("test_create_issue_all_fields") 3698 tracing_client = client._tracing_client 3699 with mlflow.start_run(experiment_id=exp_id) as run: 3700 issue = tracing_client._create_issue( 3701 experiment_id=exp_id, 3702 name="High latency", 3703 description="API response times exceed threshold", 3704 status=IssueStatus.RESOLVED, 3705 severity=IssueSeverity.HIGH, 3706 root_causes=["Database query slow", "Network congestion"], 3707 source_run_id=run.info.run_id, 3708 created_by="monitoring_system", 3709 ) 3710 3711 assert issue.issue_id.startswith("iss-") 3712 assert issue.experiment_id == exp_id 3713 assert issue.name == "High latency" 3714 assert issue.description == "API response times exceed threshold" 3715 assert issue.status == IssueStatus.RESOLVED 3716 assert issue.severity == IssueSeverity.HIGH 3717 assert issue.root_causes == ["Database query slow", "Network congestion"] 3718 assert issue.source_run_id == run.info.run_id 3719 assert issue.created_by == "monitoring_system" 3720 assert issue.created_timestamp > 0