/ tests / tracking / test_tracking.py
test_tracking.py
   1  import filecmp
   2  import io
   3  import json
   4  import os
   5  import pathlib
   6  import posixpath
   7  import random
   8  import re
   9  from datetime import datetime, timezone
  10  from typing import NamedTuple
  11  from unittest import mock
  12  
  13  import pytest
  14  import yaml
  15  
  16  import mlflow
  17  from mlflow import MlflowClient, tracking
  18  from mlflow.entities import LifecycleStage, Metric, Param, RunStatus, RunTag, ViewType
  19  from mlflow.environment_variables import (
  20      MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE,
  21      MLFLOW_RUN_ID,
  22  )
  23  from mlflow.exceptions import MlflowException
  24  from mlflow.protos.databricks_pb2 import (
  25      INVALID_PARAMETER_VALUE,
  26      RESOURCE_DOES_NOT_EXIST,
  27      ErrorCode,
  28  )
  29  from mlflow.store.tracking.file_store import FileStore
  30  from mlflow.tracking._tracking_service.client import TrackingServiceClient
  31  from mlflow.tracking.fluent import start_run
  32  from mlflow.utils.file_utils import local_file_uri_to_path
  33  from mlflow.utils.mlflow_tags import (
  34      MLFLOW_PARENT_RUN_ID,
  35      MLFLOW_RUN_NAME,
  36      MLFLOW_SOURCE_NAME,
  37      MLFLOW_SOURCE_TYPE,
  38      MLFLOW_USER,
  39  )
  40  from mlflow.utils.os import is_windows
  41  from mlflow.utils.time import get_current_time_millis
  42  from mlflow.utils.validation import (
  43      MAX_METRICS_PER_BATCH,
  44      MAX_PARAMS_TAGS_PER_BATCH,
  45  )
  46  
  47  
  48  class MockExperiment(NamedTuple):
  49      experiment_id: str
  50      lifecycle_stage: str
  51      tags: dict[str, str] = {}
  52  
  53  
  54  def test_create_experiment():
  55      with pytest.raises(MlflowException, match="Invalid experiment name"):
  56          mlflow.create_experiment(None)
  57  
  58      with pytest.raises(MlflowException, match="Invalid experiment name"):
  59          mlflow.create_experiment("")
  60  
  61      exp_id = mlflow.create_experiment(f"Some random experiment name {random.randint(1, int(1e6))}")
  62      assert exp_id is not None
  63  
  64  
  65  def test_create_experiment_with_duplicate_name():
  66      name = "popular_name"
  67      exp_id = mlflow.create_experiment(name)
  68  
  69      with pytest.raises(MlflowException, match=re.escape(f"Experiment(name={name}) already exists")):
  70          mlflow.create_experiment(name)
  71  
  72      tracking.MlflowClient().delete_experiment(exp_id)
  73      with pytest.raises(MlflowException, match=re.escape(f"Experiment(name={name}) already exists")):
  74          mlflow.create_experiment(name)
  75  
  76  
  77  def test_create_experiments_with_bad_names():
  78      # None for name
  79      with pytest.raises(MlflowException, match="Invalid experiment name: 'None'"):
  80          mlflow.create_experiment(None)
  81  
  82      # empty string name
  83      with pytest.raises(MlflowException, match="Invalid experiment name: ''"):
  84          mlflow.create_experiment("")
  85  
  86  
  87  @pytest.mark.parametrize("name", [123, 0, -1.2, [], ["A"], {1: 2}])
  88  def test_create_experiments_with_bad_name_types(name):
  89      with pytest.raises(
  90          MlflowException,
  91          match=re.escape(f"Invalid experiment name: {name}. Expects a string."),
  92      ):
  93          mlflow.create_experiment(name)
  94  
  95  
  96  @pytest.mark.usefixtures("reset_active_experiment")
  97  def test_set_experiment_by_name():
  98      name = "random_exp"
  99      exp_id = mlflow.create_experiment(name)
 100      exp1 = mlflow.set_experiment(name)
 101      assert exp1.experiment_id == exp_id
 102      with start_run() as run:
 103          assert run.info.experiment_id == exp_id
 104  
 105      another_name = "another_experiment"
 106      exp2 = mlflow.set_experiment(another_name)
 107      with start_run() as another_run:
 108          assert another_run.info.experiment_id == exp2.experiment_id
 109  
 110  
 111  @pytest.mark.usefixtures("reset_active_experiment")
 112  def test_set_experiment_by_id():
 113      name = "random_exp"
 114      exp_id = mlflow.create_experiment(name)
 115      active_exp = mlflow.set_experiment(experiment_id=exp_id)
 116      assert active_exp.experiment_id == exp_id
 117      with start_run() as run:
 118          assert run.info.experiment_id == exp_id
 119  
 120      nonexistent_id = "-1337"
 121      with pytest.raises(MlflowException, match="No Experiment with id=-1337 exists") as exc:
 122          mlflow.set_experiment(experiment_id=nonexistent_id)
 123      assert exc.value.error_code == ErrorCode.Name(RESOURCE_DOES_NOT_EXIST)
 124      with start_run() as run:
 125          assert run.info.experiment_id == exp_id
 126  
 127  
 128  def test_set_experiment_parameter_validation():
 129      with pytest.raises(MlflowException, match="Must specify exactly one") as exc:
 130          mlflow.set_experiment()
 131      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 132  
 133      with pytest.raises(MlflowException, match="Must specify exactly one") as exc:
 134          mlflow.set_experiment(None)
 135      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 136  
 137      with pytest.raises(MlflowException, match="Must specify exactly one") as exc:
 138          mlflow.set_experiment(None, None)
 139      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 140  
 141      with pytest.raises(MlflowException, match="Must specify exactly one") as exc:
 142          mlflow.set_experiment("name", "id")
 143      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 144  
 145  
 146  def test_set_experiment_with_deleted_experiment():
 147      name = "dead_exp"
 148      mlflow.set_experiment(name)
 149      with start_run() as run:
 150          exp_id = run.info.experiment_id
 151  
 152      tracking.MlflowClient().delete_experiment(exp_id)
 153  
 154      with pytest.raises(MlflowException, match="Cannot set a deleted experiment") as exc:
 155          mlflow.set_experiment(name)
 156      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 157  
 158      with pytest.raises(MlflowException, match="Cannot set a deleted experiment") as exc:
 159          mlflow.set_experiment(experiment_id=exp_id)
 160      assert exc.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 161  
 162  
 163  @pytest.mark.usefixtures("reset_active_experiment")
 164  def test_set_experiment_with_zero_id():
 165      mock_experiment = MockExperiment(experiment_id=0, lifecycle_stage=LifecycleStage.ACTIVE)
 166      with (
 167          mock.patch.object(
 168              TrackingServiceClient,
 169              "get_experiment_by_name",
 170              mock.Mock(return_value=mock_experiment),
 171          ) as get_experiment_by_name_mock,
 172          mock.patch.object(TrackingServiceClient, "create_experiment") as create_experiment_mock,
 173      ):
 174          mlflow.set_experiment("my_exp")
 175          get_experiment_by_name_mock.assert_called_once()
 176          create_experiment_mock.assert_not_called()
 177  
 178  
 179  def test_start_run_context_manager():
 180      with start_run() as first_run:
 181          first_uuid = first_run.info.run_id
 182          # Check that start_run() causes the run information to be persisted in the store
 183          persisted_run = tracking.MlflowClient().get_run(first_uuid)
 184          assert persisted_run is not None
 185          assert persisted_run.info == first_run.info
 186      finished_run = tracking.MlflowClient().get_run(first_uuid)
 187      assert finished_run.info.status == RunStatus.to_string(RunStatus.FINISHED)
 188      # Launch a separate run that fails, verify the run status is FAILED and the run UUID is
 189      # different
 190      with pytest.raises(Exception, match="Failing run!"):
 191          with start_run() as second_run:
 192              raise Exception("Failing run!")
 193      second_run_id = second_run.info.run_id
 194      assert second_run_id != first_uuid
 195      finished_run2 = tracking.MlflowClient().get_run(second_run_id)
 196      assert finished_run2.info.status == RunStatus.to_string(RunStatus.FAILED)
 197  
 198  
 199  def test_start_and_end_run():
 200      # Use the start_run() and end_run() APIs without a `with` block, verify they work.
 201  
 202      with start_run() as active_run:
 203          mlflow.log_metric("name_1", 25)
 204      finished_run = tracking.MlflowClient().get_run(active_run.info.run_id)
 205      # Validate metrics
 206      assert len(finished_run.data.metrics) == 1
 207      assert finished_run.data.metrics["name_1"] == 25
 208  
 209  
 210  def test_metric_timestamp():
 211      with mlflow.start_run() as active_run:
 212          mlflow.log_metric("name_1", 25)
 213          mlflow.log_metric("name_1", 30)
 214          run_id = active_run.info.run_id
 215      # Check that metric timestamps are between run start and finish
 216      client = MlflowClient()
 217      history = client.get_metric_history(run_id, "name_1")
 218      finished_run = client.get_run(run_id)
 219      assert len(history) == 2
 220      assert all(
 221          m.timestamp >= finished_run.info.start_time and m.timestamp <= finished_run.info.end_time
 222          for m in history
 223      )
 224  
 225  
 226  def test_log_batch():
 227      expected_metrics = {"metric-key0": 1.0, "metric-key1": 4.0}
 228      expected_params = {
 229          "param-key0": "param-val0",
 230          "param-key1": 123,
 231          "param-key2": None,
 232      }
 233      exact_expected_tags = {"tag-key0": "tag-val0", "tag-key1": "tag-val1"}
 234      approx_expected_tags = {
 235          MLFLOW_USER,
 236          MLFLOW_SOURCE_NAME,
 237          MLFLOW_SOURCE_TYPE,
 238          MLFLOW_RUN_NAME,
 239      }
 240  
 241      t = get_current_time_millis()
 242      sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[0])
 243      metrics = [
 244          Metric(key=key, value=value, timestamp=t, step=i)
 245          for i, (key, value) in enumerate(sorted_expected_metrics)
 246      ]
 247      params = [Param(key=key, value=value) for key, value in expected_params.items()]
 248      tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()]
 249  
 250      with start_run() as active_run:
 251          run_id = active_run.info.run_id
 252          MlflowClient().log_batch(run_id=run_id, metrics=metrics, params=params, tags=tags)
 253      client = tracking.MlflowClient()
 254      finished_run = client.get_run(run_id)
 255      # Validate metrics
 256      assert len(finished_run.data.metrics) == 2
 257      for key, value in finished_run.data.metrics.items():
 258          assert expected_metrics[key] == value
 259      metric_history0 = client.get_metric_history(run_id, "metric-key0")
 260      assert {(m.value, m.timestamp, m.step) for m in metric_history0} == {(1.0, t, 0)}
 261      metric_history1 = client.get_metric_history(run_id, "metric-key1")
 262      assert {(m.value, m.timestamp, m.step) for m in metric_history1} == {(4.0, t, 1)}
 263  
 264      # Validate tags (for automatically-set tags)
 265      assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
 266      for tag_key, tag_value in finished_run.data.tags.items():
 267          if tag_key in approx_expected_tags:
 268              pass
 269          else:
 270              assert exact_expected_tags[tag_key] == tag_value
 271      # Validate params
 272      assert finished_run.data.params == {k: str(v) for k, v in expected_params.items()}
 273      # test that log_batch works with fewer params
 274      new_tags = {"1": "2", "3": "4", "5": "6"}
 275      tags = [RunTag(key=key, value=value) for key, value in new_tags.items()]
 276      client.log_batch(run_id=run_id, tags=tags)
 277      finished_run_2 = client.get_run(run_id)
 278      # Validate tags (for automatically-set tags)
 279      assert len(finished_run_2.data.tags) == len(finished_run.data.tags) + 3
 280      for tag_key, tag_value in finished_run_2.data.tags.items():
 281          if tag_key in new_tags:
 282              assert new_tags[tag_key] == tag_value
 283  
 284  
 285  def test_log_batch_with_many_elements():
 286      num_metrics = MAX_METRICS_PER_BATCH * 2
 287      num_params = num_tags = MAX_PARAMS_TAGS_PER_BATCH * 2
 288      expected_metrics = {f"metric-key{i}": float(i) for i in range(num_metrics)}
 289      expected_params = {f"param-key{i}": f"param-val{i}" for i in range(num_params)}
 290      exact_expected_tags = {f"tag-key{i}": f"tag-val{i}" for i in range(num_tags)}
 291  
 292      t = get_current_time_millis()
 293      sorted_expected_metrics = sorted(expected_metrics.items(), key=lambda kv: kv[1])
 294      metrics = [
 295          Metric(key=key, value=value, timestamp=t, step=i)
 296          for i, (key, value) in enumerate(sorted_expected_metrics)
 297      ]
 298      params = [Param(key=key, value=value) for key, value in expected_params.items()]
 299      tags = [RunTag(key=key, value=value) for key, value in exact_expected_tags.items()]
 300  
 301      with start_run() as active_run:
 302          run_id = active_run.info.run_id
 303          mlflow.tracking.MlflowClient().log_batch(
 304              run_id=run_id, metrics=metrics, params=params, tags=tags
 305          )
 306      client = tracking.MlflowClient()
 307      finished_run = client.get_run(run_id)
 308      # Validate metrics
 309      assert expected_metrics == finished_run.data.metrics
 310      for i in range(num_metrics):
 311          metric_history = client.get_metric_history(run_id, f"metric-key{i}")
 312          assert {(m.value, m.timestamp, m.step) for m in metric_history} == {(float(i), t, i)}
 313  
 314      # Validate tags
 315      logged_tags = finished_run.data.tags
 316      for tag_key, tag_value in exact_expected_tags.items():
 317          assert logged_tags[tag_key] == tag_value
 318  
 319      # Validate params
 320      assert finished_run.data.params == expected_params
 321  
 322  
 323  def test_log_metric():
 324      with start_run() as active_run, mock.patch("time.time", return_value=123):
 325          run_id = active_run.info.run_id
 326          mlflow.log_metric("name_1", 25)
 327          mlflow.log_metric("name_2", -3)
 328          mlflow.log_metric("name_1", 30, 5)
 329          mlflow.log_metric("name_1", 40, -2)
 330          mlflow.log_metric("nested/nested/name", 40)
 331      finished_run = tracking.MlflowClient().get_run(run_id)
 332      # Validate metrics
 333      assert len(finished_run.data.metrics) == 3
 334      expected_pairs = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
 335      for key, value in finished_run.data.metrics.items():
 336          assert expected_pairs[key] == value
 337      client = tracking.MlflowClient()
 338      metric_history_name1 = client.get_metric_history(run_id, "name_1")
 339      assert {(m.value, m.timestamp, m.step) for m in metric_history_name1} == {
 340          (25, 123 * 1000, 0),
 341          (30, 123 * 1000, 5),
 342          (40, 123 * 1000, -2),
 343      }
 344      metric_history_name2 = client.get_metric_history(run_id, "name_2")
 345      assert {(m.value, m.timestamp, m.step) for m in metric_history_name2} == {(-3, 123 * 1000, 0)}
 346  
 347  
 348  def test_log_metrics_uses_millisecond_timestamp_resolution_fluent():
 349      with start_run() as active_run, mock.patch("time.time") as time_mock:
 350          time_mock.side_effect = lambda: 123
 351          mlflow.log_metrics({"name_1": 25, "name_2": -3})
 352          mlflow.log_metrics({"name_1": 30})
 353          mlflow.log_metrics({"name_1": 40})
 354          run_id = active_run.info.run_id
 355  
 356      client = tracking.MlflowClient()
 357      metric_history_name1 = client.get_metric_history(run_id, "name_1")
 358      assert {(m.value, m.timestamp) for m in metric_history_name1} == {
 359          (25, 123 * 1000),
 360          (30, 123 * 1000),
 361          (40, 123 * 1000),
 362      }
 363      metric_history_name2 = client.get_metric_history(run_id, "name_2")
 364      assert {(m.value, m.timestamp) for m in metric_history_name2} == {(-3, 123 * 1000)}
 365  
 366  
 367  def test_log_metrics_uses_millisecond_timestamp_resolution_client():
 368      with start_run() as active_run, mock.patch("time.time") as time_mock:
 369          time_mock.side_effect = lambda: 123
 370          mlflow_client = tracking.MlflowClient()
 371          run_id = active_run.info.run_id
 372  
 373          mlflow_client.log_metric(run_id=run_id, key="name_1", value=25)
 374          mlflow_client.log_metric(run_id=run_id, key="name_2", value=-3)
 375          mlflow_client.log_metric(run_id=run_id, key="name_1", value=30)
 376          mlflow_client.log_metric(run_id=run_id, key="name_1", value=40)
 377  
 378      metric_history_name1 = mlflow_client.get_metric_history(run_id, "name_1")
 379      assert {(m.value, m.timestamp) for m in metric_history_name1} == {
 380          (25, 123 * 1000),
 381          (30, 123 * 1000),
 382          (40, 123 * 1000),
 383      }
 384  
 385      metric_history_name2 = mlflow_client.get_metric_history(run_id, "name_2")
 386      assert {(m.value, m.timestamp) for m in metric_history_name2} == {(-3, 123 * 1000)}
 387  
 388  
 389  @pytest.mark.parametrize("step_kwarg", [None, -10, 5])
 390  def test_log_metrics_uses_common_timestamp_and_step_per_invocation(step_kwarg):
 391      expected_metrics = {"name_1": 30, "name_2": -3, "nested/nested/name": 40}
 392      with start_run() as active_run:
 393          run_id = active_run.info.run_id
 394          mlflow.log_metrics(expected_metrics, step=step_kwarg)
 395      finished_run = tracking.MlflowClient().get_run(run_id)
 396      # Validate metric key/values match what we expect, and that all metrics have the same timestamp
 397      assert len(finished_run.data.metrics) == len(expected_metrics)
 398      for key, value in finished_run.data.metrics.items():
 399          assert expected_metrics[key] == value
 400      common_timestamp = finished_run.data._metric_objs[0].timestamp
 401      expected_step = step_kwarg if step_kwarg is not None else 0
 402      for metric_obj in finished_run.data._metric_objs:
 403          assert metric_obj.timestamp == common_timestamp
 404          assert metric_obj.step == expected_step
 405  
 406  
 407  @pytest.fixture
 408  def get_store_mock():
 409      with mock.patch("mlflow.store.file_store.FileStore.log_batch") as _get_store_mock:
 410          yield _get_store_mock
 411  
 412  
 413  def test_set_tags():
 414      exact_expected_tags = {"name_1": "c", "name_2": "b", "nested/nested/name": 5}
 415      approx_expected_tags = {
 416          MLFLOW_USER,
 417          MLFLOW_SOURCE_NAME,
 418          MLFLOW_SOURCE_TYPE,
 419          MLFLOW_RUN_NAME,
 420      }
 421      with start_run() as active_run:
 422          run_id = active_run.info.run_id
 423          mlflow.set_tags(exact_expected_tags)
 424      finished_run = tracking.MlflowClient().get_run(run_id)
 425      # Validate tags
 426      assert len(finished_run.data.tags) == len(exact_expected_tags) + len(approx_expected_tags)
 427      for tag_key, tag_val in finished_run.data.tags.items():
 428          if tag_key in approx_expected_tags:
 429              pass
 430          else:
 431              assert str(exact_expected_tags[tag_key]) == tag_val
 432  
 433  
 434  def test_log_metric_validation():
 435      with start_run() as active_run:
 436          run_id = active_run.info.run_id
 437          with pytest.raises(
 438              MlflowException,
 439              match="Invalid value \"apple\" for parameter 'value' supplied",
 440          ) as e:
 441              mlflow.log_metric("name_1", "apple")
 442      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 443      finished_run = tracking.MlflowClient().get_run(run_id)
 444      assert len(finished_run.data.metrics) == 0
 445  
 446  
 447  def test_log_param():
 448      with start_run() as active_run:
 449          run_id = active_run.info.run_id
 450          assert mlflow.log_param("name_1", "a") == "a"
 451          assert mlflow.log_param("name_2", "b") == "b"
 452          assert mlflow.log_param("nested/nested/name", 5) == 5
 453      finished_run = tracking.MlflowClient().get_run(run_id)
 454      # Validate params
 455      assert finished_run.data.params == {
 456          "name_1": "a",
 457          "name_2": "b",
 458          "nested/nested/name": "5",
 459      }
 460  
 461  
 462  def test_log_params():
 463      expected_params = {"name_1": "c", "name_2": "b", "nested/nested/name": 5}
 464      with start_run() as active_run:
 465          run_id = active_run.info.run_id
 466          mlflow.log_params(expected_params)
 467      finished_run = tracking.MlflowClient().get_run(run_id)
 468      # Validate params
 469      assert finished_run.data.params == {
 470          "name_1": "c",
 471          "name_2": "b",
 472          "nested/nested/name": "5",
 473      }
 474  
 475  
 476  def test_log_params_duplicate_keys_raises():
 477      params = {"a": "1", "b": "2"}
 478      with start_run() as active_run:
 479          run_id = active_run.info.run_id
 480          mlflow.log_params(params)
 481          with pytest.raises(
 482              expected_exception=MlflowException,
 483              match=r"Changing param values is not allowed. Param with key=",
 484          ) as e:
 485              mlflow.log_param("a", "3")
 486          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 487      finished_run = tracking.MlflowClient().get_run(run_id)
 488      assert finished_run.data.params == params
 489  
 490  
 491  @pytest.mark.skipif(is_windows(), reason="Windows do not support colon in params and metrics")
 492  def test_param_metric_with_colon():
 493      with start_run() as active_run:
 494          run_id = active_run.info.run_id
 495          mlflow.log_param("a:b", 3)
 496          mlflow.log_metric("c:d", 4)
 497      finished_run = tracking.MlflowClient().get_run(run_id)
 498  
 499      # Validate param
 500      assert len(finished_run.data.params) == 1
 501      assert finished_run.data.params == {"a:b": "3"}
 502  
 503      # Validate metric
 504      assert len(finished_run.data.metrics) == 1
 505      assert finished_run.data.metrics["c:d"] == 4
 506  
 507  
 508  def test_log_batch_duplicate_entries_raises():
 509      with start_run() as active_run:
 510          run_id = active_run.info.run_id
 511          with pytest.raises(
 512              MlflowException, match=r"Duplicate parameter keys have been submitted."
 513          ) as e:
 514              tracking.MlflowClient().log_batch(
 515                  run_id=run_id, params=[Param("a", "1"), Param("a", "2")]
 516              )
 517          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 518  
 519  
 520  def test_log_batch_validates_entity_names_and_values():
 521      with start_run() as active_run:
 522          run_id = active_run.info.run_id
 523  
 524          metrics = [Metric(key="../bad/metric/name", value=0.3, timestamp=3, step=0)]
 525          with pytest.raises(
 526              MlflowException,
 527              match=r"Invalid value \"../bad/metric/name\" for parameter \'metrics\[0\].name\'",
 528          ) as e:
 529              tracking.MlflowClient().log_batch(run_id, metrics=metrics)
 530          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 531  
 532          metrics = [Metric(key="ok-name", value="non-numerical-value", timestamp=3, step=0)]
 533          with pytest.raises(
 534              MlflowException,
 535              match=r"Invalid value \"non-numerical-value\" "
 536              + r"for parameter \'metrics\[0\].value\' supplied",
 537          ) as e:
 538              tracking.MlflowClient().log_batch(run_id, metrics=metrics)
 539          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 540  
 541          metrics = [Metric(key="ok-name", value=0.3, timestamp="non-numerical-timestamp", step=0)]
 542          with pytest.raises(
 543              MlflowException,
 544              match=r"Invalid value \"non-numerical-timestamp\" for "
 545              + r"parameter \'metrics\[0\].timestamp\' supplied",
 546          ) as e:
 547              tracking.MlflowClient().log_batch(run_id, metrics=metrics)
 548          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 549  
 550          params = [Param(key="../bad/param/name", value="my-val")]
 551          with pytest.raises(
 552              MlflowException,
 553              match=r"Invalid value \"../bad/param/name\" for parameter \'params\[0\].key\' supplied",
 554          ) as e:
 555              tracking.MlflowClient().log_batch(run_id, params=params)
 556          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 557  
 558          tags = [Param(key="../bad/tag/name", value="my-val")]
 559          with pytest.raises(
 560              MlflowException,
 561              match=r"Invalid value \"../bad/tag/name\" for parameter \'tags\[0\].key\' supplied",
 562          ) as e:
 563              tracking.MlflowClient().log_batch(run_id, tags=tags)
 564          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 565  
 566          metrics = [Metric(key=None, value=42.0, timestamp=4, step=1)]
 567          with pytest.raises(
 568              MlflowException,
 569              match="Metric name cannot be None. A key name must be provided.",
 570          ) as e:
 571              tracking.MlflowClient().log_batch(run_id, metrics=metrics)
 572          assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
 573  
 574  
 575  def test_log_artifact_with_dirs(tmp_path):
 576      # Test log artifact with a directory
 577      art_dir = tmp_path / "parent"
 578      art_dir.mkdir()
 579      file0 = art_dir.joinpath("file0")
 580      file0.write_text("something")
 581      file1 = art_dir.joinpath("file1")
 582      file1.write_text("something")
 583      sub_dir = art_dir / "child"
 584      sub_dir.mkdir()
 585      with start_run():
 586          artifact_uri = mlflow.get_artifact_uri()
 587          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 588          mlflow.log_artifact(str(art_dir))
 589          base = os.path.basename(str(art_dir))
 590          assert os.listdir(run_artifact_dir) == [base]
 591          assert set(os.listdir(os.path.join(run_artifact_dir, base))) == {
 592              "child",
 593              "file0",
 594              "file1",
 595          }
 596          with open(os.path.join(run_artifact_dir, base, "file0")) as f:
 597              assert f.read() == "something"
 598      # Test log artifact with directory and specified parent folder
 599  
 600      art_dir = tmp_path / "dir"
 601      art_dir.mkdir()
 602      with start_run():
 603          artifact_uri = mlflow.get_artifact_uri()
 604          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 605          mlflow.log_artifact(str(art_dir), "some_parent")
 606          assert os.listdir(run_artifact_dir) == [os.path.basename("some_parent")]
 607          assert os.listdir(os.path.join(run_artifact_dir, "some_parent")) == [
 608              os.path.basename(str(art_dir))
 609          ]
 610      sub_dir = art_dir.joinpath("another_dir")
 611      sub_dir.mkdir()
 612      with start_run():
 613          artifact_uri = mlflow.get_artifact_uri()
 614          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 615          mlflow.log_artifact(str(art_dir), "parent/and_child")
 616          assert os.listdir(os.path.join(run_artifact_dir, "parent", "and_child")) == [
 617              os.path.basename(str(art_dir))
 618          ]
 619          assert set(
 620              os.listdir(
 621                  os.path.join(
 622                      run_artifact_dir,
 623                      "parent",
 624                      "and_child",
 625                      os.path.basename(str(art_dir)),
 626                  )
 627              )
 628          ) == {os.path.basename(str(sub_dir))}
 629  
 630  
 631  def test_log_artifact(tmp_path):
 632      # Create artifacts
 633      artifact_dir = tmp_path.joinpath("artifact_dir")
 634      artifact_dir.mkdir()
 635      path0 = artifact_dir.joinpath("file0")
 636      path1 = artifact_dir.joinpath("file1")
 637      path0.write_text("0")
 638      path1.write_text("1")
 639      # Log an artifact, verify it exists in the directory returned by get_artifact_uri
 640      # after the run finishes
 641      artifact_parent_dirs = ["some_parent_dir", None]
 642      for parent_dir in artifact_parent_dirs:
 643          with start_run():
 644              artifact_uri = mlflow.get_artifact_uri()
 645              run_artifact_dir = local_file_uri_to_path(artifact_uri)
 646              mlflow.log_artifact(path0, parent_dir)
 647          expected_dir = (
 648              os.path.join(run_artifact_dir, parent_dir)
 649              if parent_dir is not None
 650              else run_artifact_dir
 651          )
 652          assert os.listdir(expected_dir) == [os.path.basename(path0)]
 653          logged_artifact_path = os.path.join(expected_dir, path0)
 654          assert filecmp.cmp(logged_artifact_path, path0, shallow=False)
 655      # Log multiple artifacts, verify they exist in the directory returned by get_artifact_uri
 656      for parent_dir in artifact_parent_dirs:
 657          with start_run():
 658              artifact_uri = mlflow.get_artifact_uri()
 659              run_artifact_dir = local_file_uri_to_path(artifact_uri)
 660  
 661              mlflow.log_artifacts(artifact_dir, parent_dir)
 662          # Check that the logged artifacts match
 663          expected_artifact_output_dir = (
 664              os.path.join(run_artifact_dir, parent_dir)
 665              if parent_dir is not None
 666              else run_artifact_dir
 667          )
 668          dir_comparison = filecmp.dircmp(artifact_dir, expected_artifact_output_dir)
 669          assert len(dir_comparison.left_only) == 0
 670          assert len(dir_comparison.right_only) == 0
 671          assert len(dir_comparison.diff_files) == 0
 672          assert len(dir_comparison.funny_files) == 0
 673  
 674  
 675  @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."])
 676  def test_log_text(subdir):
 677      filename = "file.txt"
 678      text = "a"
 679      artifact_file = filename if subdir is None else posixpath.join(subdir, filename)
 680  
 681      with mlflow.start_run():
 682          mlflow.log_text(text, artifact_file)
 683  
 684          artifact_path = None if subdir is None else posixpath.normpath(subdir)
 685          artifact_uri = mlflow.get_artifact_uri(artifact_path)
 686          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 687          assert os.listdir(run_artifact_dir) == [filename]
 688  
 689          filepath = os.path.join(run_artifact_dir, filename)
 690          with open(filepath) as f:
 691              assert f.read() == text
 692  
 693  
 694  @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."])
 695  @pytest.mark.parametrize("extension", [".json", ".yml", ".yaml", ".txt", ""])
 696  def test_log_dict(subdir, extension):
 697      dictionary = {"k": "v"}
 698      filename = "data" + extension
 699      artifact_file = filename if subdir is None else posixpath.join(subdir, filename)
 700  
 701      with mlflow.start_run():
 702          mlflow.log_dict(dictionary, artifact_file)
 703  
 704          artifact_path = None if subdir is None else posixpath.normpath(subdir)
 705          artifact_uri = mlflow.get_artifact_uri(artifact_path)
 706          run_artifact_dir = local_file_uri_to_path(artifact_uri)
 707          assert os.listdir(run_artifact_dir) == [filename]
 708  
 709          filepath = os.path.join(run_artifact_dir, filename)
 710          extension = os.path.splitext(filename)[1]
 711          with open(filepath) as f:
 712              loaded = (
 713                  # Specify `Loader` to suppress the following deprecation warning:
 714                  # https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation
 715                  yaml.load(f, Loader=yaml.SafeLoader)
 716                  if (extension in [".yml", ".yaml"])
 717                  else json.load(f)
 718              )
 719              assert loaded == dictionary
 720  
 721  
 722  @pytest.mark.parametrize("subdir", [None, ".", "dir", "dir1/dir2", "dir/.."])
 723  def test_log_stream_bytes(subdir):
 724      filename = "file.bin"
 725      content = b"binary content"
 726      artifact_file = filename if subdir is None else posixpath.join(subdir, filename)
 727  
 728      with mlflow.start_run():
 729          stream = io.BytesIO(content)
 730          mlflow.log_stream(stream, artifact_file)
 731  
 732          artifact_path = None if subdir is None else posixpath.normpath(subdir)
 733          artifact_uri = mlflow.get_artifact_uri(artifact_path)
 734          run_artifact_dir = pathlib.Path(local_file_uri_to_path(artifact_uri))
 735          assert list(run_artifact_dir.iterdir()) == [run_artifact_dir / filename]
 736          assert (run_artifact_dir / filename).read_bytes() == content
 737  
 738  
 739  def test_log_stream_empty():
 740      with mlflow.start_run():
 741          artifact_uri = mlflow.get_artifact_uri()
 742          run_artifact_dir = pathlib.Path(local_file_uri_to_path(artifact_uri))
 743  
 744          stream = io.BytesIO(b"")
 745          mlflow.log_stream(stream, "empty.bin")
 746          assert (run_artifact_dir / "empty.bin").read_bytes() == b""
 747  
 748  
 749  def test_log_stream_large_content():
 750      with mlflow.start_run():
 751          # Large binary content (larger than chunk size of 8192)
 752          large_content = b"x" * 100000
 753          stream = io.BytesIO(large_content)
 754          mlflow.log_stream(stream, "large.bin")
 755  
 756          artifact_uri = mlflow.get_artifact_uri()
 757          run_artifact_dir = pathlib.Path(local_file_uri_to_path(artifact_uri))
 758          assert (run_artifact_dir / "large.bin").read_bytes() == large_content
 759  
 760  
 761  def test_with_startrun():
 762      run_id = None
 763      t0 = get_current_time_millis()
 764      with mlflow.start_run() as active_run:
 765          assert mlflow.active_run() == active_run
 766          run_id = active_run.info.run_id
 767      t1 = get_current_time_millis()
 768      run_info = mlflow.tracking._get_store().get_run(run_id).info
 769      assert run_info.status == "FINISHED"
 770      assert t0 <= run_info.end_time
 771      assert run_info.end_time <= t1
 772      assert mlflow.active_run() is None
 773  
 774  
 775  def test_parent_create_run(monkeypatch):
 776      with mlflow.start_run() as parent_run:
 777          parent_run_id = parent_run.info.run_id
 778      monkeypatch.setenv(MLFLOW_RUN_ID.name, parent_run_id)
 779      with mlflow.start_run() as parent_run:
 780          assert parent_run.info.run_id == parent_run_id
 781          with pytest.raises(Exception, match="To start a nested run"):
 782              mlflow.start_run()
 783          with mlflow.start_run(nested=True) as child_run:
 784              assert child_run.info.run_id != parent_run_id
 785              with mlflow.start_run(nested=True) as grand_child_run:
 786                  pass
 787  
 788      def verify_has_parent_id_tag(child_id, expected_parent_id):
 789          tags = tracking.MlflowClient().get_run(child_id).data.tags
 790          assert tags[MLFLOW_PARENT_RUN_ID] == expected_parent_id
 791  
 792      verify_has_parent_id_tag(child_run.info.run_id, parent_run.info.run_id)
 793      verify_has_parent_id_tag(grand_child_run.info.run_id, child_run.info.run_id)
 794      assert mlflow.active_run() is None
 795  
 796  
 797  def test_start_deleted_run():
 798      run_id = None
 799      with mlflow.start_run() as active_run:
 800          run_id = active_run.info.run_id
 801      tracking.MlflowClient().delete_run(run_id)
 802      with pytest.raises(MlflowException, match="because it is in the deleted state."):
 803          with mlflow.start_run(run_id=run_id):
 804              pass
 805      assert mlflow.active_run() is None
 806  
 807  
 808  @pytest.mark.usefixtures("reset_active_experiment")
 809  def test_start_run_exp_id_0():
 810      mlflow.set_experiment("some-experiment")
 811      # Create a run and verify that the current active experiment is the one we just set
 812      with mlflow.start_run() as active_run:
 813          exp_id = active_run.info.experiment_id
 814          assert exp_id != FileStore.DEFAULT_EXPERIMENT_ID
 815          assert MlflowClient().get_experiment(exp_id).name == "some-experiment"
 816      # Set experiment ID to 0 when creating a run, verify that the specified experiment ID is honored
 817      with mlflow.start_run(experiment_id=0) as active_run:
 818          assert active_run.info.experiment_id == FileStore.DEFAULT_EXPERIMENT_ID
 819  
 820  
 821  def test_get_artifact_uri_with_artifact_path_unspecified_returns_artifact_root_dir():
 822      with mlflow.start_run() as active_run:
 823          assert mlflow.get_artifact_uri(artifact_path=None) == active_run.info.artifact_uri
 824  
 825  
 826  def test_get_artifact_uri_uses_currently_active_run_id():
 827      artifact_path = "artifact"
 828      with mlflow.start_run() as active_run:
 829          assert mlflow.get_artifact_uri(
 830              artifact_path=artifact_path
 831          ) == tracking.artifact_utils.get_artifact_uri(
 832              run_id=active_run.info.run_id, artifact_path=artifact_path
 833          )
 834  
 835  
 836  def _assert_get_artifact_uri_appends_to_uri_path_component_correctly(
 837      artifact_location, expected_uri_format
 838  ):
 839      client = MlflowClient()
 840      client.create_experiment("get-artifact-uri-test", artifact_location=artifact_location)
 841      mlflow.set_experiment("get-artifact-uri-test")
 842      with mlflow.start_run():
 843          run_id = mlflow.active_run().info.run_id
 844          for artifact_path in ["path/to/artifact", "/artifact/path", "arty.txt"]:
 845              artifact_uri = mlflow.get_artifact_uri(artifact_path)
 846              assert artifact_uri == tracking.artifact_utils.get_artifact_uri(run_id, artifact_path)
 847              assert artifact_uri == expected_uri_format.format(
 848                  run_id=run_id,
 849                  path=artifact_path.lstrip("/"),
 850                  drive=pathlib.Path.cwd().drive,
 851              )
 852  
 853  
 854  @pytest.mark.parametrize(
 855      ("artifact_location", "expected_uri_format"),
 856      [
 857          (
 858              "mysql://user:password@host:port/dbname?driver=mydriver",
 859              "mysql://user:password@host:port/dbname/{run_id}/artifacts/{path}?driver=mydriver",
 860          ),
 861          (
 862              "mysql+driver://user:pass@host:port/dbname/subpath/#fragment",
 863              "mysql+driver://user:pass@host:port/dbname/subpath/{run_id}/artifacts/{path}#fragment",
 864          ),
 865          (
 866              "s3://bucketname/rootpath",
 867              "s3://bucketname/rootpath/{run_id}/artifacts/{path}",
 868          ),
 869      ],
 870  )
 871  def test_get_artifact_uri_appends_to_uri_path_component_correctly(
 872      artifact_location, expected_uri_format
 873  ):
 874      _assert_get_artifact_uri_appends_to_uri_path_component_correctly(
 875          artifact_location, expected_uri_format
 876      )
 877  
 878  
 879  @pytest.mark.skipif(not is_windows(), reason="This test only passes on Windows")
 880  def test_get_artifact_uri_appends_to_local_path_component_correctly_on_windows():
 881      _assert_get_artifact_uri_appends_to_uri_path_component_correctly(
 882          "/dirname/rootpa#th?",
 883          "file:///{drive}/dirname/rootpa/{run_id}/artifacts/{path}#th?",
 884      )
 885  
 886  
 887  @pytest.mark.skipif(is_windows(), reason="This test fails on Windows")
 888  def test_get_artifact_uri_appends_to_local_path_component_correctly():
 889      _assert_get_artifact_uri_appends_to_uri_path_component_correctly(
 890          "/dirname/rootpa#th?", "{drive}/dirname/rootpa#th?/{run_id}/artifacts/{path}"
 891      )
 892  
 893  
 894  @pytest.mark.usefixtures("reset_active_experiment")
 895  def test_search_runs():
 896      mlflow.set_experiment("exp-for-search")
 897      # Create a run and verify that the current active experiment is the one we just set
 898      logged_runs = {}
 899      with mlflow.start_run() as active_run:
 900          logged_runs["first"] = active_run.info.run_id
 901          mlflow.log_metric("m1", 0.001)
 902          mlflow.log_metric("m2", 0.002)
 903          mlflow.log_metric("m1", 0.002)
 904          mlflow.log_param("p1", "a")
 905          mlflow.set_tag("t1", "first-tag-val")
 906      with mlflow.start_run() as active_run:
 907          logged_runs["second"] = active_run.info.run_id
 908          mlflow.log_metric("m1", 0.008)
 909          mlflow.log_param("p2", "aa")
 910          mlflow.set_tag("t2", "second-tag-val")
 911  
 912      def verify_runs(runs, expected_set):
 913          assert {r.info.run_id for r in runs} == {logged_runs[r] for r in expected_set}
 914  
 915      experiment_id = MlflowClient().get_experiment_by_name("exp-for-search").experiment_id
 916  
 917      # 2 runs in this experiment
 918      assert len(MlflowClient().search_runs([experiment_id], run_view_type=ViewType.ACTIVE_ONLY)) == 2
 919  
 920      # 2 runs that have metric "m1" > 0.001
 921      runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.0001")
 922      verify_runs(runs, ["first", "second"])
 923  
 924      # 1 run with has metric "m1" > 0.002
 925      runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.002")
 926      verify_runs(runs, ["second"])
 927  
 928      # no runs with metric "m1" > 0.1
 929      runs = MlflowClient().search_runs([experiment_id], "metrics.m1 > 0.1")
 930      verify_runs(runs, [])
 931  
 932      # 1 run with metric "m2" > 0
 933      runs = MlflowClient().search_runs([experiment_id], "metrics.m2 > 0")
 934      verify_runs(runs, ["first"])
 935  
 936      # 1 run each with param "p1" and "p2"
 937      runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ALL)
 938      verify_runs(runs, ["first"])
 939      runs = MlflowClient().search_runs([experiment_id], "params.p2 != 'a'", ViewType.ALL)
 940      verify_runs(runs, ["second"])
 941      runs = MlflowClient().search_runs([experiment_id], "params.p2 = 'aa'", ViewType.ALL)
 942      verify_runs(runs, ["second"])
 943  
 944      # 1 run each with tag "t1" and "t2"
 945      runs = MlflowClient().search_runs([experiment_id], "tags.t1 = 'first-tag-val'", ViewType.ALL)
 946      verify_runs(runs, ["first"])
 947      runs = MlflowClient().search_runs([experiment_id], "tags.t2 != 'qwerty'", ViewType.ALL)
 948      verify_runs(runs, ["second"])
 949      runs = MlflowClient().search_runs([experiment_id], "tags.t2 = 'second-tag-val'", ViewType.ALL)
 950      verify_runs(runs, ["second"])
 951  
 952      # delete "first" run
 953      MlflowClient().delete_run(logged_runs["first"])
 954      runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ALL)
 955      verify_runs(runs, ["first"])
 956  
 957      runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.DELETED_ONLY)
 958      verify_runs(runs, ["first"])
 959  
 960      runs = MlflowClient().search_runs([experiment_id], "params.p1 = 'a'", ViewType.ACTIVE_ONLY)
 961      verify_runs(runs, [])
 962  
 963  
 964  @pytest.mark.usefixtures("reset_active_experiment")
 965  def test_search_runs_multiple_experiments():
 966      experiment_ids = [mlflow.create_experiment(f"exp__{exp_id}") for exp_id in range(1, 4)]
 967      for eid in experiment_ids:
 968          with mlflow.start_run(experiment_id=eid):
 969              mlflow.log_metric("m0", 1)
 970              mlflow.log_metric(f"m_{eid}", 2)
 971  
 972      assert len(MlflowClient().search_runs(experiment_ids, "metrics.m0 > 0", ViewType.ALL)) == 3
 973  
 974      assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_1 > 0", ViewType.ALL)) == 1
 975      assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_2 = 2", ViewType.ALL)) == 1
 976      assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_3 < 4", ViewType.ALL)) == 1
 977  
 978  
 979  def read_data(artifact_path):
 980      import pandas as pd
 981  
 982      if artifact_path.endswith(".json"):
 983          return pd.read_json(artifact_path, orient="split")
 984      if artifact_path.endswith(".parquet"):
 985          return pd.read_parquet(artifact_path)
 986      raise ValueError(f"Unsupported file type in {artifact_path}. Expected .json or .parquet")
 987  
 988  
 989  @pytest.mark.skipif(
 990      "MLFLOW_SKINNY" in os.environ,
 991      reason="Skinny client does not support the np or pandas dependencies",
 992  )
 993  @pytest.mark.parametrize("file_type", ["json", "parquet"])
 994  def test_log_table(file_type):
 995      import pandas as pd
 996  
 997      table_dict = {
 998          "inputs": ["What is MLflow?", "What is Databricks?"],
 999          "outputs": ["MLflow is ...", "Databricks is ..."],
1000          "toxicity": [0.0, 0.0],
1001      }
1002      artifact_file = f"qabot_eval_results.{file_type}"
1003      TAG_NAME = "mlflow.loggedArtifacts"
1004      run_id = None
1005  
1006      with pytest.raises(
1007          MlflowException, match="data must be a pandas.DataFrame or a dictionary"
1008      ) as e:
1009          with mlflow.start_run() as run:
1010              # Log the incorrect data format as a table
1011              mlflow.log_table(data="incorrect-data-format", artifact_file=artifact_file)
1012      assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
1013  
1014      with mlflow.start_run() as run:
1015          # Log the dictionary as a table
1016          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1017          run_id = run.info.run_id
1018  
1019      run = mlflow.get_run(run_id)
1020      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1021      table_data = read_data(artifact_path)
1022      assert table_data.shape[0] == 2
1023      assert table_data.shape[1] == 3
1024  
1025      # Get the current value of the tag
1026      current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]"))
1027      assert {"path": artifact_file, "type": "table"} in current_tag_value
1028      assert len(current_tag_value) == 1
1029  
1030      table_df = pd.DataFrame.from_dict(table_dict)
1031      with mlflow.start_run(run_id=run_id):
1032          # Log the dataframe as a table
1033          mlflow.log_table(data=table_df, artifact_file=artifact_file)
1034  
1035      run = mlflow.get_run(run_id)
1036      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1037      table_data = read_data(artifact_path)
1038      assert table_data.shape[0] == 4
1039      assert table_data.shape[1] == 3
1040      # Get the current value of the tag
1041      current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]"))
1042      assert {"path": artifact_file, "type": "table"} in current_tag_value
1043      assert len(current_tag_value) == 1
1044  
1045      artifact_file_new = f"qabot_eval_results_new.{file_type}"
1046      with mlflow.start_run(run_id=run_id):
1047          # Log the dataframe as a table to new artifact file
1048          mlflow.log_table(data=table_df, artifact_file=artifact_file_new)
1049  
1050      run = mlflow.get_run(run_id)
1051      artifact_path = mlflow.artifacts.download_artifacts(
1052          run_id=run_id, artifact_path=artifact_file_new
1053      )
1054      table_data = read_data(artifact_path)
1055      assert table_data.shape[0] == 2
1056      assert table_data.shape[1] == 3
1057      # Get the current value of the tag
1058      current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]"))
1059      assert {"path": artifact_file_new, "type": "table"} in current_tag_value
1060      assert len(current_tag_value) == 2
1061  
1062  
1063  @pytest.mark.skipif(
1064      "MLFLOW_SKINNY" in os.environ,
1065      reason="Skinny client does not support the np or pandas dependencies",
1066  )
1067  @pytest.mark.parametrize("file_type", ["json", "parquet"])
1068  def test_log_table_with_subdirectory(file_type):
1069      import pandas as pd
1070  
1071      table_dict = {
1072          "inputs": ["What is MLflow?", "What is Databricks?"],
1073          "outputs": ["MLflow is ...", "Databricks is ..."],
1074          "toxicity": [0.0, 0.0],
1075      }
1076      artifact_file = f"dir/foo.{file_type}"
1077      TAG_NAME = "mlflow.loggedArtifacts"
1078      run_id = None
1079  
1080      with mlflow.start_run() as run:
1081          # Log the dictionary as a table
1082          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1083          run_id = run.info.run_id
1084  
1085      run = mlflow.get_run(run_id)
1086      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1087      table_data = read_data(artifact_path)
1088      assert table_data.shape[0] == 2
1089      assert table_data.shape[1] == 3
1090  
1091      # Get the current value of the tag
1092      current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]"))
1093      assert {"path": artifact_file, "type": "table"} in current_tag_value
1094      assert len(current_tag_value) == 1
1095  
1096      table_df = pd.DataFrame.from_dict(table_dict)
1097      with mlflow.start_run(run_id=run_id):
1098          # Log the dataframe as a table
1099          mlflow.log_table(data=table_df, artifact_file=artifact_file)
1100  
1101      run = mlflow.get_run(run_id)
1102      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1103      table_data = read_data(artifact_path)
1104      assert table_data.shape[0] == 4
1105      assert table_data.shape[1] == 3
1106      # Get the current value of the tag
1107      current_tag_value = json.loads(run.data.tags.get(TAG_NAME, "[]"))
1108      assert {"path": artifact_file, "type": "table"} in current_tag_value
1109      assert len(current_tag_value) == 1
1110  
1111  
1112  @pytest.mark.skipif(
1113      "MLFLOW_SKINNY" in os.environ,
1114      reason="Skinny client does not support the np or pandas dependencies",
1115  )
1116  @pytest.mark.parametrize("file_type", ["json", "parquet"])
1117  def test_load_table(file_type):
1118      table_dict = {
1119          "inputs": ["What is MLflow?", "What is Databricks?"],
1120          "outputs": ["MLflow is ...", "Databricks is ..."],
1121          "toxicity": [0.0, 0.0],
1122      }
1123      artifact_file = f"qabot_eval_results.{file_type}"
1124      artifact_file_2 = f"qabot_eval_results_2.{file_type}"
1125      run_id_2 = None
1126  
1127      with mlflow.start_run() as run:
1128          # Log the dictionary as a table
1129          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1130          mlflow.log_table(data=table_dict, artifact_file=artifact_file_2)
1131  
1132      with mlflow.start_run() as run:
1133          # Log the dictionary as a table
1134          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1135          run_id_2 = run.info.run_id
1136  
1137      with mlflow.start_run() as run:
1138          # Log the dictionary as a table
1139          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1140          run_id_3 = run.info.run_id
1141  
1142      extra_columns = ["run_id", "tags.mlflow.loggedArtifacts"]
1143  
1144      # test 1: load table with extra columns
1145      output_df = mlflow.load_table(artifact_file=artifact_file, extra_columns=extra_columns)
1146  
1147      assert output_df.shape[0] == 6
1148      assert output_df.shape[1] == 5
1149      assert output_df["run_id"].nunique() == 3
1150      assert output_df["tags.mlflow.loggedArtifacts"].nunique() == 2
1151  
1152      # test 2: load table with extra columns and single run_id
1153      output_df = mlflow.load_table(
1154          artifact_file=artifact_file, run_ids=[run_id_2], extra_columns=extra_columns
1155      )
1156  
1157      assert output_df.shape[0] == 2
1158      assert output_df.shape[1] == 5
1159      assert output_df["run_id"].nunique() == 1
1160      assert output_df["tags.mlflow.loggedArtifacts"].nunique() == 1
1161  
1162      # test 3: load table with extra columns and multiple run_ids
1163      output_df = mlflow.load_table(
1164          artifact_file=artifact_file,
1165          run_ids=[run_id_2, run_id_3],
1166          extra_columns=extra_columns,
1167      )
1168  
1169      assert output_df.shape[0] == 4
1170      assert output_df.shape[1] == 5
1171      assert output_df["run_id"].nunique() == 2
1172      assert output_df["tags.mlflow.loggedArtifacts"].nunique() == 1
1173  
1174      # test 4: load table with no extra columns and run_ids specified but different artifact file
1175      output_df = mlflow.load_table(artifact_file=artifact_file_2)
1176      import pandas as pd
1177  
1178      pd.testing.assert_frame_equal(output_df, pd.DataFrame(table_dict), check_dtype=False)
1179  
1180      # test 5: load table with no extra columns and run_ids specified
1181      output_df = mlflow.load_table(artifact_file=artifact_file)
1182  
1183      assert output_df.shape[0] == 6
1184      assert output_df.shape[1] == 3
1185  
1186      # test 6: load table with no matching results found. Error case
1187      with pytest.raises(
1188          MlflowException, match="No runs found with the corresponding table artifact"
1189      ):
1190          mlflow.load_table(artifact_file=f"error_case.{file_type}")
1191  
1192      # test 7: load table with no matching extra_column found. Error case
1193      with pytest.raises(KeyError, match="error_column"):
1194          mlflow.load_table(artifact_file=artifact_file, extra_columns=["error_column"])
1195  
1196  
1197  @pytest.mark.skipif(
1198      "MLFLOW_SKINNY" in os.environ,
1199      reason="Skinny client does not support the np or pandas dependencies",
1200  )
1201  @pytest.mark.parametrize("file_type", ["json", "parquet"])
1202  def test_log_table_with_datetime_columns(file_type):
1203      import pandas as pd
1204  
1205      start_time = str(datetime.now(timezone.utc))
1206      table_dict = {
1207          "inputs": ["What is MLflow?", "What is Databricks?"],
1208          "outputs": ["MLflow is ...", "Databricks is ..."],
1209          "start_time": [start_time, start_time],
1210      }
1211      artifact_file = f"test_time.{file_type}"
1212  
1213      with mlflow.start_run() as run:
1214          # Log the dictionary as a table
1215          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1216          run_id = run.info.run_id
1217  
1218      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1219      if file_type == "parquet":
1220          table_data = pd.read_parquet(artifact_path)
1221      else:
1222          table_data = pd.read_json(artifact_path, orient="split", convert_dates=False)
1223      assert table_data["start_time"][0] == start_time
1224  
1225      # append the same table to the same artifact file
1226      mlflow.log_table(data=table_dict, artifact_file=artifact_file, run_id=run_id)
1227      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1228      if file_type == "parquet":
1229          df = pd.read_parquet(artifact_path)
1230      else:
1231          df = pd.read_json(artifact_path, orient="split", convert_dates=False)
1232      assert df["start_time"][2] == start_time
1233  
1234  
1235  @pytest.mark.skipif(
1236      "MLFLOW_SKINNY" in os.environ,
1237      reason="Skinny client does not support the np or pandas dependencies",
1238  )
1239  @pytest.mark.parametrize("file_type", ["json", "parquet"])
1240  def test_log_table_with_image_columns(file_type):
1241      import numpy as np
1242      from PIL import Image
1243  
1244      image = mlflow.Image([[1, 2, 3]])
1245      table_dict = {
1246          "inputs": ["What is MLflow?", "What is Databricks?"],
1247          "outputs": ["MLflow is ...", "Databricks is ..."],
1248          "image": [image, image],
1249      }
1250      artifact_file = f"test_time.{file_type}"
1251  
1252      with mlflow.start_run() as run:
1253          # Log the dictionary as a table
1254          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1255          run_id = run.info.run_id
1256  
1257      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1258      table_data = read_data(artifact_path)
1259      assert table_data["image"][0]["type"] == "image"
1260      image_path = mlflow.artifacts.download_artifacts(
1261          run_id=run_id, artifact_path=table_data["image"][0]["filepath"]
1262      )
1263      image2 = Image.open(image_path)
1264      assert np.abs(image.to_array() - np.array(image2)).sum() == 0
1265  
1266      # append the same table to the same artifact file
1267      mlflow.log_table(data=table_dict, artifact_file=artifact_file, run_id=run_id)
1268      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1269      df = read_data(artifact_path)
1270      assert df["image"][2]["type"] == "image"
1271  
1272  
1273  @pytest.mark.skipif(
1274      "MLFLOW_SKINNY" in os.environ,
1275      reason="Skinny client does not support the np or pandas dependencies",
1276  )
1277  @pytest.mark.parametrize("file_type", ["json", "parquet"])
1278  def test_log_table_with_pil_image_columns(file_type):
1279      import numpy as np
1280      from PIL import Image
1281  
1282      image = Image.fromarray(np.array([[1.0, 2.0, 3.0]]))
1283      image = image.convert("RGB")
1284  
1285      table_dict = {
1286          "inputs": ["What is MLflow?", "What is Databricks?"],
1287          "outputs": ["MLflow is ...", "Databricks is ..."],
1288          "image": [image, image],
1289      }
1290      artifact_file = f"test_time.{file_type}"
1291  
1292      with mlflow.start_run() as run:
1293          # Log the dictionary as a table
1294          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1295          run_id = run.info.run_id
1296  
1297      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1298      table_data = read_data(artifact_path)
1299      assert table_data["image"][0]["type"] == "image"
1300      image_path = mlflow.artifacts.download_artifacts(
1301          run_id=run_id, artifact_path=table_data["image"][0]["filepath"]
1302      )
1303      image2 = Image.open(image_path)
1304      assert np.abs(np.array(image) - np.array(image2)).sum() == 0
1305  
1306      # append the same table to the same artifact file
1307      mlflow.log_table(data=table_dict, artifact_file=artifact_file, run_id=run_id)
1308      artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
1309      df = read_data(artifact_path)
1310      assert df["image"][2]["type"] == "image"
1311  
1312  
1313  @pytest.mark.skipif(
1314      "MLFLOW_SKINNY" in os.environ,
1315      reason="Skinny client does not support the np or pandas dependencies",
1316  )
1317  @pytest.mark.parametrize("file_type", ["json", "parquet"])
1318  def test_log_table_with_invalid_image_columns(file_type):
1319      image = mlflow.Image([[1, 2, 3]])
1320      table_dict = {
1321          "inputs": ["What is MLflow?", "What is Databricks?"],
1322          "outputs": ["MLflow is ...", "Databricks is ..."],
1323          "image": [image, "text"],
1324      }
1325      artifact_file = f"test_time.{file_type}"
1326      with pytest.raises(ValueError, match="Column `image` contains a mix of images and non-images"):
1327          with mlflow.start_run():
1328              # Log the dictionary as a table
1329              mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1330  
1331  
1332  @pytest.mark.skipif(
1333      "MLFLOW_SKINNY" in os.environ,
1334      reason="Skinny client does not support the np or pandas dependencies",
1335  )
1336  @pytest.mark.parametrize("file_type", ["json", "parquet"])
1337  def test_log_table_with_valid_image_columns(file_type):
1338      class ImageObj:
1339          def __init__(self):
1340              self.size = (1, 1)
1341  
1342          def resize(self, size):
1343              return self
1344  
1345          def save(self, path):
1346              with open(path, "w+") as f:
1347                  f.write("dummy data")
1348  
1349      image_obj = ImageObj()
1350      image = mlflow.Image([[1, 2, 3]])
1351  
1352      table_dict = {
1353          "inputs": ["What is MLflow?", "What is Databricks?"],
1354          "outputs": ["MLflow is ...", "Databricks is ..."],
1355          "image": [image, image_obj],
1356      }
1357      # No error should be raised
1358      artifact_file = f"test_time.{file_type}"
1359      with mlflow.start_run():
1360          # Log the dictionary as a table
1361          mlflow.log_table(data=table_dict, artifact_file=artifact_file)
1362  
1363  
1364  def test_set_async_logging_threadpool_size():
1365      MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.set(6)
1366      assert MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.get() == 6
1367  
1368      with mlflow.start_run():
1369          mlflow.log_param("key", "val", synchronous=False)
1370  
1371      store = mlflow.tracking._get_store()
1372      async_queue = store._async_logging_queue
1373      assert async_queue._batch_logging_worker_threadpool._max_workers == 6
1374      mlflow.flush_async_logging()
1375      MLFLOW_ASYNC_LOGGING_THREADPOOL_SIZE.unset()