test_tracking.py
1 import filecmp 2 import io 3 import json 4 import os 5 import pathlib 6 import posixpath 7 import random 8 import re 9 from datetime import datetime, timezone 10 from typing import NamedTuple 11 from unittest import mock 12 13 import pytest 14 import yaml 15 16 import mlflow 17 from mlflow import MlflowClient, tracking 18 from mlflow.entities import LifecycleStage, Metric, Param, RunStatus, RunTag, ViewType 19 from mlflow.environment_variables import ( 20 MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE, 21 MLFLOW_RUN_ID, 22 ) 23 from mlflow.exceptions import MlflowException 24 from mlflow.protos.databricks_pb2 import ( 25 INVALID_PARAMETER_VALUE, 26 RESOURCE_DOES_NOT_EXIST, 27 ErrorCode, 28 ) 29 from mlflow.store.tracking.file_store import FileStore 30 from mlflow.tracking._tracking_service.client import TrackingServiceClient 31 from mlflow.tracking.fluent import start_run 32 from mlflow.utils.file_utils import local_file_uri_to_path 33 from mlflow.utils.mlflow_tags import ( 34 MLFLOW_PARENT_RUN_ID, 35 MLFLOW_RUN_NAME, 36 MLFLOW_SOURCE_NAME, 37 MLFLOW_SOURCE_TYPE, 38 MLFLOW_USER, 39 ) 40 from mlflow.utils.os import is_windows 41 from mlflow.utils.time import get_current_time_millis 42 from mlflow.utils.validation import ( 43 MAX_METRICS_PER_BATCH, 44 MAX_PARAMS_TAGS_PER_BATCH, 45 ) 46 47 48 class MockExperiment(NamedTuple): 49 experiment_id: str 50 lifecycle_stage: str 51 tags: dict[str, str] = {} 52 53 54 def test_create_experiment(): 55 with pytest.raises(MlflowException, match="Invalid experiment name"): 56 mlflow.create_experiment(None) 57 58 with pytest.raises(MlflowException, match="Invalid experiment name"): 59 mlflow.create_experiment("") 60 61 exp_id = mlflow.create_experiment(f"Some random experiment name {random.randint(1, int(1e6))}") 62 assert exp_id is not None 63 64 65 def test_create_experiment_with_duplicate_name(): 66 name = "popular_name" 67 exp_id = mlflow.create_experiment(name) 68 69 with pytest.raises(MlflowException, match=re.escape(f"Experiment(name={name}) already exists")): 70 mlflow.create_experiment(name) 71 72 tracking.MlflowClient().delete_experiment(exp_id) 73 with pytest.raises(MlflowException, match=re.escape(f"Experiment(name={name}) already exists")): 74 mlflow.create_experiment(name) 75 76 77 def test_create_experiments_with_bad_names(): 78 # None for name 79 with pytest.raises(MlflowException, match="Invalid experiment name: 'None'"): 80 mlflow.create_experiment(None) 81 82 # empty string name 83 with pytest.raises(MlflowException, match="Invalid experiment name: ''"): 84 mlflow.create_experiment("") 85 86 87 @pytest.mark.parametrize("name", [123, 0, -1.2, [], ["A"], {1: 2}]) 88 def test_create_experiments_with_bad_name_types(name): 89 with pytest.raises( 90 MlflowException, 91 match=re.escape(f"Invalid experiment name: {name}. Expects a string."), 92 ): 93 mlflow.create_experiment(name) 94 95 96 @pytest.mark.usefixtures("reset_active_experiment") 97 def test_set_experiment_by_name(): 98 name = "random_exp" 99 exp_id = mlflow.create_experiment(name) 100 exp1 = mlflow.set_experiment(name) 101 assert exp1.experiment_id == exp_id 102 with start_run() as run: 103 assert run.info.experiment_id == exp_id 104 105 another_name = "another_experiment" 106 exp2 = mlflow.set_experiment(another_name) 107 with start_run() as another_run: 108 assert another_run.info.experiment_id == exp2.experiment_id 109 110 111 @pytest.mark.usefixtures("reset_active_experiment") 112 def test_set_experiment_by_id(): 113 name = "random_exp" 114 exp_id = mlflow.create_experiment(name) 115 active_exp = mlflow.set_experiment(experiment_id=exp_id) 116 assert active_exp.experiment_id == exp_id 117 with start_run() as run: 118 assert run.info.experiment_id == exp_id 119 120 nonexistent_id = "-1337" 121 with pytest.raises(MlflowException, match="No Experiment with id=-1337 exists") as exc: 122 mlflow.set_experiment(experiment_id=nonexistent_id) 123 assert exc.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST) 124 with start_run() as run: 125 assert run.info.experiment_id == exp_id 126 127 128 def test_set_experiment_parameter_validation(): 129 with pytest.raises(MlflowException, match="Must specify exactly one") as exc: 130 mlflow.set_experiment() 131 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 132 133 with pytest.raises(MlflowException, match="Must specify exactly one") as exc: 134 mlflow.set_experiment(None) 135 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 136 137 with pytest.raises(MlflowException, match="Must specify exactly one") as exc: 138 mlflow.set_experiment(None, None) 139 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 140 141 with pytest.raises(MlflowException, match="Must specify exactly one") as exc: 142 mlflow.set_experiment("name", "id") 143 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 144 145 146 def test_set_experiment_with_deleted_experiment(): 147 name = "dead_exp" 148 mlflow.set_experiment(name) 149 with start_run() as run: 150 exp_id = run.info.experiment_id 151 152 tracking.MlflowClient().delete_experiment(exp_id) 153 154 with pytest.raises(MlflowException, match="Cannot set a deleted experiment") as exc: 155 mlflow.set_experiment(name) 156 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 157 158 with pytest.raises(MlflowException, match="Cannot set a deleted experiment") as exc: 159 mlflow.set_experiment(experiment_id=exp_id) 160 assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 161 162 163 @pytest.mark.usefixtures("reset_active_experiment") 164 def test_set_experiment_with_zero_id(): 165 mock_experiment = MockExperiment(experiment_id=0, lifecycle_stage=LifecycleStage.ACTIVE) 166 with ( 167 mock.patch.object( 168 TrackingServiceClient, 169 "get_experiment_by_name", 170 mock.Mock(return_value=mock_experiment), 171 ) as get_experiment_by_name_mock, 172 mock.patch.object(TrackingServiceClient, "create_experiment") as create_experiment_mock, 173 ): 174 mlflow.set_experiment("my_exp") 175 get_experiment_by_name_mock.assert_called_once() 176 create_experiment_mock.assert_not_called() 177 178 179 def test_start_run_context_manager(): 180 with start_run() as first_run: 181 first_uuid = first_run.info.run_id 182 # Check that start_run() causes the run information to be persisted in the store 183 persisted_run = tracking.MlflowClient().get_run(first_uuid) 184 assert persisted_run is not None 185 assert persisted_run.info == first_run.info 186 finished_run = tracking.MlflowClient().get_run(first_uuid) 187 assert finished_run.info.status == RunStatus.to_string(RunStatus.FINISHED) 188 # Launch a separate run that fails, verify the run status is FAILED and the run UUID is 189 # different 190 with pytest.raises(Exception, match="Failing run!"): 191 with start_run() as second_run: 192 raise Exception("Failing run!") 193 second_run_id = second_run.info.run_id 194 assert second_run_id != first_uuid 195 finished_run2 = tracking.MlflowClient().get_run(second_run_id) 196 assert finished_run2.info.status == RunStatus.to_string(RunStatus.FAILED) 197 198 199 def test_start_and_end_run(): 200 # Use the start_run() and end_run() APIs without a `with` block, verify they work. 201 202 with start_run() as active_run: 203 mlflow.log_metric("name_1", 25) 204 finished_run = tracking.MlflowClient().get_run(active_run.info.run_id) 205 # Validate metrics 206 assert len(finished_run.data.metrics) == 1 207 assert finished_run.data.metrics["name_1"] == 25 208 209 210 def test_metric_timestamp(): 211 with mlflow.start_run() as active_run: 212 mlflow.log_metric("name_1", 25) 213 mlflow.log_metric("name_1", 30) 214 run_id = active_run.info.run_id 215 # Check that metric timestamps are between run start and finish 216 client = MlflowClient() 217 history = client.get_metric_history(run_id, "name_1") 218 finished_run = client.get_run(run_id) 219 assert len(history) == 2 220 assert all( 221 m.timestamp >= finished_run.info.start_time and m.timestamp <= finished_run.info.end_time 222 for m in history 223 ) 224 225 226 def test_log_batch(): 227 expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0} 228 expected_params = { 229 "param-key0": "param-val0", 230 "param-key1": 123, 231 "param-key2": None, 232 } 233 exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"} 234 approx_expected_tags = { 235 MLFLOW_USER, 236 MLFLOW_SOURCE_NAME, 237 MLFLOW_SOURCE_TYPE, 238 MLFLOW_RUN_NAME, 239 } 240 241 t = get_current_time_millis() 242 sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[0]) 243 metrics = [ 244 Metric(key=key, value=value, timestamp=t, step=i) 245 for i, (key, value) in enumerate(sorted_expected_metrics) 246 ] 247 params = [Param(key=key, value=value) for key, value in expected_params.items()] 248 tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()] 249 250 with start_run() as active_run: 251 run_id = active_run.info.run_id 252 MlflowClient().log_batch(run_id=run_id, metrics=metrics, params=params, tags=tags) 253 client = tracking.MlflowClient() 254 finished_run = client.get_run(run_id) 255 # Validate metrics 256 assert len(finished_run.data.metrics) == 2 257 for key, value in finished_run.data.metrics.items(): 258 assert expected_metrics[key] == value 259 metric_history0 = client.get_metric_history(run_id, "metric-key0") 260 assert {(m.value, m.timestamp, m.step) for m in metric_history0} == {(1.0, t, 0)} 261 metric_history1 = client.get_metric_history(run_id, "metric-key1") 262 assert {(m.value, m.timestamp, m.step) for m in metric_history1} == {(4.0, t, 1)} 263 264 # Validate tags (for automatically-set tags) 265 assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags) 266 for tag_key, tag_value in finished_run.data.tags.items(): 267 if tag_key in approx_expected_tags: 268 pass 269 else: 270 assert exact_expected_tags[tag_key] == tag_value 271 # Validate params 272 assert finished_run.data.params == {k: str(v) for k, v in expected_params.items()} 273 # test that log_batch works with fewer params 274 new_tags = {"1": "2", "3": "4", "5": "6"} 275 tags = [RunTag(key=key, value=value) for key, value in new_tags.items()] 276 client.log_batch(run_id=run_id, tags=tags) 277 finished_run_2 = client.get_run(run_id) 278 # Validate tags (for automatically-set tags) 279 assert len(finished_run_2.data.tags) == len(finished_run.data.tags) + 3 280 for tag_key, tag_value in finished_run_2.data.tags.items(): 281 if tag_key in new_tags: 282 assert new_tags[tag_key] == tag_value 283 284 285 def test_log_batch_with_many_elements(): 286 num_metrics = MAX_METRICS_PER_BATCH * 2 287 num_params = num_tags = MAX_PARAMS_TAGS_PER_BATCH * 2 288 expected_metrics = {f"metric-key{i}": float(i) for i in range(num_metrics)} 289 expected_params = {f"param-key{i}": f"param-val{i}" for i in range(num_params)} 290 exact_expected_tags = {f"tag-key{i}": f"tag-val{i}" for i in range(num_tags)} 291 292 t = get_current_time_millis() 293 sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[1]) 294 metrics = [ 295 Metric(key=key, value=value, timestamp=t, step=i) 296 for i, (key, value) in enumerate(sorted_expected_metrics) 297 ] 298 params = [Param(key=key, value=value) for key, value in expected_params.items()] 299 tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()] 300 301 with start_run() as active_run: 302 run_id = active_run.info.run_id 303 mlflow.tracking.MlflowClient().log_batch( 304 run_id=run_id, metrics=metrics, params=params, tags=tags 305 ) 306 client = tracking.MlflowClient() 307 finished_run = client.get_run(run_id) 308 # Validate metrics 309 assert expected_metrics == finished_run.data.metrics 310 for i in range(num_metrics): 311 metric_history = client.get_metric_history(run_id, f"metric-key{i}") 312 assert {(m.value, m.timestamp, m.step) for m in metric_history} == {(float(i), t, i)} 313 314 # Validate tags 315 logged_tags = finished_run.data.tags 316 for tag_key, tag_value in exact_expected_tags.items(): 317 assert logged_tags[tag_key] == tag_value 318 319 # Validate params 320 assert finished_run.data.params == expected_params 321 322 323 def test_log_metric(): 324 with start_run() as active_run, mock.patch("time.time", return_value=123): 325 run_id = active_run.info.run_id 326 mlflow.log_metric("name_1", 25) 327 mlflow.log_metric("name_2", -3) 328 mlflow.log_metric("name_1", 30, 5) 329 mlflow.log_metric("name_1", 40, -2) 330 mlflow.log_metric("nested/nested/name", 40) 331 finished_run = tracking.MlflowClient().get_run(run_id) 332 # Validate metrics 333 assert len(finished_run.data.metrics) == 3 334 expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40} 335 for key, value in finished_run.data.metrics.items(): 336 assert expected_pairs[key] == value 337 client = tracking.MlflowClient() 338 metric_history_name1 = client.get_metric_history(run_id, "name_1") 339 assert {(m.value, m.timestamp, m.step) for m in metric_history_name1} == { 340 (25, 123 * 1000, 0), 341 (30, 123 * 1000, 5), 342 (40, 123 * 1000, -2), 343 } 344 metric_history_name2 = client.get_metric_history(run_id, "name_2") 345 assert {(m.value, m.timestamp, m.step) for m in metric_history_name2} == {(-3, 123 * 1000, 0)} 346 347 348 def test_log_metrics_uses_millisecond_timestamp_resolution_fluent(): 349 with start_run() as active_run, mock.patch("time.time") as time_mock: 350 time_mock.side_effect = lambda: 123 351 mlflow.log_metrics({"name_1": 25, "name_2": -3}) 352 mlflow.log_metrics({"name_1": 30}) 353 mlflow.log_metrics({"name_1": 40}) 354 run_id = active_run.info.run_id 355 356 client = tracking.MlflowClient() 357 metric_history_name1 = client.get_metric_history(run_id, "name_1") 358 assert {(m.value, m.timestamp) for m in metric_history_name1} == { 359 (25, 123 * 1000), 360 (30, 123 * 1000), 361 (40, 123 * 1000), 362 } 363 metric_history_name2 = client.get_metric_history(run_id, "name_2") 364 assert {(m.value, m.timestamp) for m in metric_history_name2} == {(-3, 123 * 1000)} 365 366 367 def test_log_metrics_uses_millisecond_timestamp_resolution_client(): 368 with start_run() as active_run, mock.patch("time.time") as time_mock: 369 time_mock.side_effect = lambda: 123 370 mlflow_client = tracking.MlflowClient() 371 run_id = active_run.info.run_id 372 373 mlflow_client.log_metric(run_id=run_id, key="name_1", value=25) 374 mlflow_client.log_metric(run_id=run_id, key="name_2", value=-3) 375 mlflow_client.log_metric(run_id=run_id, key="name_1", value=30) 376 mlflow_client.log_metric(run_id=run_id, key="name_1", value=40) 377 378 metric_history_name1 = mlflow_client.get_metric_history(run_id, "name_1") 379 assert {(m.value, m.timestamp) for m in metric_history_name1} == { 380 (25, 123 * 1000), 381 (30, 123 * 1000), 382 (40, 123 * 1000), 383 } 384 385 metric_history_name2 = mlflow_client.get_metric_history(run_id, "name_2") 386 assert {(m.value, m.timestamp) for m in metric_history_name2} == {(-3, 123 * 1000)} 387 388 389 @pytest.mark.parametrize("step_kwarg", [None, -10, 5]) 390 def test_log_metrics_uses_common_timestamp_and_step_per_invocation(step_kwarg): 391 expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40} 392 with start_run() as active_run: 393 run_id = active_run.info.run_id 394 mlflow.log_metrics(expected_metrics, step=step_kwarg) 395 finished_run = tracking.MlflowClient().get_run(run_id) 396 # Validate metric key/values match what we expect, and that all metrics have the same timestamp 397 assert len(finished_run.data.metrics) == len(expected_metrics) 398 for key, value in finished_run.data.metrics.items(): 399 assert expected_metrics[key] == value 400 common_timestamp = finished_run.data._metric_objs[0].timestamp 401 expected_step = step_kwarg if step_kwarg is not None else 0 402 for metric_obj in finished_run.data._metric_objs: 403 assert metric_obj.timestamp == common_timestamp 404 assert metric_obj.step == expected_step 405 406 407 @pytest.fixture 408 def get_store_mock(): 409 with mock.patch("mlflow.store.file_store.FileStore.log_batch") as _get_store_mock: 410 yield _get_store_mock 411 412 413 def test_set_tags(): 414 exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": 5} 415 approx_expected_tags = { 416 MLFLOW_USER, 417 MLFLOW_SOURCE_NAME, 418 MLFLOW_SOURCE_TYPE, 419 MLFLOW_RUN_NAME, 420 } 421 with start_run() as active_run: 422 run_id = active_run.info.run_id 423 mlflow.set_tags(exact_expected_tags) 424 finished_run = tracking.MlflowClient().get_run(run_id) 425 # Validate tags 426 assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags) 427 for tag_key, tag_val in finished_run.data.tags.items(): 428 if tag_key in approx_expected_tags: 429 pass 430 else: 431 assert str(exact_expected_tags[tag_key]) == tag_val 432 433 434 def test_log_metric_validation(): 435 with start_run() as active_run: 436 run_id = active_run.info.run_id 437 with pytest.raises( 438 MlflowException, 439 match="Invalid value \"apple\" for parameter 'value' supplied", 440 ) as e: 441 mlflow.log_metric("name_1", "apple") 442 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 443 finished_run = tracking.MlflowClient().get_run(run_id) 444 assert len(finished_run.data.metrics) == 0 445 446 447 def test_log_param(): 448 with start_run() as active_run: 449 run_id = active_run.info.run_id 450 assert mlflow.log_param("name_1", "a") == "a" 451 assert mlflow.log_param("name_2", "b") == "b" 452 assert mlflow.log_param("nested/nested/name", 5) == 5 453 finished_run = tracking.MlflowClient().get_run(run_id) 454 # Validate params 455 assert finished_run.data.params == { 456 "name_1": "a", 457 "name_2": "b", 458 "nested/nested/name": "5", 459 } 460 461 462 def test_log_params(): 463 expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": 5} 464 with start_run() as active_run: 465 run_id = active_run.info.run_id 466 mlflow.log_params(expected_params) 467 finished_run = tracking.MlflowClient().get_run(run_id) 468 # Validate params 469 assert finished_run.data.params == { 470 "name_1": "c", 471 "name_2": "b", 472 "nested/nested/name": "5", 473 } 474 475 476 def test_log_params_duplicate_keys_raises(): 477 params = {"a": "1", "b": "2"} 478 with start_run() as active_run: 479 run_id = active_run.info.run_id 480 mlflow.log_params(params) 481 with pytest.raises( 482 expected_exception=MlflowException, 483 match=r"Changing param values is not allowed. Param with key=", 484 ) as e: 485 mlflow.log_param("a", "3") 486 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 487 finished_run = tracking.MlflowClient().get_run(run_id) 488 assert finished_run.data.params == params 489 490 491 @pytest.mark.skipif(is_windows(), reason="Windows do not support colon in params and metrics") 492 def test_param_metric_with_colon(): 493 with start_run() as active_run: 494 run_id = active_run.info.run_id 495 mlflow.log_param("a:b", 3) 496 mlflow.log_metric("c:d", 4) 497 finished_run = tracking.MlflowClient().get_run(run_id) 498 499 # Validate param 500 assert len(finished_run.data.params) == 1 501 assert finished_run.data.params == {"a:b": "3"} 502 503 # Validate metric 504 assert len(finished_run.data.metrics) == 1 505 assert finished_run.data.metrics["c:d"] == 4 506 507 508 def test_log_batch_duplicate_entries_raises(): 509 with start_run() as active_run: 510 run_id = active_run.info.run_id 511 with pytest.raises( 512 MlflowException, match=r"Duplicate parameter keys have been submitted." 513 ) as e: 514 tracking.MlflowClient().log_batch( 515 run_id=run_id, params=[Param("a", "1"), Param("a", "2")] 516 ) 517 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 518 519 520 def test_log_batch_validates_entity_names_and_values(): 521 with start_run() as active_run: 522 run_id = active_run.info.run_id 523 524 metrics = [Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)] 525 with pytest.raises( 526 MlflowException, 527 match=r"Invalid value \"../bad/metric/name\" for parameter \'metrics\[0\].name\'", 528 ) as e: 529 tracking.MlflowClient().log_batch(run_id, metrics=metrics) 530 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 531 532 metrics = [Metric(key="ok-name", value="non-numerical-value", timestamp=3, step=0)] 533 with pytest.raises( 534 MlflowException, 535 match=r"Invalid value \"non-numerical-value\" " 536 + r"for parameter \'metrics\[0\].value\' supplied", 537 ) as e: 538 tracking.MlflowClient().log_batch(run_id, metrics=metrics) 539 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 540 541 metrics = [Metric(key="ok-name", value=0.3, timestamp="non-numerical-timestamp", step=0)] 542 with pytest.raises( 543 MlflowException, 544 match=r"Invalid value \"non-numerical-timestamp\" for " 545 + r"parameter \'metrics\[0\].timestamp\' supplied", 546 ) as e: 547 tracking.MlflowClient().log_batch(run_id, metrics=metrics) 548 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 549 550 params = [Param(key="../bad/param/name", value="my-val")] 551 with pytest.raises( 552 MlflowException, 553 match=r"Invalid value \"../bad/param/name\" for parameter \'params\[0\].key\' supplied", 554 ) as e: 555 tracking.MlflowClient().log_batch(run_id, params=params) 556 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 557 558 tags = [Param(key="../bad/tag/name", value="my-val")] 559 with pytest.raises( 560 MlflowException, 561 match=r"Invalid value \"../bad/tag/name\" for parameter \'tags\[0\].key\' supplied", 562 ) as e: 563 tracking.MlflowClient().log_batch(run_id, tags=tags) 564 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 565 566 metrics = [Metric(key=None, value=42.0, timestamp=4, step=1)] 567 with pytest.raises( 568 MlflowException, 569 match="Metric name cannot be None. A key name must be provided.", 570 ) as e: 571 tracking.MlflowClient().log_batch(run_id, metrics=metrics) 572 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 573 574 575 def test_log_artifact_with_dirs(tmp_path): 576 # Test log artifact with a directory 577 art_dir = tmp_path / "parent" 578 art_dir.mkdir() 579 file0 = art_dir.joinpath("file0") 580 file0.write_text("something") 581 file1 = art_dir.joinpath("file1") 582 file1.write_text("something") 583 sub_dir = art_dir / "child" 584 sub_dir.mkdir() 585 with start_run(): 586 artifact_uri = mlflow.get_artifact_uri() 587 run_artifact_dir = local_file_uri_to_path(artifact_uri) 588 mlflow.log_artifact(str(art_dir)) 589 base = os.path.basename(str(art_dir)) 590 assert os.listdir(run_artifact_dir) == [base] 591 assert set(os.listdir(os.path.join(run_artifact_dir, base))) == { 592 "child", 593 "file0", 594 "file1", 595 } 596 with open(os.path.join(run_artifact_dir, base, "file0")) as f: 597 assert f.read() == "something" 598 # Test log artifact with directory and specified parent folder 599 600 art_dir = tmp_path / "dir" 601 art_dir.mkdir() 602 with start_run(): 603 artifact_uri = mlflow.get_artifact_uri() 604 run_artifact_dir = local_file_uri_to_path(artifact_uri) 605 mlflow.log_artifact(str(art_dir), "some_parent") 606 assert os.listdir(run_artifact_dir) == [os.path.basename("some_parent")] 607 assert os.listdir(os.path.join(run_artifact_dir, "some_parent")) == [ 608 os.path.basename(str(art_dir)) 609 ] 610 sub_dir = art_dir.joinpath("another_dir") 611 sub_dir.mkdir() 612 with start_run(): 613 artifact_uri = mlflow.get_artifact_uri() 614 run_artifact_dir = local_file_uri_to_path(artifact_uri) 615 mlflow.log_artifact(str(art_dir), "parent/and_child") 616 assert os.listdir(os.path.join(run_artifact_dir, "parent", "and_child")) == [ 617 os.path.basename(str(art_dir)) 618 ] 619 assert set( 620 os.listdir( 621 os.path.join( 622 run_artifact_dir, 623 "parent", 624 "and_child", 625 os.path.basename(str(art_dir)), 626 ) 627 ) 628 ) == {os.path.basename(str(sub_dir))} 629 630 631 def test_log_artifact(tmp_path): 632 # Create artifacts 633 artifact_dir = tmp_path.joinpath("artifact_dir") 634 artifact_dir.mkdir() 635 path0 = artifact_dir.joinpath("file0") 636 path1 = artifact_dir.joinpath("file1") 637 path0.write_text("0") 638 path1.write_text("1") 639 # Log an artifact, verify it exists in the directory returned by get_artifact_uri 640 # after the run finishes 641 artifact_parent_dirs = ["some_parent_dir", None] 642 for parent_dir in artifact_parent_dirs: 643 with start_run(): 644 artifact_uri = mlflow.get_artifact_uri() 645 run_artifact_dir = local_file_uri_to_path(artifact_uri) 646 mlflow.log_artifact(path0, parent_dir) 647 expected_dir = ( 648 os.path.join(run_artifact_dir, parent_dir) 649 if parent_dir is not None 650 else run_artifact_dir 651 ) 652 assert os.listdir(expected_dir) == [os.path.basename(path0)] 653 logged_artifact_path = os.path.join(expected_dir, path0) 654 assert filecmp.cmp(logged_artifact_path, path0, shallow=False) 655 # Log multiple artifacts, verify they exist in the directory returned by get_artifact_uri 656 for parent_dir in artifact_parent_dirs: 657 with start_run(): 658 artifact_uri = mlflow.get_artifact_uri() 659 run_artifact_dir = local_file_uri_to_path(artifact_uri) 660 661 mlflow.log_artifacts(artifact_dir, parent_dir) 662 # Check that the logged artifacts match 663 expected_artifact_output_dir = ( 664 os.path.join(run_artifact_dir, parent_dir) 665 if parent_dir is not None 666 else run_artifact_dir 667 ) 668 dir_comparison = filecmp.dircmp(artifact_dir, expected_artifact_output_dir) 669 assert len(dir_comparison.left_only) == 0 670 assert len(dir_comparison.right_only) == 0 671 assert len(dir_comparison.diff_files) == 0 672 assert len(dir_comparison.funny_files) == 0 673 674 675 @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."]) 676 def test_log_text(subdir): 677 filename = "file.txt" 678 text = "a" 679 artifact_file = filename if subdir is None else posixpath.join(subdir, filename) 680 681 with mlflow.start_run(): 682 mlflow.log_text(text, artifact_file) 683 684 artifact_path = None if subdir is None else posixpath.normpath(subdir) 685 artifact_uri = mlflow.get_artifact_uri(artifact_path) 686 run_artifact_dir = local_file_uri_to_path(artifact_uri) 687 assert os.listdir(run_artifact_dir) == [filename] 688 689 filepath = os.path.join(run_artifact_dir, filename) 690 with open(filepath) as f: 691 assert f.read() == text 692 693 694 @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."]) 695 @pytest.mark.parametrize("extension", [".json", ".yml", ".yaml", ".txt", ""]) 696 def test_log_dict(subdir, extension): 697 dictionary = {"k": "v"} 698 filename = "data" + extension 699 artifact_file = filename if subdir is None else posixpath.join(subdir, filename) 700 701 with mlflow.start_run(): 702 mlflow.log_dict(dictionary, artifact_file) 703 704 artifact_path = None if subdir is None else posixpath.normpath(subdir) 705 artifact_uri = mlflow.get_artifact_uri(artifact_path) 706 run_artifact_dir = local_file_uri_to_path(artifact_uri) 707 assert os.listdir(run_artifact_dir) == [filename] 708 709 filepath = os.path.join(run_artifact_dir, filename) 710 extension = os.path.splitext(filename)[1] 711 with open(filepath) as f: 712 loaded = ( 713 # Specify `Loader` to suppress the following deprecation warning: 714 # https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation 715 yaml.load(f, Loader=yaml.SafeLoader) 716 if (extension in [".yml", ".yaml"]) 717 else json.load(f) 718 ) 719 assert loaded == dictionary 720 721 722 @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."]) 723 def test_log_stream_bytes(subdir): 724 filename = "file.bin" 725 content = b"binary content" 726 artifact_file = filename if subdir is None else posixpath.join(subdir, filename) 727 728 with mlflow.start_run(): 729 stream = io.BytesIO(content) 730 mlflow.log_stream(stream, artifact_file) 731 732 artifact_path = None if subdir is None else posixpath.normpath(subdir) 733 artifact_uri = mlflow.get_artifact_uri(artifact_path) 734 run_artifact_dir = pathlib.Path(local_file_uri_to_path(artifact_uri)) 735 assert list(run_artifact_dir.iterdir()) == [run_artifact_dir / filename] 736 assert (run_artifact_dir / filename).read_bytes() == content 737 738 739 def test_log_stream_empty(): 740 with mlflow.start_run(): 741 artifact_uri = mlflow.get_artifact_uri() 742 run_artifact_dir = pathlib.Path(local_file_uri_to_path(artifact_uri)) 743 744 stream = io.BytesIO(b"") 745 mlflow.log_stream(stream, "empty.bin") 746 assert (run_artifact_dir / "empty.bin").read_bytes() == b"" 747 748 749 def test_log_stream_large_content(): 750 with mlflow.start_run(): 751 # Large binary content (larger than chunk size of 8192) 752 large_content = b"x" * 100000 753 stream = io.BytesIO(large_content) 754 mlflow.log_stream(stream, "large.bin") 755 756 artifact_uri = mlflow.get_artifact_uri() 757 run_artifact_dir = pathlib.Path(local_file_uri_to_path(artifact_uri)) 758 assert (run_artifact_dir / "large.bin").read_bytes() == large_content 759 760 761 def test_with_startrun(): 762 run_id = None 763 t0 = get_current_time_millis() 764 with mlflow.start_run() as active_run: 765 assert mlflow.active_run() == active_run 766 run_id = active_run.info.run_id 767 t1 = get_current_time_millis() 768 run_info = mlflow.tracking._get_store().get_run(run_id).info 769 assert run_info.status == "FINISHED" 770 assert t0 <= run_info.end_time 771 assert run_info.end_time <= t1 772 assert mlflow.active_run() is None 773 774 775 def test_parent_create_run(monkeypatch): 776 with mlflow.start_run() as parent_run: 777 parent_run_id = parent_run.info.run_id 778 monkeypatch.setenv(MLFLOW_RUN_ID.name, parent_run_id) 779 with mlflow.start_run() as parent_run: 780 assert parent_run.info.run_id == parent_run_id 781 with pytest.raises(Exception, match="To start a nested run"): 782 mlflow.start_run() 783 with mlflow.start_run(nested=True) as child_run: 784 assert child_run.info.run_id != parent_run_id 785 with mlflow.start_run(nested=True) as grand_child_run: 786 pass 787 788 def verify_has_parent_id_tag(child_id, expected_parent_id): 789 tags = tracking.MlflowClient().get_run(child_id).data.tags 790 assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id 791 792 verify_has_parent_id_tag(child_run.info.run_id, parent_run.info.run_id) 793 verify_has_parent_id_tag(grand_child_run.info.run_id, child_run.info.run_id) 794 assert mlflow.active_run() is None 795 796 797 def test_start_deleted_run(): 798 run_id = None 799 with mlflow.start_run() as active_run: 800 run_id = active_run.info.run_id 801 tracking.MlflowClient().delete_run(run_id) 802 with pytest.raises(MlflowException, match="because it is in the deleted state."): 803 with mlflow.start_run(run_id=run_id): 804 pass 805 assert mlflow.active_run() is None 806 807 808 @pytest.mark.usefixtures("reset_active_experiment") 809 def test_start_run_exp_id_0(): 810 mlflow.set_experiment("some-experiment") 811 # Create a run and verify that the current active experiment is the one we just set 812 with mlflow.start_run() as active_run: 813 exp_id = active_run.info.experiment_id 814 assert exp_id != FileStore.DEFAULT_EXPERIMENT_ID 815 assert MlflowClient().get_experiment(exp_id).name == "some-experiment" 816 # Set experiment ID to 0 when creating a run, verify that the specified experiment ID is honored 817 with mlflow.start_run(experiment_id=0) as active_run: 818 assert active_run.info.experiment_id == FileStore.DEFAULT_EXPERIMENT_ID 819 820 821 def test_get_artifact_uri_with_artifact_path_unspecified_returns_artifact_root_dir(): 822 with mlflow.start_run() as active_run: 823 assert mlflow.get_artifact_uri(artifact_path=None) == active_run.info.artifact_uri 824 825 826 def test_get_artifact_uri_uses_currently_active_run_id(): 827 artifact_path = "artifact" 828 with mlflow.start_run() as active_run: 829 assert mlflow.get_artifact_uri( 830 artifact_path=artifact_path 831 ) == tracking.artifact_utils.get_artifact_uri( 832 run_id=active_run.info.run_id, artifact_path=artifact_path 833 ) 834 835 836 def _assert_get_artifact_uri_appends_to_uri_path_component_correctly( 837 artifact_location, expected_uri_format 838 ): 839 client = MlflowClient() 840 client.create_experiment("get-artifact-uri-test", artifact_location=artifact_location) 841 mlflow.set_experiment("get-artifact-uri-test") 842 with mlflow.start_run(): 843 run_id = mlflow.active_run().info.run_id 844 for artifact_path in ["path/to/artifact", "/artifact/path", "arty.txt"]: 845 artifact_uri = mlflow.get_artifact_uri(artifact_path) 846 assert artifact_uri == tracking.artifact_utils.get_artifact_uri(run_id, artifact_path) 847 assert artifact_uri == expected_uri_format.format( 848 run_id=run_id, 849 path=artifact_path.lstrip("/"), 850 drive=pathlib.Path.cwd().drive, 851 ) 852 853 854 @pytest.mark.parametrize( 855 ("artifact_location", "expected_uri_format"), 856 [ 857 ( 858 "mysql://user:password@host:port/dbname?driver=mydriver", 859 "mysql://user:password@host:port/dbname/{run_id}/artifacts/{path}?driver=mydriver", 860 ), 861 ( 862 "mysql+driver://user:pass@host:port/dbname/subpath/#fragment", 863 "mysql+driver://user:pass@host:port/dbname/subpath/{run_id}/artifacts/{path}#fragment", 864 ), 865 ( 866 "s3://bucketname/rootpath", 867 "s3://bucketname/rootpath/{run_id}/artifacts/{path}", 868 ), 869 ], 870 ) 871 def test_get_artifact_uri_appends_to_uri_path_component_correctly( 872 artifact_location, expected_uri_format 873 ): 874 _assert_get_artifact_uri_appends_to_uri_path_component_correctly( 875 artifact_location, expected_uri_format 876 ) 877 878 879 @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows") 880 def test_get_artifact_uri_appends_to_local_path_component_correctly_on_windows(): 881 _assert_get_artifact_uri_appends_to_uri_path_component_correctly( 882 "/dirname/rootpa#th?", 883 "file:///{drive}/dirname/rootpa/{run_id}/artifacts/{path}#th?", 884 ) 885 886 887 @pytest.mark.skipif(is_windows(), reason="This test fails on Windows") 888 def test_get_artifact_uri_appends_to_local_path_component_correctly(): 889 _assert_get_artifact_uri_appends_to_uri_path_component_correctly( 890 "/dirname/rootpa#th?", "{drive}/dirname/rootpa#th?/{run_id}/artifacts/{path}" 891 ) 892 893 894 @pytest.mark.usefixtures("reset_active_experiment") 895 def test_search_runs(): 896 mlflow.set_experiment("exp-for-search") 897 # Create a run and verify that the current active experiment is the one we just set 898 logged_runs = {} 899 with mlflow.start_run() as active_run: 900 logged_runs["first"] = active_run.info.run_id 901 mlflow.log_metric("m1", 0.001) 902 mlflow.log_metric("m2", 0.002) 903 mlflow.log_metric("m1", 0.002) 904 mlflow.log_param("p1", "a") 905 mlflow.set_tag("t1", "first-tag-val") 906 with mlflow.start_run() as active_run: 907 logged_runs["second"] = active_run.info.run_id 908 mlflow.log_metric("m1", 0.008) 909 mlflow.log_param("p2", "aa") 910 mlflow.set_tag("t2", "second-tag-val") 911 912 def verify_runs(runs, expected_set): 913 assert {r.info.run_id for r in runs} == {logged_runs[r] for r in expected_set} 914 915 experiment_id = MlflowClient().get_experiment_by_name("exp-for-search").experiment_id 916 917 # 2 runs in this experiment 918 assert len(MlflowClient().search_runs([experiment_id], run_view_type=ViewType.ACTIVE_ONLY)) == 2 919 920 # 2 runs that have metric "m1" > 0.001 921 runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.0001") 922 verify_runs(runs, ["first", "second"]) 923 924 # 1 run with has metric "m1" > 0.002 925 runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.002") 926 verify_runs(runs, ["second"]) 927 928 # no runs with metric "m1" > 0.1 929 runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.1") 930 verify_runs(runs, []) 931 932 # 1 run with metric "m2" > 0 933 runs = MlflowClient().search_runs([experiment_id], "metrics.m2 > 0") 934 verify_runs(runs, ["first"]) 935 936 # 1 run each with param "p1" and "p2" 937 runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ALL) 938 verify_runs(runs, ["first"]) 939 runs = MlflowClient().search_runs([experiment_id], "params.p2 != 'a'", ViewType.ALL) 940 verify_runs(runs, ["second"]) 941 runs = MlflowClient().search_runs([experiment_id], "params.p2 = 'aa'", ViewType.ALL) 942 verify_runs(runs, ["second"]) 943 944 # 1 run each with tag "t1" and "t2" 945 runs = MlflowClient().search_runs([experiment_id], "tags.t1 = 'first-tag-val'", ViewType.ALL) 946 verify_runs(runs, ["first"]) 947 runs = MlflowClient().search_runs([experiment_id], "tags.t2 != 'qwerty'", ViewType.ALL) 948 verify_runs(runs, ["second"]) 949 runs = MlflowClient().search_runs([experiment_id], "tags.t2 = 'second-tag-val'", ViewType.ALL) 950 verify_runs(runs, ["second"]) 951 952 # delete "first" run 953 MlflowClient().delete_run(logged_runs["first"]) 954 runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ALL) 955 verify_runs(runs, ["first"]) 956 957 runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.DELETED_ONLY) 958 verify_runs(runs, ["first"]) 959 960 runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ACTIVE_ONLY) 961 verify_runs(runs, []) 962 963 964 @pytest.mark.usefixtures("reset_active_experiment") 965 def test_search_runs_multiple_experiments(): 966 experiment_ids = [mlflow.create_experiment(f"exp__{exp_id}") for exp_id in range(1, 4)] 967 for eid in experiment_ids: 968 with mlflow.start_run(experiment_id=eid): 969 mlflow.log_metric("m0", 1) 970 mlflow.log_metric(f"m_{eid}", 2) 971 972 assert len(MlflowClient().search_runs(experiment_ids, "metrics.m0 > 0", ViewType.ALL)) == 3 973 974 assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_1 > 0", ViewType.ALL)) == 1 975 assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_2 = 2", ViewType.ALL)) == 1 976 assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_3 < 4", ViewType.ALL)) == 1 977 978 979 def read_data(artifact_path): 980 import pandas as pd 981 982 if artifact_path.endswith(".json"): 983 return pd.read_json(artifact_path, orient="split") 984 if artifact_path.endswith(".parquet"): 985 return pd.read_parquet(artifact_path) 986 raise ValueError(f"Unsupported file type in {artifact_path}. Expected .json or .parquet") 987 988 989 @pytest.mark.skipif( 990 "MLFLOW_SKINNY" in os.environ, 991 reason="Skinny client does not support the np or pandas dependencies", 992 ) 993 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 994 def test_log_table(file_type): 995 import pandas as pd 996 997 table_dict = { 998 "inputs": ["What is MLflow?", "What is Databricks?"], 999 "outputs": ["MLflow is ...", "Databricks is ..."], 1000 "toxicity": [0.0, 0.0], 1001 } 1002 artifact_file = f"qabot_eval_results.{file_type}" 1003 TAG_NAME = "mlflow.loggedArtifacts" 1004 run_id = None 1005 1006 with pytest.raises( 1007 MlflowException, match="data must be a pandas.DataFrame or a dictionary" 1008 ) as e: 1009 with mlflow.start_run() as run: 1010 # Log the incorrect data format as a table 1011 mlflow.log_table(data="incorrect-data-format", artifact_file=artifact_file) 1012 assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE) 1013 1014 with mlflow.start_run() as run: 1015 # Log the dictionary as a table 1016 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1017 run_id = run.info.run_id 1018 1019 run = mlflow.get_run(run_id) 1020 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1021 table_data = read_data(artifact_path) 1022 assert table_data.shape[0] == 2 1023 assert table_data.shape[1] == 3 1024 1025 # Get the current value of the tag 1026 current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]")) 1027 assert {"path": artifact_file, "type": "table"} in current_tag_value 1028 assert len(current_tag_value) == 1 1029 1030 table_df = pd.DataFrame.from_dict(table_dict) 1031 with mlflow.start_run(run_id=run_id): 1032 # Log the dataframe as a table 1033 mlflow.log_table(data=table_df, artifact_file=artifact_file) 1034 1035 run = mlflow.get_run(run_id) 1036 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1037 table_data = read_data(artifact_path) 1038 assert table_data.shape[0] == 4 1039 assert table_data.shape[1] == 3 1040 # Get the current value of the tag 1041 current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]")) 1042 assert {"path": artifact_file, "type": "table"} in current_tag_value 1043 assert len(current_tag_value) == 1 1044 1045 artifact_file_new = f"qabot_eval_results_new.{file_type}" 1046 with mlflow.start_run(run_id=run_id): 1047 # Log the dataframe as a table to new artifact file 1048 mlflow.log_table(data=table_df, artifact_file=artifact_file_new) 1049 1050 run = mlflow.get_run(run_id) 1051 artifact_path = mlflow.artifacts.download_artifacts( 1052 run_id=run_id, artifact_path=artifact_file_new 1053 ) 1054 table_data = read_data(artifact_path) 1055 assert table_data.shape[0] == 2 1056 assert table_data.shape[1] == 3 1057 # Get the current value of the tag 1058 current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]")) 1059 assert {"path": artifact_file_new, "type": "table"} in current_tag_value 1060 assert len(current_tag_value) == 2 1061 1062 1063 @pytest.mark.skipif( 1064 "MLFLOW_SKINNY" in os.environ, 1065 reason="Skinny client does not support the np or pandas dependencies", 1066 ) 1067 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 1068 def test_log_table_with_subdirectory(file_type): 1069 import pandas as pd 1070 1071 table_dict = { 1072 "inputs": ["What is MLflow?", "What is Databricks?"], 1073 "outputs": ["MLflow is ...", "Databricks is ..."], 1074 "toxicity": [0.0, 0.0], 1075 } 1076 artifact_file = f"dir/foo.{file_type}" 1077 TAG_NAME = "mlflow.loggedArtifacts" 1078 run_id = None 1079 1080 with mlflow.start_run() as run: 1081 # Log the dictionary as a table 1082 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1083 run_id = run.info.run_id 1084 1085 run = mlflow.get_run(run_id) 1086 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1087 table_data = read_data(artifact_path) 1088 assert table_data.shape[0] == 2 1089 assert table_data.shape[1] == 3 1090 1091 # Get the current value of the tag 1092 current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]")) 1093 assert {"path": artifact_file, "type": "table"} in current_tag_value 1094 assert len(current_tag_value) == 1 1095 1096 table_df = pd.DataFrame.from_dict(table_dict) 1097 with mlflow.start_run(run_id=run_id): 1098 # Log the dataframe as a table 1099 mlflow.log_table(data=table_df, artifact_file=artifact_file) 1100 1101 run = mlflow.get_run(run_id) 1102 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1103 table_data = read_data(artifact_path) 1104 assert table_data.shape[0] == 4 1105 assert table_data.shape[1] == 3 1106 # Get the current value of the tag 1107 current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]")) 1108 assert {"path": artifact_file, "type": "table"} in current_tag_value 1109 assert len(current_tag_value) == 1 1110 1111 1112 @pytest.mark.skipif( 1113 "MLFLOW_SKINNY" in os.environ, 1114 reason="Skinny client does not support the np or pandas dependencies", 1115 ) 1116 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 1117 def test_load_table(file_type): 1118 table_dict = { 1119 "inputs": ["What is MLflow?", "What is Databricks?"], 1120 "outputs": ["MLflow is ...", "Databricks is ..."], 1121 "toxicity": [0.0, 0.0], 1122 } 1123 artifact_file = f"qabot_eval_results.{file_type}" 1124 artifact_file_2 = f"qabot_eval_results_2.{file_type}" 1125 run_id_2 = None 1126 1127 with mlflow.start_run() as run: 1128 # Log the dictionary as a table 1129 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1130 mlflow.log_table(data=table_dict, artifact_file=artifact_file_2) 1131 1132 with mlflow.start_run() as run: 1133 # Log the dictionary as a table 1134 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1135 run_id_2 = run.info.run_id 1136 1137 with mlflow.start_run() as run: 1138 # Log the dictionary as a table 1139 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1140 run_id_3 = run.info.run_id 1141 1142 extra_columns = ["run_id", "tags.mlflow.loggedArtifacts"] 1143 1144 # test 1: load table with extra columns 1145 output_df = mlflow.load_table(artifact_file=artifact_file, extra_columns=extra_columns) 1146 1147 assert output_df.shape[0] == 6 1148 assert output_df.shape[1] == 5 1149 assert output_df["run_id"].nunique() == 3 1150 assert output_df["tags.mlflow.loggedArtifacts"].nunique() == 2 1151 1152 # test 2: load table with extra columns and single run_id 1153 output_df = mlflow.load_table( 1154 artifact_file=artifact_file, run_ids=[run_id_2], extra_columns=extra_columns 1155 ) 1156 1157 assert output_df.shape[0] == 2 1158 assert output_df.shape[1] == 5 1159 assert output_df["run_id"].nunique() == 1 1160 assert output_df["tags.mlflow.loggedArtifacts"].nunique() == 1 1161 1162 # test 3: load table with extra columns and multiple run_ids 1163 output_df = mlflow.load_table( 1164 artifact_file=artifact_file, 1165 run_ids=[run_id_2, run_id_3], 1166 extra_columns=extra_columns, 1167 ) 1168 1169 assert output_df.shape[0] == 4 1170 assert output_df.shape[1] == 5 1171 assert output_df["run_id"].nunique() == 2 1172 assert output_df["tags.mlflow.loggedArtifacts"].nunique() == 1 1173 1174 # test 4: load table with no extra columns and run_ids specified but different artifact file 1175 output_df = mlflow.load_table(artifact_file=artifact_file_2) 1176 import pandas as pd 1177 1178 pd.testing.assert_frame_equal(output_df, pd.DataFrame(table_dict), check_dtype=False) 1179 1180 # test 5: load table with no extra columns and run_ids specified 1181 output_df = mlflow.load_table(artifact_file=artifact_file) 1182 1183 assert output_df.shape[0] == 6 1184 assert output_df.shape[1] == 3 1185 1186 # test 6: load table with no matching results found. Error case 1187 with pytest.raises( 1188 MlflowException, match="No runs found with the corresponding table artifact" 1189 ): 1190 mlflow.load_table(artifact_file=f"error_case.{file_type}") 1191 1192 # test 7: load table with no matching extra_column found. Error case 1193 with pytest.raises(KeyError, match="error_column"): 1194 mlflow.load_table(artifact_file=artifact_file, extra_columns=["error_column"]) 1195 1196 1197 @pytest.mark.skipif( 1198 "MLFLOW_SKINNY" in os.environ, 1199 reason="Skinny client does not support the np or pandas dependencies", 1200 ) 1201 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 1202 def test_log_table_with_datetime_columns(file_type): 1203 import pandas as pd 1204 1205 start_time = str(datetime.now(timezone.utc)) 1206 table_dict = { 1207 "inputs": ["What is MLflow?", "What is Databricks?"], 1208 "outputs": ["MLflow is ...", "Databricks is ..."], 1209 "start_time": [start_time, start_time], 1210 } 1211 artifact_file = f"test_time.{file_type}" 1212 1213 with mlflow.start_run() as run: 1214 # Log the dictionary as a table 1215 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1216 run_id = run.info.run_id 1217 1218 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1219 if file_type == "parquet": 1220 table_data = pd.read_parquet(artifact_path) 1221 else: 1222 table_data = pd.read_json(artifact_path, orient="split", convert_dates=False) 1223 assert table_data["start_time"][0] == start_time 1224 1225 # append the same table to the same artifact file 1226 mlflow.log_table(data=table_dict, artifact_file=artifact_file, run_id=run_id) 1227 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1228 if file_type == "parquet": 1229 df = pd.read_parquet(artifact_path) 1230 else: 1231 df = pd.read_json(artifact_path, orient="split", convert_dates=False) 1232 assert df["start_time"][2] == start_time 1233 1234 1235 @pytest.mark.skipif( 1236 "MLFLOW_SKINNY" in os.environ, 1237 reason="Skinny client does not support the np or pandas dependencies", 1238 ) 1239 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 1240 def test_log_table_with_image_columns(file_type): 1241 import numpy as np 1242 from PIL import Image 1243 1244 image = mlflow.Image([[1, 2, 3]]) 1245 table_dict = { 1246 "inputs": ["What is MLflow?", "What is Databricks?"], 1247 "outputs": ["MLflow is ...", "Databricks is ..."], 1248 "image": [image, image], 1249 } 1250 artifact_file = f"test_time.{file_type}" 1251 1252 with mlflow.start_run() as run: 1253 # Log the dictionary as a table 1254 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1255 run_id = run.info.run_id 1256 1257 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1258 table_data = read_data(artifact_path) 1259 assert table_data["image"][0]["type"] == "image" 1260 image_path = mlflow.artifacts.download_artifacts( 1261 run_id=run_id, artifact_path=table_data["image"][0]["filepath"] 1262 ) 1263 image2 = Image.open(image_path) 1264 assert np.abs(image.to_array() - np.array(image2)).sum() == 0 1265 1266 # append the same table to the same artifact file 1267 mlflow.log_table(data=table_dict, artifact_file=artifact_file, run_id=run_id) 1268 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1269 df = read_data(artifact_path) 1270 assert df["image"][2]["type"] == "image" 1271 1272 1273 @pytest.mark.skipif( 1274 "MLFLOW_SKINNY" in os.environ, 1275 reason="Skinny client does not support the np or pandas dependencies", 1276 ) 1277 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 1278 def test_log_table_with_pil_image_columns(file_type): 1279 import numpy as np 1280 from PIL import Image 1281 1282 image = Image.fromarray(np.array([[1.0, 2.0, 3.0]])) 1283 image = image.convert("RGB") 1284 1285 table_dict = { 1286 "inputs": ["What is MLflow?", "What is Databricks?"], 1287 "outputs": ["MLflow is ...", "Databricks is ..."], 1288 "image": [image, image], 1289 } 1290 artifact_file = f"test_time.{file_type}" 1291 1292 with mlflow.start_run() as run: 1293 # Log the dictionary as a table 1294 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1295 run_id = run.info.run_id 1296 1297 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1298 table_data = read_data(artifact_path) 1299 assert table_data["image"][0]["type"] == "image" 1300 image_path = mlflow.artifacts.download_artifacts( 1301 run_id=run_id, artifact_path=table_data["image"][0]["filepath"] 1302 ) 1303 image2 = Image.open(image_path) 1304 assert np.abs(np.array(image) - np.array(image2)).sum() == 0 1305 1306 # append the same table to the same artifact file 1307 mlflow.log_table(data=table_dict, artifact_file=artifact_file, run_id=run_id) 1308 artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file) 1309 df = read_data(artifact_path) 1310 assert df["image"][2]["type"] == "image" 1311 1312 1313 @pytest.mark.skipif( 1314 "MLFLOW_SKINNY" in os.environ, 1315 reason="Skinny client does not support the np or pandas dependencies", 1316 ) 1317 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 1318 def test_log_table_with_invalid_image_columns(file_type): 1319 image = mlflow.Image([[1, 2, 3]]) 1320 table_dict = { 1321 "inputs": ["What is MLflow?", "What is Databricks?"], 1322 "outputs": ["MLflow is ...", "Databricks is ..."], 1323 "image": [image, "text"], 1324 } 1325 artifact_file = f"test_time.{file_type}" 1326 with pytest.raises(ValueError, match="Column `image` contains a mix of images and non-images"): 1327 with mlflow.start_run(): 1328 # Log the dictionary as a table 1329 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1330 1331 1332 @pytest.mark.skipif( 1333 "MLFLOW_SKINNY" in os.environ, 1334 reason="Skinny client does not support the np or pandas dependencies", 1335 ) 1336 @pytest.mark.parametrize("file_type", ["json", "parquet"]) 1337 def test_log_table_with_valid_image_columns(file_type): 1338 class ImageObj: 1339 def __init__(self): 1340 self.size = (1, 1) 1341 1342 def resize(self, size): 1343 return self 1344 1345 def save(self, path): 1346 with open(path, "w+") as f: 1347 f.write("dummy data") 1348 1349 image_obj = ImageObj() 1350 image = mlflow.Image([[1, 2, 3]]) 1351 1352 table_dict = { 1353 "inputs": ["What is MLflow?", "What is Databricks?"], 1354 "outputs": ["MLflow is ...", "Databricks is ..."], 1355 "image": [image, image_obj], 1356 } 1357 # No error should be raised 1358 artifact_file = f"test_time.{file_type}" 1359 with mlflow.start_run(): 1360 # Log the dictionary as a table 1361 mlflow.log_table(data=table_dict, artifact_file=artifact_file) 1362 1363 1364 def test_set_async_logging_threadpool_size(): 1365 MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.set(6) 1366 assert MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.get() == 6 1367 1368 with mlflow.start_run(): 1369 mlflow.log_param("key", "val", synchronous=False) 1370 1371 store = mlflow.tracking._get_store() 1372 async_queue = store._async_logging_queue 1373 assert async_queue._batch_logging_worker_threadpool._max_workers == 6 1374 mlflow.flush_async_logging() 1375 MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.unset()