/ tests / genai / datasets / test_fluent.py
test_fluent.py
   1  import json
   2  import os
   3  import sys
   4  import warnings
   5  from unittest import mock
   6  
   7  import pandas as pd
   8  import pytest
   9  
  10  import mlflow
  11  from mlflow.data import Dataset
  12  from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin
  13  from mlflow.entities.dataset_record_source import DatasetRecordSourceType
  14  from mlflow.entities.evaluation_dataset import (
  15      EvaluationDataset as EntityEvaluationDataset,
  16  )
  17  from mlflow.exceptions import MlflowException
  18  from mlflow.genai.datasets import (
  19      EvaluationDataset,
  20      create_dataset,
  21      delete_dataset,
  22      delete_dataset_tag,
  23      get_dataset,
  24      search_datasets,
  25      set_dataset_tags,
  26  )
  27  from mlflow.genai.datasets.evaluation_dataset import (
  28      EvaluationDataset as WrapperEvaluationDataset,
  29  )
  30  from mlflow.store.entities.paged_list import PagedList
  31  from mlflow.store.tracking import SEARCH_EVALUATION_DATASETS_MAX_RESULTS
  32  from mlflow.tracking import MlflowClient
  33  from mlflow.utils.mlflow_tags import MLFLOW_USER
  34  
  35  
  36  @pytest.fixture
  37  def mock_client():
  38      with (
  39          mock.patch("mlflow.tracking.client.MlflowClient") as mock_client_class,
  40          mock.patch("mlflow.genai.datasets.MlflowClient", mock_client_class),
  41      ):
  42          mock_client_instance = mock_client_class.return_value
  43          yield mock_client_instance
  44  
  45  
  46  @pytest.fixture
  47  def mock_databricks_environment():
  48      with mock.patch("mlflow.genai.datasets.is_databricks_uri", return_value=True):
  49          yield
  50  
  51  
  52  @pytest.fixture
  53  def client(db_uri):
  54      original_tracking_uri = mlflow.get_tracking_uri()
  55      mlflow.set_tracking_uri(db_uri)
  56      yield MlflowClient(tracking_uri=db_uri)
  57      mlflow.set_tracking_uri(original_tracking_uri)
  58  
  59  
  60  @pytest.fixture
  61  def experiments(client):
  62      exp1 = client.create_experiment("test_exp_1")
  63      exp2 = client.create_experiment("test_exp_2")
  64      exp3 = client.create_experiment("test_exp_3")
  65      return [exp1, exp2, exp3]
  66  
  67  
  68  @pytest.fixture
  69  def experiment(client):
  70      return client.create_experiment("test_trace_experiment")
  71  
  72  
  73  def test_create_dataset(mock_client):
  74      expected_dataset = EntityEvaluationDataset(
  75          dataset_id="test_id",
  76          name="test_dataset",
  77          digest="abc123",
  78          created_time=123456789,
  79          last_update_time=123456789,
  80          tags={"environment": "production", "version": "1.0"},
  81      )
  82      mock_client.create_dataset.return_value = expected_dataset
  83  
  84      result = create_dataset(
  85          name="test_dataset",
  86          experiment_id=["exp1", "exp2"],
  87          tags={"environment": "production", "version": "1.0"},
  88      )
  89  
  90      assert result == expected_dataset
  91      mock_client.create_dataset.assert_called_once_with(
  92          name="test_dataset",
  93          experiment_id=["exp1", "exp2"],
  94          tags={"environment": "production", "version": "1.0"},
  95      )
  96  
  97  
  98  def test_create_dataset_single_experiment_id(mock_client):
  99      expected_dataset = EntityEvaluationDataset(
 100          dataset_id="test_id",
 101          name="test_dataset",
 102          digest="abc123",
 103          created_time=123456789,
 104          last_update_time=123456789,
 105      )
 106      mock_client.create_dataset.return_value = expected_dataset
 107  
 108      result = create_dataset(
 109          name="test_dataset",
 110          experiment_id="exp1",
 111      )
 112  
 113      assert result == expected_dataset
 114      mock_client.create_dataset.assert_called_once_with(
 115          name="test_dataset",
 116          experiment_id=["exp1"],
 117          tags=None,
 118      )
 119  
 120  
 121  def test_create_dataset_with_empty_tags(mock_client):
 122      expected_dataset = EntityEvaluationDataset(
 123          dataset_id="test_id",
 124          name="test_dataset",
 125          digest="abc123",
 126          created_time=123456789,
 127          last_update_time=123456789,
 128          tags={},
 129      )
 130      mock_client.create_dataset.return_value = expected_dataset
 131  
 132      result = create_dataset(
 133          name="test_dataset",
 134          experiment_id=["exp1"],
 135          tags={},
 136      )
 137  
 138      assert result == expected_dataset
 139      mock_client.create_dataset.assert_called_once_with(
 140          name="test_dataset",
 141          experiment_id=["exp1"],
 142          tags={},
 143      )
 144  
 145  
 146  def test_create_dataset_databricks(mock_databricks_environment):
 147      mock_dataset = mock.Mock()
 148      with mock.patch.dict(
 149          "sys.modules",
 150          {
 151              "databricks.agents.datasets": mock.Mock(
 152                  create_dataset=mock.Mock(return_value=mock_dataset)
 153              )
 154          },
 155      ):
 156          result = create_dataset(
 157              name="catalog.schema.table",
 158              experiment_id=["exp1", "exp2"],
 159          )
 160  
 161          sys.modules["databricks.agents.datasets"].create_dataset.assert_called_once_with(
 162              "catalog.schema.table", ["exp1", "exp2"]
 163          )
 164          assert isinstance(result, EvaluationDataset)
 165  
 166  
 167  def test_get_dataset(mock_client):
 168      expected_dataset = EntityEvaluationDataset(
 169          dataset_id="test_id",
 170          name="test_dataset",
 171          digest="abc123",
 172          created_time=123456789,
 173          last_update_time=123456789,
 174      )
 175      mock_client.get_dataset.return_value = expected_dataset
 176  
 177      result = get_dataset(dataset_id="test_id")
 178  
 179      assert result == expected_dataset
 180      mock_client.get_dataset.assert_called_once_with("test_id")
 181  
 182  
 183  def test_get_dataset_missing_id():
 184      with pytest.raises(ValueError, match="Either 'name' or 'dataset_id' must be provided"):
 185          get_dataset()
 186  
 187  
 188  def test_get_dataset_databricks(mock_databricks_environment):
 189      mock_dataset = mock.Mock()
 190      with mock.patch.dict(
 191          "sys.modules",
 192          {"databricks.agents.datasets": mock.Mock(get_dataset=mock.Mock(return_value=mock_dataset))},
 193      ):
 194          result = get_dataset(name="catalog.schema.table")
 195  
 196          sys.modules["databricks.agents.datasets"].get_dataset.assert_called_once_with(
 197              "catalog.schema.table"
 198          )
 199          assert isinstance(result, EvaluationDataset)
 200  
 201  
 202  def test_get_dataset_databricks_missing_name(mock_databricks_environment):
 203      with pytest.raises(ValueError, match="Parameter 'name' is required"):
 204          get_dataset(dataset_id="test_id")
 205  
 206  
 207  def test_get_dataset_by_name_oss(experiments):
 208      dataset = create_dataset(
 209          name="unique_dataset_name",
 210          experiment_id=experiments[0],
 211          tags={"test": "get_by_name"},
 212      )
 213  
 214      retrieved = get_dataset(name="unique_dataset_name")
 215  
 216      assert retrieved.dataset_id == dataset.dataset_id
 217      assert retrieved.name == "unique_dataset_name"
 218      assert retrieved.tags["test"] == "get_by_name"
 219  
 220  
 221  def test_get_dataset_by_name_not_found(client):
 222      with pytest.raises(MlflowException, match="Dataset with name 'nonexistent_dataset' not found"):
 223          get_dataset(name="nonexistent_dataset")
 224  
 225  
 226  def test_get_dataset_by_name_multiple_matches(experiments):
 227      create_dataset(
 228          name="duplicate_name",
 229          experiment_id=experiments[0],
 230      )
 231      create_dataset(
 232          name="duplicate_name",
 233          experiment_id=experiments[1],
 234      )
 235  
 236      with pytest.raises(MlflowException, match="Multiple datasets found with name 'duplicate_name'"):
 237          get_dataset(name="duplicate_name")
 238  
 239  
 240  def test_get_dataset_both_name_and_id_error(experiments):
 241      dataset = create_dataset(
 242          name="test_dataset_both",
 243          experiment_id=experiments[0],
 244      )
 245  
 246      with pytest.raises(ValueError, match="Cannot specify both 'name' and 'dataset_id'"):
 247          get_dataset(name="test_dataset_both", dataset_id=dataset.dataset_id)
 248  
 249  
 250  def test_get_dataset_neither_name_nor_id_error(client):
 251      with pytest.raises(ValueError, match="Either 'name' or 'dataset_id' must be provided"):
 252          get_dataset()
 253  
 254  
 255  @pytest.mark.parametrize(
 256      "name",
 257      [
 258          "dataset's_with_single_quote",
 259          'dataset"with_double_quote',
 260      ],
 261  )
 262  def test_get_dataset_name_with_quotes(experiments, name):
 263      dataset = create_dataset(name=name, experiment_id=experiments[0])
 264  
 265      retrieved = get_dataset(name=name)
 266  
 267      assert retrieved.dataset_id == dataset.dataset_id
 268      assert retrieved.name == name
 269  
 270  
 271  def test_delete_dataset(mock_client):
 272      delete_dataset(dataset_id="test_id")
 273  
 274      mock_client.delete_dataset.assert_called_once_with("test_id")
 275  
 276  
 277  def test_delete_dataset_missing_id():
 278      with pytest.raises(ValueError, match="Parameter 'dataset_id' is required"):
 279          delete_dataset()
 280  
 281  
 282  def test_delete_dataset_databricks(mock_databricks_environment):
 283      with mock.patch.dict(
 284          "sys.modules",
 285          {"databricks.agents.datasets": mock.Mock(delete_dataset=mock.Mock())},
 286      ):
 287          delete_dataset(name="catalog.schema.table")
 288  
 289          sys.modules["databricks.agents.datasets"].delete_dataset.assert_called_once_with(
 290              "catalog.schema.table"
 291          )
 292  
 293  
 294  def test_search_datasets_with_mock(mock_client):
 295      datasets = [
 296          EntityEvaluationDataset(
 297              dataset_id="id1",
 298              name="dataset1",
 299              digest="digest1",
 300              created_time=123456789,
 301              last_update_time=123456789,
 302          ),
 303          EntityEvaluationDataset(
 304              dataset_id="id2",
 305              name="dataset2",
 306              digest="digest2",
 307              created_time=123456789,
 308              last_update_time=123456789,
 309          ),
 310      ]
 311      mock_client.search_datasets.return_value = PagedList(datasets, None)
 312  
 313      result = search_datasets(
 314          experiment_ids=["exp1", "exp2"],
 315          filter_string="name LIKE 'test%'",
 316          max_results=100,
 317          order_by=["created_time DESC"],
 318      )
 319  
 320      assert len(result) == 2
 321      assert isinstance(result, list)
 322  
 323      mock_client.search_datasets.assert_called_once_with(
 324          experiment_ids=["exp1", "exp2"],
 325          filter_string="name LIKE 'test%'",
 326          max_results=50,
 327          order_by=["created_time DESC"],
 328          page_token=None,
 329      )
 330  
 331  
 332  def test_search_datasets_single_experiment_id(mock_client):
 333      datasets = [
 334          EntityEvaluationDataset(
 335              dataset_id="id1",
 336              name="dataset1",
 337              digest="digest1",
 338              created_time=123456789,
 339              last_update_time=123456789,
 340          )
 341      ]
 342      mock_client.search_datasets.return_value = PagedList(datasets, None)
 343  
 344      # When no max_results is specified, it defaults to None which means get all
 345      # Mock time to have a consistent filter_string
 346      with mock.patch("time.time", return_value=1234567890):
 347          search_datasets(experiment_ids="exp1")
 348  
 349      # The pagination wrapper will use SEARCH_EVALUATION_DATASETS_MAX_RESULTS as the page size
 350      # Now the function adds default filter (last 7 days) and order_by when not specified
 351      seven_days_ago = int((1234567890 - 7 * 24 * 60 * 60) * 1000)
 352      mock_client.search_datasets.assert_called_once_with(
 353          experiment_ids=["exp1"],
 354          filter_string=f"created_time >= {seven_days_ago}",
 355          max_results=SEARCH_EVALUATION_DATASETS_MAX_RESULTS,  # Page size
 356          order_by=["created_time DESC"],
 357          page_token=None,
 358      )
 359  
 360  
 361  def test_search_datasets_pagination_handling(mock_client):
 362      page1_datasets = [
 363          EntityEvaluationDataset(
 364              dataset_id=f"id{i}",
 365              name=f"dataset{i}",
 366              digest=f"digest{i}",
 367              created_time=123456789,
 368              last_update_time=123456789,
 369          )
 370          for i in range(3)
 371      ]
 372  
 373      page2_datasets = [
 374          EntityEvaluationDataset(
 375              dataset_id=f"id{i}",
 376              name=f"dataset{i}",
 377              digest=f"digest{i}",
 378              created_time=123456789,
 379              last_update_time=123456789,
 380          )
 381          for i in range(3, 5)
 382      ]
 383  
 384      mock_client.search_datasets.side_effect = [
 385          PagedList(page1_datasets, "token1"),
 386          PagedList(page2_datasets, None),
 387      ]
 388  
 389      result = search_datasets(experiment_ids=["exp1"], max_results=10)
 390  
 391      assert len(result) == 5
 392      assert isinstance(result, list)
 393  
 394      assert mock_client.search_datasets.call_count == 2
 395  
 396      first_call = mock_client.search_datasets.call_args_list[0]
 397      assert first_call[1]["page_token"] is None
 398  
 399      second_call = mock_client.search_datasets.call_args_list[1]
 400      assert second_call[1]["page_token"] == "token1"
 401  
 402  
 403  def test_search_datasets_single_page(mock_client):
 404      datasets = [
 405          EntityEvaluationDataset(
 406              dataset_id="id1",
 407              name="dataset1",
 408              digest="digest1",
 409              created_time=123456789,
 410              last_update_time=123456789,
 411          )
 412      ]
 413  
 414      mock_client.search_datasets.return_value = PagedList(datasets, None)
 415  
 416      result = search_datasets(max_results=10)
 417  
 418      assert len(result) == 1
 419      assert isinstance(result, list)
 420  
 421      assert mock_client.search_datasets.call_count == 1
 422  
 423  
 424  def test_search_datasets_databricks(mock_databricks_environment, mock_client):
 425      datasets = [
 426          EntityEvaluationDataset(
 427              dataset_id="id1",
 428              name="dataset1",
 429              digest="digest1",
 430              created_time=123456789,
 431              last_update_time=123456789,
 432          ),
 433      ]
 434      mock_client.search_datasets.return_value = PagedList(datasets, None)
 435  
 436      result = search_datasets(experiment_ids=["exp1"])
 437  
 438      assert len(result) == 1
 439      assert isinstance(result, list)
 440  
 441      # Verify that default filter_string and order_by are NOT set for Databricks
 442      # (since these parameters may not be supported by all Databricks backends)
 443      mock_client.search_datasets.assert_called_once()
 444      call_kwargs = mock_client.search_datasets.call_args.kwargs
 445      assert call_kwargs.get("filter_string") is None
 446      assert call_kwargs.get("order_by") is None
 447  
 448  
 449  def test_databricks_import_error():
 450      with (
 451          mock.patch("mlflow.genai.datasets.is_databricks_uri", return_value=True),
 452          mock.patch.dict("sys.modules", {"databricks.agents.datasets": None}),
 453          mock.patch("builtins.__import__", side_effect=ImportError("No module")),
 454      ):
 455          with pytest.raises(ImportError, match="databricks-agents"):
 456              create_dataset(name="test", experiment_id="exp1")
 457  
 458  
 459  def test_databricks_profile_uri_support():
 460      mock_dataset = mock.Mock()
 461      with (
 462          mock.patch(
 463              "mlflow.genai.datasets.get_tracking_uri",
 464              return_value="databricks://profilename",
 465          ),
 466          mock.patch.dict(
 467              "sys.modules",
 468              {
 469                  "databricks.agents.datasets": mock.Mock(
 470                      get_dataset=mock.Mock(return_value=mock_dataset),
 471                      create_dataset=mock.Mock(return_value=mock_dataset),
 472                      delete_dataset=mock.Mock(),
 473                  )
 474              },
 475          ),
 476      ):
 477          result = get_dataset(name="catalog.schema.table")
 478          sys.modules["databricks.agents.datasets"].get_dataset.assert_called_once_with(
 479              "catalog.schema.table"
 480          )
 481          assert isinstance(result, EvaluationDataset)
 482  
 483          result2 = create_dataset(name="catalog.schema.table2", experiment_id=["exp1"])
 484          sys.modules["databricks.agents.datasets"].create_dataset.assert_called_once_with(
 485              "catalog.schema.table2", ["exp1"]
 486          )
 487          assert isinstance(result2, EvaluationDataset)
 488  
 489          delete_dataset(name="catalog.schema.table3")
 490          sys.modules["databricks.agents.datasets"].delete_dataset.assert_called_once_with(
 491              "catalog.schema.table3"
 492          )
 493  
 494  
 495  def test_databricks_profile_env_var_set_from_uri(monkeypatch):
 496      mock_dataset = mock.Mock()
 497      profile_values_during_calls = []
 498  
 499      def mock_get_dataset(name):
 500          profile_values_during_calls.append((
 501              "get_dataset",
 502              os.environ.get("DATABRICKS_CONFIG_PROFILE"),
 503          ))
 504          return mock_dataset
 505  
 506      def mock_create_dataset(name, experiment_ids):
 507          profile_values_during_calls.append((
 508              "create_dataset",
 509              os.environ.get("DATABRICKS_CONFIG_PROFILE"),
 510          ))
 511          return mock_dataset
 512  
 513      def mock_delete_dataset(name):
 514          profile_values_during_calls.append((
 515              "delete_dataset",
 516              os.environ.get("DATABRICKS_CONFIG_PROFILE"),
 517          ))
 518  
 519      mock_agents_module = mock.Mock(
 520          get_dataset=mock_get_dataset,
 521          create_dataset=mock_create_dataset,
 522          delete_dataset=mock_delete_dataset,
 523      )
 524      monkeypatch.setitem(sys.modules, "databricks.agents.datasets", mock_agents_module)
 525      monkeypatch.setattr("mlflow.genai.datasets.get_tracking_uri", lambda: "databricks://myprofile")
 526  
 527      assert "DATABRICKS_CONFIG_PROFILE" not in os.environ
 528  
 529      get_dataset(name="catalog.schema.table")
 530      create_dataset(name="catalog.schema.table", experiment_id="exp1")
 531      delete_dataset(name="catalog.schema.table")
 532  
 533      assert "DATABRICKS_CONFIG_PROFILE" not in os.environ
 534  
 535      assert profile_values_during_calls == [
 536          ("get_dataset", "myprofile"),
 537          ("create_dataset", "myprofile"),
 538          ("delete_dataset", "myprofile"),
 539      ]
 540  
 541  
 542  def test_databricks_profile_env_var_overridden_and_restored(monkeypatch):
 543      mock_dataset = mock.Mock()
 544      profile_during_call = None
 545  
 546      def mock_get_dataset(name):
 547          nonlocal profile_during_call
 548          profile_during_call = os.environ.get("DATABRICKS_CONFIG_PROFILE")
 549          return mock_dataset
 550  
 551      mock_agents_module = mock.Mock(get_dataset=mock_get_dataset)
 552      monkeypatch.setitem(sys.modules, "databricks.agents.datasets", mock_agents_module)
 553      monkeypatch.setattr("mlflow.genai.datasets.get_tracking_uri", lambda: "databricks://myprofile")
 554      monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "original_profile")
 555  
 556      assert os.environ.get("DATABRICKS_CONFIG_PROFILE") == "original_profile"
 557  
 558      get_dataset(name="catalog.schema.table")
 559  
 560      assert os.environ.get("DATABRICKS_CONFIG_PROFILE") == "original_profile"
 561      assert profile_during_call == "myprofile"
 562  
 563  
 564  def test_databricks_dataset_merge_records_uses_profile(monkeypatch):
 565      profile_during_merge = None
 566      profile_during_to_df = None
 567  
 568      mock_inner_dataset = mock.Mock()
 569      mock_inner_dataset.digest = "test_digest"
 570      mock_inner_dataset.name = "catalog.schema.table"
 571      mock_inner_dataset.dataset_id = "dataset-123"
 572  
 573      def mock_merge_records(records):
 574          nonlocal profile_during_merge
 575          profile_during_merge = os.environ.get("DATABRICKS_CONFIG_PROFILE")
 576          return mock_inner_dataset
 577  
 578      def mock_to_df():
 579          nonlocal profile_during_to_df
 580          profile_during_to_df = os.environ.get("DATABRICKS_CONFIG_PROFILE")
 581          import pandas as pd
 582  
 583          return pd.DataFrame({"test": [1, 2, 3]})
 584  
 585      mock_inner_dataset.merge_records = mock_merge_records
 586      mock_inner_dataset.to_df = mock_to_df
 587  
 588      def mock_get_dataset(name):
 589          return mock_inner_dataset
 590  
 591      mock_agents_module = mock.Mock(get_dataset=mock_get_dataset)
 592      monkeypatch.setitem(sys.modules, "databricks.agents.datasets", mock_agents_module)
 593      monkeypatch.setattr("mlflow.genai.datasets.get_tracking_uri", lambda: "databricks://myprofile")
 594  
 595      assert "DATABRICKS_CONFIG_PROFILE" not in os.environ
 596  
 597      dataset = get_dataset(name="catalog.schema.table")
 598  
 599      assert "DATABRICKS_CONFIG_PROFILE" not in os.environ
 600  
 601      dataset.merge_records([{"inputs": {"q": "test"}}])
 602      assert profile_during_merge == "myprofile"
 603      assert "DATABRICKS_CONFIG_PROFILE" not in os.environ
 604  
 605      dataset.to_df()
 606      assert profile_during_to_df == "myprofile"
 607      assert "DATABRICKS_CONFIG_PROFILE" not in os.environ
 608  
 609  
 610  def test_create_dataset_with_user_tag(experiments):
 611      dataset = create_dataset(
 612          name="test_user_attribution",
 613          experiment_id=experiments[0],
 614          tags={"environment": "test", MLFLOW_USER: "john_doe"},
 615      )
 616  
 617      assert dataset.name == "test_user_attribution"
 618      assert dataset.tags[MLFLOW_USER] == "john_doe"
 619      assert dataset.created_by == "john_doe"
 620  
 621      dataset2 = create_dataset(
 622          name="test_no_user",
 623          experiment_id=experiments[0],
 624          tags={"environment": "test"},
 625      )
 626  
 627      assert dataset2.name == "test_no_user"
 628      assert isinstance(dataset2.tags[MLFLOW_USER], str)
 629      assert dataset2.created_by == dataset2.tags[MLFLOW_USER]
 630  
 631  
 632  def test_create_and_get_dataset(experiments):
 633      dataset = create_dataset(
 634          name="qa_evaluation_v1",
 635          experiment_id=[experiments[0], experiments[1]],
 636          tags={"source": "manual_curation", "environment": "test"},
 637      )
 638  
 639      assert dataset.name == "qa_evaluation_v1"
 640      assert dataset.tags["source"] == "manual_curation"
 641      assert dataset.tags["environment"] == "test"
 642      assert len(dataset.experiment_ids) == 2
 643      assert dataset.dataset_id is not None
 644  
 645      retrieved = get_dataset(dataset_id=dataset.dataset_id)
 646  
 647      assert retrieved.dataset_id == dataset.dataset_id
 648      assert retrieved.name == dataset.name
 649      assert retrieved.tags == dataset.tags
 650      assert set(retrieved.experiment_ids) == {experiments[0], experiments[1]}
 651  
 652  
 653  def test_create_dataset_minimal_params(client):
 654      dataset = create_dataset(name="minimal_dataset")
 655  
 656      assert dataset.name == "minimal_dataset"
 657      assert "mlflow.user" not in dataset.tags or isinstance(dataset.tags.get("mlflow.user"), str)
 658      assert dataset.experiment_ids == ["0"]
 659  
 660  
 661  def test_active_record_pattern_merge_records(experiments):
 662      dataset = create_dataset(
 663          name="active_record_test",
 664          experiment_id=experiments[0],
 665      )
 666  
 667      records_batch1 = [
 668          {
 669              "inputs": {"question": "What is MLflow?"},
 670              "outputs": {
 671                  "answer": "MLflow is an open source platform for managing the ML lifecycle",
 672                  "key1": "value1",
 673              },
 674              "expectations": {
 675                  "answer": "MLflow is an open source platform",
 676                  "key2": "value2",
 677              },
 678              "tags": {"difficulty": "easy"},
 679          },
 680          {
 681              "inputs": {"question": "What is Python?"},
 682              "outputs": {"answer": "Python is a versatile programming language"},
 683              "expectations": {"answer": "Python is a programming language"},
 684              "tags": {"difficulty": "easy"},
 685          },
 686      ]
 687  
 688      records_batch2 = [
 689          {
 690              "inputs": {"question": "What is MLflow?"},
 691              "outputs": {"answer": "MLflow is a popular ML lifecycle platform"},
 692              "expectations": {"answer": "MLflow is an ML lifecycle platform"},
 693              "tags": {"category": "ml"},
 694          },
 695          {
 696              "inputs": {"question": "What is Docker?"},
 697              "outputs": {"answer": "Docker is a popular containerization platform"},
 698              "expectations": {"answer": "Docker is a containerization platform"},
 699              "tags": {"difficulty": "medium"},
 700          },
 701      ]
 702  
 703      dataset.merge_records(records_batch1)
 704  
 705      df1 = dataset.to_df()
 706      assert len(df1) == 2
 707  
 708      mlflow_record = df1[df1["inputs"].apply(lambda x: x.get("question") == "What is MLflow?")].iloc[
 709          0
 710      ]
 711      assert mlflow_record["expectations"] == {
 712          "answer": "MLflow is an open source platform",
 713          "key2": "value2",
 714      }
 715      assert mlflow_record["outputs"] == {
 716          "answer": "MLflow is an open source platform for managing the ML lifecycle",
 717          "key1": "value1",
 718      }
 719      assert mlflow_record["tags"]["difficulty"] == "easy"
 720      assert "category" not in mlflow_record["tags"]
 721  
 722      dataset.merge_records(records_batch2)
 723  
 724      df2 = dataset.to_df()
 725      assert len(df2) == 3
 726  
 727      mlflow_record_updated = df2[
 728          df2["inputs"].apply(lambda x: x.get("question") == "What is MLflow?")
 729      ].iloc[0]
 730  
 731      assert mlflow_record_updated["expectations"] == {
 732          "answer": "MLflow is an ML lifecycle platform",
 733          "key2": "value2",
 734      }
 735      assert mlflow_record_updated["outputs"] == {
 736          "answer": "MLflow is a popular ML lifecycle platform"
 737      }
 738      assert mlflow_record_updated["tags"]["difficulty"] == "easy"
 739      assert mlflow_record_updated["tags"]["category"] == "ml"
 740  
 741      # Verify that the new Docker record also has outputs
 742      docker_record = df2[df2["inputs"].apply(lambda x: x.get("question") == "What is Docker?")].iloc[
 743          0
 744      ]
 745      assert docker_record["outputs"]["answer"] == "Docker is a popular containerization platform"
 746      assert docker_record["expectations"]["answer"] == "Docker is a containerization platform"
 747      assert docker_record["tags"]["difficulty"] == "medium"
 748  
 749  
 750  def test_dataset_with_dataframe_records(experiments):
 751      dataset = create_dataset(
 752          name="dataframe_test",
 753          experiment_id=experiments[0],
 754          tags={"source": "csv", "file": "test_data.csv"},
 755      )
 756  
 757      df = pd.DataFrame([
 758          {
 759              "inputs": {"text": "The movie was amazing!", "model": "sentiment-v1"},
 760              "expectations": {"sentiment": "positive", "confidence": 0.95},
 761              "tags": {"source": "imdb"},
 762          },
 763          {
 764              "inputs": {"text": "Terrible experience", "model": "sentiment-v1"},
 765              "expectations": {"sentiment": "negative", "confidence": 0.88},
 766              "tags": {"source": "yelp"},
 767          },
 768      ])
 769  
 770      dataset.merge_records(df)
 771  
 772      result_df = dataset.to_df()
 773      assert len(result_df) == 2
 774      assert all(col in result_df.columns for col in ["inputs", "expectations", "tags"])
 775  
 776      # Check that all expected records are present (order-agnostic)
 777      texts = {record["inputs"]["text"] for _, record in result_df.iterrows()}
 778      expected_texts = {"The movie was amazing!", "Terrible experience"}
 779      assert texts == expected_texts
 780  
 781      sentiments = {record["expectations"]["sentiment"] for _, record in result_df.iterrows()}
 782      expected_sentiments = {"positive", "negative"}
 783      assert sentiments == expected_sentiments
 784  
 785  
 786  def test_search_datasets(experiments):
 787      for i in range(5):
 788          create_dataset(
 789              name=f"search_test_{i}",
 790              experiment_id=[experiments[i % len(experiments)]],
 791              tags={"type": "human" if i % 2 == 0 else "trace", "index": str(i)},
 792          )
 793  
 794      all_results = search_datasets()
 795      assert len(all_results) == 5
 796  
 797      exp0_results = search_datasets(experiment_ids=experiments[0])
 798      assert len(exp0_results) == 2
 799  
 800      human_results = search_datasets(filter_string="name LIKE 'search_test_%'")
 801      assert len(human_results) == 5
 802  
 803      limited_results = search_datasets(max_results=2)
 804      assert len(limited_results) == 2
 805  
 806      more_results = search_datasets(max_results=4)
 807      assert len(more_results) == 4
 808  
 809  
 810  def test_delete_dataset(experiments):
 811      dataset = create_dataset(
 812          name="to_be_deleted",
 813          experiment_id=[experiments[0], experiments[1]],
 814          tags={"env": "test", "version": "1.0"},
 815      )
 816      dataset_id = dataset.dataset_id
 817  
 818      dataset.merge_records([{"inputs": {"q": "test"}, "expectations": {"a": "answer"}}])
 819  
 820      retrieved = get_dataset(dataset_id=dataset_id)
 821      assert retrieved is not None
 822      assert len(retrieved.to_df()) == 1
 823  
 824      delete_dataset(dataset_id=dataset_id)
 825  
 826      with pytest.raises(MlflowException, match="Could not find|not found"):
 827          get_dataset(dataset_id=dataset_id)
 828  
 829      search_results = search_datasets(experiment_ids=[experiments[0], experiments[1]])
 830      found_ids = [d.dataset_id for d in search_results]
 831      assert dataset_id not in found_ids
 832  
 833  
 834  def test_dataset_lifecycle_workflow(experiments):
 835      dataset = create_dataset(
 836          name="qa_eval_prod_v1",
 837          experiment_id=[experiments[0], experiments[1]],
 838          tags={"source": "qa_team_annotations", "team": "qa", "env": "prod"},
 839      )
 840  
 841      initial_cases = [
 842          {
 843              "inputs": {"question": "What is the capital of France?"},
 844              "expectations": {"answer": "Paris", "confidence": "high"},
 845              "tags": {"category": "geography", "difficulty": "easy"},
 846          },
 847          {
 848              "inputs": {"question": "Explain quantum computing"},
 849              "expectations": {"answer": "Quantum computing uses quantum mechanics principles"},
 850              "tags": {"category": "science", "difficulty": "hard"},
 851          },
 852      ]
 853      dataset.merge_records(initial_cases)
 854  
 855      dataset_id = dataset.dataset_id
 856      retrieved = get_dataset(dataset_id=dataset_id)
 857      df = retrieved.to_df()
 858      assert len(df) == 2
 859  
 860      additional_cases = [
 861          {
 862              "inputs": {"question": "What is 2+2?"},
 863              "expectations": {"answer": "4", "confidence": "high"},
 864              "tags": {"category": "math", "difficulty": "easy"},
 865          },
 866      ]
 867      retrieved.merge_records(additional_cases)
 868  
 869      found = search_datasets(
 870          experiment_ids=experiments[0],
 871          filter_string="name LIKE 'qa_eval%'",
 872      )
 873      assert len(found) == 1
 874      assert found[0].dataset_id == dataset_id
 875  
 876      final_dataset = get_dataset(dataset_id=dataset_id)
 877      final_df = final_dataset.to_df()
 878      assert len(final_df) == 3
 879  
 880      categories = set()
 881      for _, row in final_df.iterrows():
 882          if row["tags"] and "category" in row["tags"]:
 883              categories.add(row["tags"]["category"])
 884      assert categories == {"geography", "science", "math"}
 885  
 886  
 887  def test_error_handling_filestore_backend(tmp_path):
 888      pytest.skip("FileStore is no longer supported.")
 889      file_uri = f"file://{tmp_path}"
 890      mlflow.set_tracking_uri(file_uri)
 891  
 892      with pytest.raises(MlflowException, match="not supported with FileStore") as exc:
 893          create_dataset(name="test")
 894      assert exc.value.error_code == "FEATURE_DISABLED"
 895  
 896      with pytest.raises(MlflowException, match="not supported with FileStore") as exc:
 897          get_dataset(dataset_id="test_id")
 898      assert exc.value.error_code == "FEATURE_DISABLED"
 899  
 900      with pytest.raises(MlflowException, match="not supported with FileStore") as exc:
 901          search_datasets()
 902      assert exc.value.error_code == "FEATURE_DISABLED"
 903  
 904      with pytest.raises(MlflowException, match="not supported with FileStore") as exc:
 905          delete_dataset(dataset_id="test_id")
 906      assert exc.value.error_code == "FEATURE_DISABLED"
 907  
 908  
 909  def test_single_experiment_id_handling(experiments):
 910      dataset = create_dataset(
 911          name="single_exp_test",
 912          experiment_id=experiments[0],
 913      )
 914  
 915      assert isinstance(dataset.experiment_ids, list)
 916      assert dataset.experiment_ids == [experiments[0]]
 917  
 918      results = search_datasets(experiment_ids=experiments[0])
 919      found_ids = [d.dataset_id for d in results]
 920      assert dataset.dataset_id in found_ids
 921  
 922  
 923  def test_trace_to_evaluation_dataset_integration(experiments):
 924      trace_inputs = [
 925          {"question": "What is MLflow?", "context": "ML platforms"},
 926          {"question": "What is Python?", "context": "programming"},
 927          {"question": "What is MLflow?", "context": "ML platforms"},
 928      ]
 929  
 930      created_trace_ids = []
 931      for i, inputs in enumerate(trace_inputs):
 932          with mlflow.start_run(experiment_id=experiments[i % 2]):
 933              with mlflow.start_span(name=f"qa_trace_{i}") as span:
 934                  span.set_inputs(inputs)
 935                  span.set_outputs({"answer": f"Answer for {inputs['question']}"})
 936                  span.set_attributes({"model": "test-model", "temperature": "0.7"})
 937                  trace_id = span.trace_id
 938                  created_trace_ids.append(trace_id)
 939  
 940                  mlflow.log_expectation(
 941                      trace_id=trace_id,
 942                      name="expected_answer",
 943                      value=f"Detailed answer for {inputs['question']}",
 944                  )
 945                  mlflow.log_expectation(
 946                      trace_id=trace_id,
 947                      name="quality_score",
 948                      value=0.85 + i * 0.05,
 949                  )
 950  
 951      traces = mlflow.search_traces(
 952          locations=[experiments[0], experiments[1]],
 953          max_results=10,
 954          return_type="list",
 955      )
 956      assert len(traces) == 3
 957  
 958      dataset = create_dataset(
 959          name="trace_eval_dataset",
 960          experiment_id=[experiments[0], experiments[1]],
 961          tags={"source": "test_traces", "type": "trace_integration"},
 962      )
 963  
 964      dataset.merge_records(traces)
 965  
 966      df = dataset.to_df()
 967      assert len(df) == 2
 968  
 969      for _, record in df.iterrows():
 970          assert "inputs" in record
 971          assert "question" in record["inputs"]
 972          assert "context" in record["inputs"]
 973          assert record.get("source_type") == "TRACE"
 974          assert record.get("source_id") is not None
 975  
 976      mlflow_records = df[df["inputs"].apply(lambda x: x.get("question") == "What is MLflow?")]
 977      assert len(mlflow_records) == 1
 978  
 979      with mlflow.start_run(experiment_id=experiments[0]):
 980          with mlflow.start_span(name="additional_trace") as span:
 981              span.set_inputs({"question": "What is Docker?", "context": "containers"})
 982              span.set_outputs({"answer": "Docker is a containerization platform"})
 983              span.set_attributes({"model": "test-model"})
 984  
 985      all_traces = mlflow.search_traces(
 986          locations=[experiments[0], experiments[1]], max_results=10, return_type="list"
 987      )
 988      assert len(all_traces) == 4
 989  
 990      new_trace = None
 991      for trace in all_traces:
 992          root_span = trace.data._get_root_span() if hasattr(trace, "data") else None
 993          if root_span and root_span.inputs and root_span.inputs.get("question") == "What is Docker?":
 994              new_trace = trace
 995              break
 996  
 997      assert new_trace is not None
 998  
 999      dataset.merge_records([new_trace])
1000  
1001      final_df = dataset.to_df()
1002      assert len(final_df) == 3
1003  
1004      retrieved = get_dataset(dataset_id=dataset.dataset_id)
1005      retrieved_df = retrieved.to_df()
1006      assert len(retrieved_df) == 3
1007  
1008      delete_dataset(dataset_id=dataset.dataset_id)
1009  
1010      with pytest.raises(MlflowException, match="Could not find|not found"):
1011          get_dataset(dataset_id=dataset.dataset_id)
1012  
1013      search_results = search_datasets(
1014          experiment_ids=[experiments[0], experiments[1]], max_results=100
1015      )
1016      found_dataset_ids = [d.dataset_id for d in search_results]
1017      assert dataset.dataset_id not in found_dataset_ids
1018  
1019      all_datasets = search_datasets(max_results=100)
1020      all_dataset_ids = [d.dataset_id for d in all_datasets]
1021      assert dataset.dataset_id not in all_dataset_ids
1022  
1023  
1024  def test_search_traces_dataframe_to_dataset_integration(experiments):
1025      for i in range(3):
1026          with mlflow.start_run(experiment_id=experiments[0]):
1027              with mlflow.start_span(name=f"test_span_{i}") as span:
1028                  span.set_inputs({"question": f"Question {i}?", "temperature": 0.7})
1029                  span.set_outputs({"answer": f"Answer {i}"})
1030  
1031                  mlflow.log_expectation(
1032                      trace_id=span.trace_id,
1033                      name="expected_answer",
1034                      value=f"Expected answer {i}",
1035                  )
1036                  mlflow.log_expectation(
1037                      trace_id=span.trace_id,
1038                      name="min_score",
1039                      value=0.8,
1040                  )
1041  
1042      traces_df = mlflow.search_traces(
1043          locations=[experiments[0]],
1044      )
1045  
1046      assert "trace" in traces_df.columns
1047      assert "assessments" in traces_df.columns
1048      assert len(traces_df) == 3
1049  
1050      dataset = create_dataset(
1051          name="traces_dataframe_dataset",
1052          experiment_id=experiments[0],
1053          tags={"source": "search_traces", "format": "dataframe"},
1054      )
1055  
1056      dataset.merge_records(traces_df)
1057  
1058      result_df = dataset.to_df()
1059      assert len(result_df) == 3
1060  
1061      for idx, row in result_df.iterrows():
1062          assert "inputs" in row
1063          assert "expectations" in row
1064          assert "source_type" in row
1065          assert row["source_type"] == "TRACE"
1066  
1067          assert "question" in row["inputs"]
1068          question_text = row["inputs"]["question"]
1069          assert question_text.startswith("Question ")
1070          assert question_text.endswith("?")
1071          question_num = int(question_text.replace("Question ", "").replace("?", ""))
1072          assert 0 <= question_num <= 2
1073  
1074          assert "expected_answer" in row["expectations"]
1075          assert f"Expected answer {question_num}" == row["expectations"]["expected_answer"]
1076          assert "min_score" in row["expectations"]
1077          assert row["expectations"]["min_score"] == 0.8
1078  
1079  
1080  def test_trace_to_dataset_with_assessments(client, experiment):
1081      trace_data = [
1082          {
1083              "inputs": {"question": "What is MLflow?", "context": "ML platforms"},
1084              "outputs": {"answer": "MLflow is an open source platform for ML lifecycle"},
1085              "expectations": {
1086                  "correctness": True,
1087                  "completeness": 0.8,
1088              },
1089          },
1090          {
1091              "inputs": {
1092                  "question": "What is Python?",
1093                  "context": "programming languages",
1094              },
1095              "outputs": {"answer": "Python is a high-level programming language"},
1096              "expectations": {
1097                  "correctness": True,
1098              },
1099          },
1100          {
1101              "inputs": {"question": "What is Docker?", "context": "containerization"},
1102              "outputs": {"answer": "Docker is a container platform"},
1103              "expectations": {},
1104          },
1105      ]
1106  
1107      created_traces = []
1108      for i, data in enumerate(trace_data):
1109          with mlflow.start_run(experiment_id=experiment):
1110              with mlflow.start_span(name=f"qa_trace_{i}") as span:
1111                  span.set_inputs(data["inputs"])
1112                  span.set_outputs(data["outputs"])
1113                  span.set_attributes({"model": "test-model", "temperature": 0.7})
1114                  trace_id = span.trace_id
1115  
1116                  for name, value in data["expectations"].items():
1117                      mlflow.log_expectation(
1118                          trace_id=trace_id,
1119                          name=name,
1120                          value=value,
1121                          span_id=span.span_id,
1122                      )
1123  
1124          trace = client.get_trace(trace_id)
1125          created_traces.append(trace)
1126  
1127      dataset = create_dataset(
1128          name="trace_assessment_dataset",
1129          experiment_id=[experiment],
1130          tags={"source": "trace_integration_test", "version": "1.0"},
1131      )
1132  
1133      dataset.merge_records(created_traces)
1134  
1135      df = dataset.to_df()
1136      assert len(df) == 3
1137  
1138      mlflow_record = df[df["inputs"].apply(lambda x: x.get("question") == "What is MLflow?")].iloc[0]
1139      assert mlflow_record["inputs"]["question"] == "What is MLflow?"
1140      assert mlflow_record["inputs"]["context"] == "ML platforms"
1141  
1142      assert "expectations" in mlflow_record
1143      assert mlflow_record["expectations"]["correctness"] is True
1144      assert mlflow_record["expectations"]["completeness"] == 0.8
1145  
1146      assert mlflow_record["source_type"] == "TRACE"
1147      assert mlflow_record["source_id"] is not None
1148  
1149      python_record = df[df["inputs"].apply(lambda x: x.get("question") == "What is Python?")].iloc[0]
1150      assert python_record["expectations"]["correctness"] is True
1151      assert len(python_record["expectations"]) == 1
1152  
1153      docker_record = df[df["inputs"].apply(lambda x: x.get("question") == "What is Docker?")].iloc[0]
1154      assert docker_record["expectations"] is None or docker_record["expectations"] == {}
1155  
1156      retrieved = get_dataset(dataset_id=dataset.dataset_id)
1157      assert retrieved.tags["source"] == "trace_integration_test"
1158      assert retrieved.tags["version"] == "1.0"
1159      assert set(retrieved.experiment_ids) == {experiment}
1160  
1161  
1162  def test_trace_deduplication_with_assessments(client, experiment):
1163      trace_ids = []
1164      for i in range(3):
1165          with mlflow.start_run(experiment_id=experiment):
1166              with mlflow.start_span(name=f"duplicate_trace_{i}") as span:
1167                  span.set_inputs({"question": "What is AI?", "model": "gpt-4"})
1168                  span.set_outputs({"answer": f"AI is artificial intelligence (version {i})"})
1169                  trace_id = span.trace_id
1170                  trace_ids.append(trace_id)
1171  
1172                  mlflow.log_expectation(
1173                      trace_id=trace_id,
1174                      name="quality",
1175                      value=0.5 + i * 0.2,
1176                      span_id=span.span_id,
1177                  )
1178  
1179      traces = [client.get_trace(tid) for tid in trace_ids]
1180  
1181      dataset = create_dataset(
1182          name="dedup_test",
1183          experiment_id=experiment,
1184          tags={"test": "deduplication"},
1185      )
1186      dataset.merge_records(traces)
1187  
1188      df = dataset.to_df()
1189      assert len(df) == 1
1190  
1191      record = df.iloc[0]
1192      assert record["inputs"]["question"] == "What is AI?"
1193      assert record["expectations"]["quality"] == 0.9
1194      assert record["source_id"] in trace_ids
1195  
1196  
1197  def test_mixed_record_types_with_traces(client, experiment):
1198      with mlflow.start_run(experiment_id=experiment):
1199          with mlflow.start_span(name="mixed_test_trace") as span:
1200              span.set_inputs({"question": "What is ML?", "context": "machine learning"})
1201              span.set_outputs({"answer": "ML stands for Machine Learning"})
1202              trace_id = span.trace_id
1203  
1204              mlflow.log_expectation(
1205                  trace_id=trace_id,
1206                  name="accuracy",
1207                  value=0.95,
1208                  span_id=span.span_id,
1209              )
1210  
1211      trace = client.get_trace(trace_id)
1212  
1213      dataset = create_dataset(
1214          name="mixed_records_test",
1215          experiment_id=experiment,
1216          tags={"type": "mixed", "test": "true"},
1217      )
1218  
1219      manual_records = [
1220          {
1221              "inputs": {"question": "What is AI?"},
1222              "expectations": {"correctness": True},
1223              "tags": {"source": "manual"},
1224          },
1225          {
1226              "inputs": {"question": "What is Python?"},
1227              "expectations": {"correctness": True},
1228              "tags": {"source": "manual"},
1229          },
1230      ]
1231      dataset.merge_records(manual_records)
1232  
1233      df1 = dataset.to_df()
1234      assert len(df1) == 2
1235  
1236      dataset.merge_records([trace])
1237  
1238      df2 = dataset.to_df()
1239      assert len(df2) == 3
1240  
1241      ml_record = df2[df2["inputs"].apply(lambda x: x.get("question") == "What is ML?")].iloc[0]
1242      assert ml_record["expectations"]["accuracy"] == 0.95
1243      assert ml_record["source_type"] == "TRACE"
1244  
1245      manual_questions = {"What is AI?", "What is Python?"}
1246      manual_records_df = df2[df2["inputs"].apply(lambda x: x.get("question") in manual_questions)]
1247      assert len(manual_records_df) == 2
1248  
1249      for _, record in manual_records_df.iterrows():
1250          assert record.get("source_type") != "TRACE"
1251  
1252  
1253  def test_trace_without_root_span_inputs(client, experiment):
1254      with mlflow.start_run(experiment_id=experiment):
1255          with mlflow.start_span(name="no_inputs_trace") as span:
1256              span.set_outputs({"result": "some output"})
1257              trace_id = span.trace_id
1258  
1259      trace = client.get_trace(trace_id)
1260  
1261      dataset = create_dataset(
1262          name="no_inputs_test",
1263          experiment_id=experiment,
1264      )
1265  
1266      dataset.merge_records([trace])
1267  
1268      df = dataset.to_df()
1269      assert len(df) == 1
1270      assert df.iloc[0]["inputs"] == {}
1271      assert df.iloc[0]["expectations"] is None or df.iloc[0]["expectations"] == {}
1272  
1273  
1274  def test_error_handling_invalid_trace_types(client, experiment):
1275      dataset = create_dataset(
1276          name="error_test",
1277          experiment_id=experiment,
1278      )
1279  
1280      with mlflow.start_run(experiment_id=experiment):
1281          with mlflow.start_span(name="valid_trace") as span:
1282              span.set_inputs({"q": "test"})
1283              trace_id = span.trace_id
1284  
1285      valid_trace = client.get_trace(trace_id)
1286  
1287      with pytest.raises(MlflowException, match="Mixed types in trace list"):
1288          dataset.merge_records([valid_trace, {"inputs": {"q": "dict record"}}])
1289  
1290      with pytest.raises(MlflowException, match="Mixed types in trace list"):
1291          dataset.merge_records([valid_trace, "not a trace"])
1292  
1293  
1294  def test_trace_integration_end_to_end(client, experiment):
1295      traces_to_create = [
1296          {
1297              "name": "successful_qa",
1298              "inputs": {"question": "What is the capital of France?", "language": "en"},
1299              "outputs": {"answer": "Paris", "confidence": 0.99},
1300              "expectations": {"correctness": True, "confidence_threshold": 0.8},
1301          },
1302          {
1303              "name": "incorrect_qa",
1304              "inputs": {"question": "What is 2+2?", "language": "en"},
1305              "outputs": {"answer": "5", "confidence": 0.5},
1306              "expectations": {"correctness": False},
1307          },
1308          {
1309              "name": "multilingual_qa",
1310              "inputs": {"question": "¿Cómo estás?", "language": "es"},
1311              "outputs": {"answer": "I'm doing well, thank you!", "confidence": 0.9},
1312              "expectations": {"language_match": False, "politeness": True},
1313          },
1314      ]
1315  
1316      created_trace_ids = []
1317      for trace_config in traces_to_create:
1318          with mlflow.start_run(experiment_id=experiment):
1319              with mlflow.start_span(name=trace_config["name"]) as span:
1320                  span.set_inputs(trace_config["inputs"])
1321                  span.set_outputs(trace_config["outputs"])
1322                  span.set_attributes({
1323                      "model": "test-llm-v1",
1324                      "temperature": 0.7,
1325                      "max_tokens": 100,
1326                  })
1327                  trace_id = span.trace_id
1328                  created_trace_ids.append(trace_id)
1329  
1330                  for exp_name, exp_value in trace_config["expectations"].items():
1331                      mlflow.log_expectation(
1332                          trace_id=trace_id,
1333                          name=exp_name,
1334                          value=exp_value,
1335                          span_id=span.span_id,
1336                          metadata={"trace_name": trace_config["name"]},
1337                      )
1338  
1339      dataset = create_dataset(
1340          name="comprehensive_trace_test",
1341          experiment_id=[experiment],
1342          tags={
1343              "test_type": "end_to_end",
1344              "model": "test-llm-v1",
1345              "language": "multilingual",
1346          },
1347      )
1348  
1349      traces = [client.get_trace(tid) for tid in created_trace_ids]
1350      dataset.merge_records(traces)
1351  
1352      df = dataset.to_df()
1353      assert len(df) == 3
1354  
1355      french_record = df[df["inputs"].apply(lambda x: "France" in str(x.get("question", "")))].iloc[0]
1356      assert french_record["expectations"]["correctness"] is True
1357      assert french_record["expectations"]["confidence_threshold"] == 0.8
1358  
1359      math_record = df[df["inputs"].apply(lambda x: "2+2" in str(x.get("question", "")))].iloc[0]
1360      assert math_record["expectations"]["correctness"] is False
1361  
1362      spanish_record = df[df["inputs"].apply(lambda x: x.get("language") == "es")].iloc[0]
1363      assert spanish_record["expectations"]["language_match"] is False
1364      assert spanish_record["expectations"]["politeness"] is True
1365  
1366      retrieved_dataset = get_dataset(dataset_id=dataset.dataset_id)
1367      retrieved_df = retrieved_dataset.to_df()
1368      assert len(retrieved_df) == 3
1369      assert retrieved_dataset.tags["model"] == "test-llm-v1"
1370  
1371      additional_records = [
1372          {
1373              "inputs": {"question": "What is Python?", "language": "en"},
1374              "expectations": {"technical_accuracy": True},
1375              "tags": {"source": "manual_addition"},
1376          }
1377      ]
1378      retrieved_dataset.merge_records(additional_records)
1379  
1380      final_df = retrieved_dataset.to_df()
1381      assert len(final_df) == 4
1382  
1383      trace_records = final_df[final_df["source_type"] == "TRACE"]
1384      assert len(trace_records) == 3
1385  
1386      manual_records = final_df[final_df["source_type"] != "TRACE"]
1387      assert len(manual_records) == 1
1388  
1389  
1390  def test_dataset_pagination_transparency_large_records(experiments):
1391      dataset = create_dataset(
1392          name="test_pagination_transparency",
1393          experiment_id=experiments[0],
1394          tags={"test": "large_dataset"},
1395      )
1396  
1397      large_records = [
1398          {
1399              "inputs": {"question": f"Question {i}", "index": i},
1400              "expectations": {"answer": f"Answer {i}", "score": i * 0.01},
1401          }
1402          for i in range(150)
1403      ]
1404  
1405      dataset.merge_records(large_records)
1406  
1407      all_records = dataset._mlflow_dataset.records
1408      assert len(all_records) == 150
1409  
1410      record_indices = {record.inputs["index"] for record in all_records}
1411      expected_indices = set(range(150))
1412      assert record_indices == expected_indices
1413  
1414      record_scores = {record.expectations["score"] for record in all_records}
1415      expected_scores = {i * 0.01 for i in range(150)}
1416      assert record_scores == expected_scores
1417  
1418      df = dataset.to_df()
1419      assert len(df) == 150
1420  
1421      df_indices = {row["index"] for row in df["inputs"]}
1422      assert df_indices == expected_indices
1423  
1424      assert not hasattr(dataset, "page_token")
1425      assert not hasattr(dataset, "next_page_token")
1426      assert not hasattr(dataset, "max_results")
1427  
1428      second_access = dataset._mlflow_dataset.records
1429      assert second_access is all_records
1430  
1431      dataset._mlflow_dataset._records = None
1432      refreshed_records = dataset._mlflow_dataset.records
1433      assert len(refreshed_records) == 150
1434  
1435  
1436  def test_dataset_internal_pagination_with_mock(experiments):
1437      from mlflow.tracking._tracking_service.utils import _get_store
1438  
1439      dataset = create_dataset(
1440          name="test_internal_pagination",
1441          experiment_id=experiments[0],
1442          tags={"test": "pagination_mock"},
1443      )
1444  
1445      records = [
1446          {"inputs": {"question": f"Q{i}", "id": i}, "expectations": {"answer": f"A{i}"}}
1447          for i in range(75)
1448      ]
1449  
1450      dataset.merge_records(records)
1451  
1452      dataset._mlflow_dataset._records = None
1453  
1454      store = _get_store()
1455      with mock.patch.object(
1456          store, "_load_dataset_records", wraps=store._load_dataset_records
1457      ) as mock_load:
1458          accessed_records = dataset._mlflow_dataset.records
1459  
1460          mock_load.assert_called_once_with(dataset.dataset_id, max_results=None)
1461          assert len(accessed_records) == 75
1462  
1463      dataset._mlflow_dataset._records = None
1464  
1465      with mock.patch.object(
1466          store, "_load_dataset_records", wraps=store._load_dataset_records
1467      ) as mock_load:
1468          df = dataset.to_df()
1469  
1470          mock_load.assert_called_once_with(dataset.dataset_id, max_results=None)
1471          assert len(df) == 75
1472  
1473  
1474  def test_dataset_experiment_associations(experiments):
1475      from mlflow.genai.datasets import (
1476          add_dataset_to_experiments,
1477          remove_dataset_from_experiments,
1478      )
1479  
1480      dataset = create_dataset(
1481          name="test_associations",
1482          experiment_id=experiments[0],
1483          tags={"test": "associations"},
1484      )
1485  
1486      initial_exp_ids = dataset.experiment_ids
1487      assert experiments[0] in initial_exp_ids
1488  
1489      updated = add_dataset_to_experiments(
1490          dataset_id=dataset.dataset_id, experiment_ids=[experiments[1], experiments[2]]
1491      )
1492      assert experiments[0] in updated.experiment_ids
1493      assert experiments[1] in updated.experiment_ids
1494      assert experiments[2] in updated.experiment_ids
1495      assert len(updated.experiment_ids) == 3
1496  
1497      result = add_dataset_to_experiments(
1498          dataset_id=dataset.dataset_id, experiment_ids=[experiments[1], experiments[2]]
1499      )
1500      assert len(result.experiment_ids) == 3
1501      assert all(exp in result.experiment_ids for exp in experiments)
1502  
1503      removed = remove_dataset_from_experiments(
1504          dataset_id=dataset.dataset_id, experiment_ids=[experiments[1], experiments[2]]
1505      )
1506      assert experiments[1] not in removed.experiment_ids
1507      assert experiments[2] not in removed.experiment_ids
1508      assert experiments[0] in removed.experiment_ids
1509      assert len(removed.experiment_ids) == 1
1510  
1511      with mock.patch("mlflow.store.tracking.sqlalchemy_store._logger.warning") as mock_warning:
1512          idempotent = remove_dataset_from_experiments(
1513              dataset_id=dataset.dataset_id,
1514              experiment_ids=[experiments[1], experiments[2]],
1515          )
1516          assert mock_warning.call_count == 2
1517          assert "was not associated" in mock_warning.call_args_list[0][0][0]
1518  
1519      assert len(idempotent.experiment_ids) == 1
1520  
1521  
1522  def test_dataset_associations_filestore_blocking(tmp_path):
1523      pytest.skip("FileStore is no longer supported.")
1524      from mlflow.genai.datasets import (
1525          add_dataset_to_experiments,
1526          remove_dataset_from_experiments,
1527      )
1528  
1529      mlflow.set_tracking_uri(tmp_path.as_uri())
1530  
1531      with pytest.raises(NotImplementedError, match="not supported with FileStore"):
1532          add_dataset_to_experiments(dataset_id="d-test123", experiment_ids=["1", "2"])
1533  
1534      with pytest.raises(NotImplementedError, match="not supported with FileStore"):
1535          remove_dataset_from_experiments(dataset_id="d-test123", experiment_ids=["1"])
1536  
1537  
1538  def test_evaluation_dataset_tags_crud_workflow(experiments):
1539      dataset = create_dataset(
1540          name="test_tags_crud",
1541          experiment_id=experiments[0],
1542      )
1543      initial_tags = dataset.tags.copy()
1544  
1545      set_dataset_tags(
1546          dataset_id=dataset.dataset_id,
1547          tags={
1548              "team": "ml-platform",
1549              "project": "evaluation",
1550              "priority": "high",
1551          },
1552      )
1553  
1554      dataset = get_dataset(dataset_id=dataset.dataset_id)
1555      expected_tags = initial_tags.copy()
1556      expected_tags.update({
1557          "team": "ml-platform",
1558          "project": "evaluation",
1559          "priority": "high",
1560      })
1561      assert dataset.tags == expected_tags
1562  
1563      set_dataset_tags(
1564          dataset_id=dataset.dataset_id,
1565          tags={
1566              "priority": "medium",
1567              "status": "active",
1568          },
1569      )
1570  
1571      dataset = get_dataset(dataset_id=dataset.dataset_id)
1572      expected_tags = initial_tags.copy()
1573      expected_tags.update({
1574          "team": "ml-platform",
1575          "project": "evaluation",
1576          "priority": "medium",
1577          "status": "active",
1578      })
1579      assert dataset.tags == expected_tags
1580  
1581      delete_dataset_tag(
1582          dataset_id=dataset.dataset_id,
1583          key="priority",
1584      )
1585  
1586      dataset = get_dataset(dataset_id=dataset.dataset_id)
1587      expected_tags = initial_tags.copy()
1588      expected_tags.update({
1589          "team": "ml-platform",
1590          "project": "evaluation",
1591          "status": "active",
1592      })
1593      assert dataset.tags == expected_tags
1594  
1595      delete_dataset(dataset_id=dataset.dataset_id)
1596  
1597      with pytest.raises(MlflowException, match="Could not find|not found"):
1598          get_dataset(dataset_id=dataset.dataset_id)
1599  
1600      with pytest.raises(MlflowException, match="Could not find|not found"):
1601          set_dataset_tags(
1602              dataset_id=dataset.dataset_id,
1603              tags={"should": "fail"},
1604          )
1605  
1606      delete_dataset_tag(dataset_id=dataset.dataset_id, key="status")
1607  
1608  
1609  def test_set_dataset_tags_databricks(mock_databricks_environment):
1610      with pytest.raises(NotImplementedError, match="tag operations are not available"):
1611          set_dataset_tags(dataset_id="test", tags={"key": "value"})
1612  
1613  
1614  def test_delete_dataset_tag_databricks(mock_databricks_environment):
1615      with pytest.raises(NotImplementedError, match="tag operations are not available"):
1616          delete_dataset_tag(dataset_id="test", key="key")
1617  
1618  
1619  def test_dataset_schema_evolution_and_log_input(experiments):
1620      dataset = create_dataset(
1621          name="schema_evolution_test",
1622          experiment_id=[experiments[0]],
1623          tags={"test": "schema_evolution", "mlflow.user": "test_user"},
1624      )
1625  
1626      stage1_records = [
1627          {
1628              "inputs": {"prompt": "What is MLflow?"},
1629              "expectations": {"response": "MLflow is a platform"},
1630          }
1631      ]
1632      dataset.merge_records(stage1_records)
1633  
1634      ds1 = get_dataset(dataset_id=dataset.dataset_id)
1635      schema1 = json.loads(ds1.schema)
1636      assert schema1 is not None
1637      assert "prompt" in schema1["inputs"]
1638      assert schema1["inputs"]["prompt"] == "string"
1639      assert len(schema1["inputs"]) == 1
1640      assert len(schema1["expectations"]) == 1
1641  
1642      stage2_records = [
1643          {
1644              "inputs": {
1645                  "prompt": "Explain Python",
1646                  "temperature": 0.7,
1647                  "max_length": 500,
1648                  "top_p": 0.95,
1649              },
1650              "expectations": {
1651                  "response": "Python is a programming language",
1652                  "quality_score": 0.85,
1653                  "token_count": 127,
1654              },
1655          }
1656      ]
1657      dataset.merge_records(stage2_records)
1658  
1659      ds2 = get_dataset(dataset_id=dataset.dataset_id)
1660      schema2 = json.loads(ds2.schema)
1661      assert "temperature" in schema2["inputs"]
1662      assert schema2["inputs"]["temperature"] == "float"
1663      assert "max_length" in schema2["inputs"]
1664      assert schema2["inputs"]["max_length"] == "integer"
1665      assert len(schema2["inputs"]) == 4
1666      assert len(schema2["expectations"]) == 3
1667  
1668      stage3_records = [
1669          {
1670              "inputs": {
1671                  "prompt": "Complex query",
1672                  "streaming": True,
1673                  "stop_sequences": ["\n\n", "END"],
1674                  "config": {"model": "gpt-4", "version": "1.0"},
1675              },
1676              "expectations": {
1677                  "response": "Complex response",
1678                  "is_valid": True,
1679                  "citations": ["source1", "source2"],
1680                  "metadata": {"confidence": 0.9},
1681              },
1682          }
1683      ]
1684      dataset.merge_records(stage3_records)
1685  
1686      ds3 = get_dataset(dataset_id=dataset.dataset_id)
1687      schema3 = json.loads(ds3.schema)
1688  
1689      assert schema3["inputs"]["streaming"] == "boolean"
1690      assert schema3["inputs"]["stop_sequences"] == "array"
1691      assert schema3["inputs"]["config"] == "object"
1692      assert schema3["expectations"]["is_valid"] == "boolean"
1693      assert schema3["expectations"]["citations"] == "array"
1694      assert schema3["expectations"]["metadata"] == "object"
1695  
1696      assert "prompt" in schema3["inputs"]
1697      assert "temperature" in schema3["inputs"]
1698      assert "quality_score" in schema3["expectations"]
1699  
1700      with mlflow.start_run(experiment_id=experiments[0]) as run:
1701          mlflow.log_input(dataset, context="evaluation")
1702  
1703          mlflow.log_metrics({"accuracy": 0.92, "f1_score": 0.89})
1704  
1705      run_data = mlflow.get_run(run.info.run_id)
1706      assert run_data.inputs is not None
1707      assert run_data.inputs.dataset_inputs is not None
1708      assert len(run_data.inputs.dataset_inputs) > 0
1709  
1710      dataset_input = run_data.inputs.dataset_inputs[0]
1711      assert dataset_input.dataset.name == "schema_evolution_test"
1712      assert dataset_input.dataset.source_type == "mlflow_evaluation_dataset"
1713  
1714      tag_dict = {tag.key: tag.value for tag in dataset_input.tags}
1715      assert "mlflow.data.context" in tag_dict
1716      assert tag_dict["mlflow.data.context"] == "evaluation"
1717  
1718      final_dataset = get_dataset(dataset_id=dataset.dataset_id)
1719      final_schema = json.loads(final_dataset.schema)
1720  
1721      assert "inputs" in final_schema
1722      assert "expectations" in final_schema
1723      assert "version" in final_schema
1724      assert final_schema["version"] == "1.0"
1725  
1726      profile = json.loads(final_dataset.profile)
1727      assert profile is not None
1728      assert profile["num_records"] == 3
1729  
1730      consistency_records = [
1731          {
1732              "inputs": {"prompt": "Another test", "temperature": 0.5, "max_length": 200},
1733              "expectations": {"response": "Another response", "quality_score": 0.75},
1734          }
1735      ]
1736      dataset.merge_records(consistency_records)
1737  
1738      consistent_dataset = get_dataset(dataset_id=dataset.dataset_id)
1739      consistent_schema = json.loads(consistent_dataset.schema)
1740  
1741      assert set(consistent_schema["inputs"].keys()) == set(final_schema["inputs"].keys())
1742      assert set(consistent_schema["expectations"].keys()) == set(final_schema["expectations"].keys())
1743  
1744      consistent_profile = json.loads(consistent_dataset.profile)
1745      assert consistent_profile["num_records"] == 4
1746      delete_dataset_tag(dataset_id="test", key="key")
1747  
1748  
1749  def test_deprecated_parameter_substitution(experiment):
1750      with warnings.catch_warnings(record=True) as w:
1751          warnings.simplefilter("always")
1752  
1753          dataset = create_dataset(
1754              uc_table_name="test_dataset_deprecated",
1755              experiment_id=experiment,
1756              tags={"test": "deprecated_parameter"},
1757          )
1758  
1759          assert len(w) == 1
1760          assert issubclass(w[0].category, FutureWarning)
1761          assert "uc_table_name" in str(w[0].message)
1762          assert "deprecated" in str(w[0].message).lower()
1763          assert "name" in str(w[0].message)
1764  
1765          assert dataset.name == "test_dataset_deprecated"
1766          assert dataset.tags["test"] == "deprecated_parameter"
1767  
1768      with pytest.raises(ValueError, match="Cannot specify both.*uc_table_name.*and.*name"):
1769          create_dataset(
1770              uc_table_name="old_name",
1771              name="new_name",
1772              experiment_id=experiment,
1773          )
1774  
1775      with warnings.catch_warnings(record=True) as w:
1776          warnings.simplefilter("always")
1777  
1778          with pytest.raises(ValueError, match="name.*only supported in Databricks"):
1779              delete_dataset(uc_table_name="test_dataset_deprecated")
1780  
1781          assert len(w) == 1
1782          assert issubclass(w[0].category, FutureWarning)
1783          assert "uc_table_name" in str(w[0].message)
1784  
1785      delete_dataset(dataset_id=dataset.dataset_id)
1786  
1787  
1788  def test_create_dataset_uses_active_experiment_when_not_specified(client):
1789      exp_id = mlflow.create_experiment("test_active_experiment")
1790      mlflow.set_experiment(experiment_id=exp_id)
1791  
1792      dataset = create_dataset(name="test_with_active_exp")
1793  
1794      assert dataset.experiment_ids == [exp_id]
1795  
1796      from mlflow.tracking import fluent
1797  
1798      fluent._active_experiment_id = None
1799  
1800  
1801  def test_create_dataset_with_no_active_experiment(client):
1802      from mlflow.tracking import fluent
1803  
1804      fluent._active_experiment_id = None
1805  
1806      dataset = create_dataset(name="test_no_active_exp")
1807  
1808      assert dataset.experiment_ids == ["0"]
1809  
1810  
1811  def test_create_dataset_explicit_overrides_active_experiment(client):
1812      active_exp = mlflow.create_experiment("active_exp")
1813      explicit_exp = mlflow.create_experiment("explicit_exp")
1814  
1815      mlflow.set_experiment(experiment_id=active_exp)
1816  
1817      dataset = create_dataset(name="test_explicit_override", experiment_id=explicit_exp)
1818  
1819      assert dataset.experiment_ids == [explicit_exp]
1820  
1821      from mlflow.tracking import fluent
1822  
1823      fluent._active_experiment_id = None
1824  
1825  
1826  def test_create_dataset_none_uses_active_experiment(client):
1827      exp_id = mlflow.create_experiment("test_none_experiment")
1828      mlflow.set_experiment(experiment_id=exp_id)
1829  
1830      dataset = create_dataset(name="test_none_exp", experiment_id=None)
1831  
1832      assert dataset.experiment_ids == [exp_id]
1833  
1834      from mlflow.tracking import fluent
1835  
1836      fluent._active_experiment_id = None
1837  
1838  
1839  def test_source_type_inference():
1840      exp = mlflow.create_experiment("test_source_inference")
1841      dataset = create_dataset(
1842          name="test_source_inference",
1843          experiment_id=exp,
1844          tags={"test": "source_inference"},
1845      )
1846  
1847      human_records = [
1848          {
1849              "inputs": {"question": "What is MLflow?"},
1850              "expectations": {"answer": "MLflow is an ML platform", "quality": 0.9},
1851          },
1852          {
1853              "inputs": {"question": "How to track experiments?"},
1854              "expectations": {"answer": "Use mlflow.start_run()", "quality": 0.85},
1855          },
1856      ]
1857      dataset.merge_records(human_records)
1858  
1859      df = dataset.to_df()
1860      human_sources = df[df["source_type"] == DatasetRecordSourceType.HUMAN.value]
1861      assert len(human_sources) == 2
1862  
1863      code_records = [{"inputs": {"question": f"Generated question {i}"}} for i in range(3)]
1864      dataset.merge_records(code_records)
1865  
1866      df = dataset.to_df()
1867      code_sources = df[df["source_type"] == DatasetRecordSourceType.CODE.value]
1868      assert len(code_sources) == 3
1869  
1870      explicit_records = [
1871          {
1872              "inputs": {"question": "Document-based question"},
1873              "expectations": {"answer": "From document"},
1874              "source": {
1875                  "source_type": DatasetRecordSourceType.DOCUMENT.value,
1876                  "source_data": {"source_id": "doc123", "page": 5},
1877              },
1878          }
1879      ]
1880      dataset.merge_records(explicit_records)
1881  
1882      df = dataset.to_df()
1883      doc_sources = df[df["source_type"] == DatasetRecordSourceType.DOCUMENT.value]
1884      assert len(doc_sources) == 1
1885      assert doc_sources.iloc[0]["source_id"] == "doc123"
1886  
1887      empty_exp_records = [{"inputs": {"question": "Has empty expectations"}, "expectations": {}}]
1888      dataset.merge_records(empty_exp_records)
1889  
1890      df = dataset.to_df()
1891      last_record = df.iloc[-1]
1892      assert last_record["source_type"] not in [
1893          DatasetRecordSourceType.HUMAN.value,
1894          DatasetRecordSourceType.CODE.value,
1895      ]
1896  
1897      explicit_trace = [
1898          {
1899              "inputs": {"question": "From trace"},
1900              "source": {
1901                  "source_type": DatasetRecordSourceType.TRACE.value,
1902                  "source_data": {"trace_id": "trace123"},
1903              },
1904          }
1905      ]
1906      dataset.merge_records(explicit_trace)
1907  
1908      df = dataset.to_df()
1909      trace_sources = df[df["source_type"] == DatasetRecordSourceType.TRACE.value]
1910      assert len(trace_sources) == 1, f"Expected 1 TRACE source, got {len(trace_sources)}"
1911      assert trace_sources.iloc[0]["source_id"] == "trace123"
1912  
1913      source_counts = df["source_type"].value_counts()
1914      assert source_counts.get(DatasetRecordSourceType.HUMAN.value, 0) == 2
1915      assert source_counts.get(DatasetRecordSourceType.CODE.value, 0) == 3
1916      assert source_counts.get(DatasetRecordSourceType.DOCUMENT.value, 0) == 1
1917      assert source_counts.get(DatasetRecordSourceType.TRACE.value, 0) == 1
1918  
1919      delete_dataset(dataset_id=dataset.dataset_id)
1920  
1921  
1922  def test_trace_source_type_detection():
1923      exp = mlflow.create_experiment("test_trace_source_detection")
1924  
1925      trace_ids = []
1926      for i in range(3):
1927          with mlflow.start_run(experiment_id=exp):
1928              with mlflow.start_span(name=f"test_span_{i}") as span:
1929                  span.set_inputs({"question": f"Question {i}", "context": f"Context {i}"})
1930                  span.set_outputs({"answer": f"Answer {i}"})
1931                  trace_ids.append(span.trace_id)
1932  
1933                  if i < 2:
1934                      mlflow.log_expectation(
1935                          trace_id=span.trace_id,
1936                          name="quality",
1937                          value=0.8 + i * 0.05,
1938                          span_id=span.span_id,
1939                      )
1940  
1941      dataset = create_dataset(
1942          name="test_trace_sources",
1943          experiment_id=exp,
1944          tags={"test": "trace_source_detection"},
1945      )
1946  
1947      client = mlflow.MlflowClient()
1948      traces = [client.get_trace(tid) for tid in trace_ids]
1949      dataset.merge_records(traces)
1950  
1951      df = dataset.to_df()
1952      trace_sources = df[df["source_type"] == DatasetRecordSourceType.TRACE.value]
1953      assert len(trace_sources) == 3
1954  
1955      for trace_id in trace_ids:
1956          matching_records = df[df["source_id"] == trace_id]
1957          assert len(matching_records) == 1
1958  
1959      dataset2 = create_dataset(
1960          name="test_trace_sources_df",
1961          experiment_id=exp,
1962          tags={"test": "trace_source_df"},
1963      )
1964  
1965      traces_df = mlflow.search_traces(locations=[exp])
1966      assert not traces_df.empty
1967      dataset2.merge_records(traces_df)
1968  
1969      df2 = dataset2.to_df()
1970      trace_sources2 = df2[df2["source_type"] == DatasetRecordSourceType.TRACE.value]
1971      assert len(trace_sources2) == len(traces_df)
1972  
1973      dataset3 = create_dataset(
1974          name="test_trace_sources_list",
1975          experiment_id=exp,
1976          tags={"test": "trace_source_list"},
1977      )
1978  
1979      traces_list = mlflow.search_traces(locations=[exp], return_type="list")
1980      assert len(traces_list) > 0
1981      dataset3.merge_records(traces_list)
1982  
1983      df3 = dataset3.to_df()
1984      trace_sources3 = df3[df3["source_type"] == DatasetRecordSourceType.TRACE.value]
1985      assert len(trace_sources3) == len(traces_list)
1986  
1987      df_with_expectations = df[df["expectations"].apply(lambda x: bool(x) and len(x) > 0)]
1988      assert len(df_with_expectations) == 2
1989  
1990      delete_dataset(dataset_id=dataset.dataset_id)
1991      delete_dataset(dataset_id=dataset2.dataset_id)
1992      delete_dataset(dataset_id=dataset3.dataset_id)
1993  
1994  
1995  def test_create_dataset_empty_list_stays_empty(client):
1996      exp_id = mlflow.create_experiment("test_empty_list")
1997      mlflow.set_experiment(experiment_id=exp_id)
1998  
1999      dataset = create_dataset(name="test_empty_list", experiment_id=[])
2000  
2001      assert dataset.experiment_ids == []
2002  
2003      from mlflow.tracking import fluent
2004  
2005      fluent._active_experiment_id = None
2006  
2007  
2008  def test_search_datasets_filter_string_edge_cases(client):
2009      exp_id = mlflow.create_experiment("test_filter_edge_cases")
2010  
2011      dataset = create_dataset(name="test_dataset", experiment_id=exp_id, tags={"test": "value"})
2012  
2013      with mock.patch("mlflow.tracking.client.MlflowClient.search_datasets") as mock_search:
2014          mock_search.return_value = mock.MagicMock(token=None, items=[dataset])
2015  
2016          search_datasets(experiment_ids=exp_id, filter_string=None)
2017          call_args = mock_search.call_args
2018          filter_arg = call_args.kwargs.get("filter_string")
2019          assert "created_time >=" in filter_arg
2020  
2021          mock_search.reset_mock()
2022  
2023          search_datasets(experiment_ids=exp_id, filter_string=[])
2024          call_args = mock_search.call_args
2025          filter_arg = call_args.kwargs.get("filter_string")
2026          assert "created_time >=" in filter_arg
2027  
2028          mock_search.reset_mock()
2029  
2030          search_datasets(experiment_ids=exp_id, filter_string="")
2031          call_args = mock_search.call_args
2032          filter_arg = call_args.kwargs.get("filter_string")
2033          assert "created_time >=" in filter_arg
2034  
2035          mock_search.reset_mock()
2036  
2037          search_datasets(experiment_ids=exp_id, filter_string='name = "test"')
2038          call_args = mock_search.call_args
2039          filter_arg = call_args.kwargs.get("filter_string")
2040          assert filter_arg == 'name = "test"'
2041  
2042  
2043  def test_wrapper_type_is_actually_returned_not_entity(experiments):
2044      dataset = create_dataset(
2045          name="test_wrapper",
2046          experiment_id=experiments[0],
2047          tags={"test": "wrapper_check"},
2048      )
2049  
2050      assert isinstance(dataset, WrapperEvaluationDataset)
2051      assert not isinstance(dataset, EntityEvaluationDataset)
2052      assert hasattr(dataset, "_mlflow_dataset")
2053      assert dataset._mlflow_dataset is not None
2054      assert isinstance(dataset._mlflow_dataset, EntityEvaluationDataset)
2055  
2056  
2057  def test_wrapper_delegates_all_properties_correctly(experiments):
2058      dataset = create_dataset(
2059          name="test_delegation",
2060          experiment_id=experiments[0],
2061          tags={"env": "test", "version": "1.0"},
2062      )
2063  
2064      assert dataset.name == "test_delegation"
2065      assert dataset.dataset_id.startswith("d-")
2066      assert dataset.tags["env"] == "test"
2067      assert dataset.tags["version"] == "1.0"
2068      assert experiments[0] in dataset.experiment_ids
2069      assert dataset.created_time > 0
2070      assert dataset.last_update_time > 0
2071      assert dataset.digest is not None
2072      assert hasattr(dataset, "source")
2073      assert dataset.source._get_source_type() == "mlflow_evaluation_dataset"
2074  
2075  
2076  def test_get_and_search_return_wrapper_not_entity(experiments):
2077      created = create_dataset(
2078          name="test_get_wrapper",
2079          experiment_id=experiments[0],
2080          tags={"test": "get"},
2081      )
2082  
2083      retrieved = get_dataset(dataset_id=created.dataset_id)
2084      assert isinstance(retrieved, WrapperEvaluationDataset)
2085      assert not isinstance(retrieved, EntityEvaluationDataset)
2086      assert retrieved.dataset_id == created.dataset_id
2087      assert retrieved.name == created.name
2088  
2089      results = search_datasets(
2090          experiment_ids=experiments[0],
2091          filter_string="name = 'test_get_wrapper'",
2092      )
2093      assert len(results) == 1
2094      assert isinstance(results[0], WrapperEvaluationDataset)
2095      assert not isinstance(results[0], EntityEvaluationDataset)
2096  
2097  
2098  def test_wrapper_vs_direct_client_usage(experiments):
2099      client = MlflowClient()
2100  
2101      entity_dataset = client.create_dataset(
2102          name="test_client_direct",
2103          experiment_id=experiments[0],
2104          tags={"direct": "client"},
2105      )
2106      assert isinstance(entity_dataset, EntityEvaluationDataset)
2107      assert not isinstance(entity_dataset, WrapperEvaluationDataset)
2108  
2109      wrapped_dataset = create_dataset(
2110          name="test_wrapped",
2111          experiment_id=experiments[0],
2112          tags={"wrapped": "fluent"},
2113      )
2114      assert isinstance(wrapped_dataset, WrapperEvaluationDataset)
2115      assert not isinstance(wrapped_dataset, EntityEvaluationDataset)
2116      assert wrapped_dataset._mlflow_dataset is not None
2117  
2118      wrapped_from_entity = WrapperEvaluationDataset(entity_dataset)
2119      assert wrapped_from_entity == entity_dataset
2120  
2121  
2122  def test_wrapper_works_with_mlflow_log_input_integration(experiments):
2123      dataset = create_dataset(
2124          name="test_log_input",
2125          experiment_id=experiments[0],
2126      )
2127  
2128      records = [
2129          {
2130              "inputs": {"question": "Test question"},
2131              "expectations": {"answer": "Test answer"},
2132          }
2133      ]
2134      dataset.merge_records(records)
2135  
2136      with mlflow.start_run(experiment_id=experiments[0]) as run:
2137          mlflow.log_input(dataset, context="evaluation")
2138  
2139      run_data = mlflow.get_run(run.info.run_id)
2140      assert len(run_data.inputs.dataset_inputs) == 1
2141      dataset_input = run_data.inputs.dataset_inputs[0]
2142      assert dataset_input.dataset.name == "test_log_input"
2143      assert dataset_input.dataset.digest == dataset.digest
2144  
2145  
2146  def test_wrapper_isinstance_checks_for_dataset_interfaces(experiments):
2147      dataset = create_dataset(
2148          name="test_isinstance",
2149          experiment_id=experiments[0],
2150      )
2151  
2152      assert isinstance(dataset, Dataset)
2153      assert isinstance(dataset, PyFuncConvertibleDatasetMixin)
2154      assert isinstance(dataset, WrapperEvaluationDataset)
2155      assert not isinstance(dataset, EntityEvaluationDataset)
2156      assert isinstance(dataset, (WrapperEvaluationDataset, EntityEvaluationDataset))
2157  
2158  
2159  @pytest.mark.parametrize(
2160      "records",
2161      [
2162          [
2163              {"inputs": {"persona": "Student", "goal": "Find articles"}},
2164              {
2165                  "inputs": {
2166                      "persona": "Researcher",
2167                      "goal": "Review",
2168                      "context": {"dept": "CS"},
2169                  }
2170              },
2171              {"inputs": {"goal": "Single goal"}, "expectations": {"output": "expected"}},
2172          ],
2173          [
2174              {"inputs": {"goal": "Learn ML", "simulation_guidelines": "Be concise"}},
2175              {
2176                  "inputs": {
2177                      "persona": "Engineer",
2178                      "goal": "Debug",
2179                      "simulation_guidelines": "Focus on logs",
2180                  }
2181              },
2182              {
2183                  "inputs": {
2184                      "persona": "Student",
2185                      "goal": "Study",
2186                      "context": {"course": "CS101"},
2187                      "simulation_guidelines": "Ask clarifying questions",
2188                  }
2189              },
2190          ],
2191      ],
2192  )
2193  def test_multiturn_valid_formats(experiments, records):
2194      dataset = create_dataset(name="multiturn_test", experiment_id=experiments[0])
2195      dataset.merge_records(records)
2196      df = dataset.to_df()
2197  
2198      assert len(df) == 3
2199      for _, row in df.iterrows():
2200          assert any(
2201              key in row["inputs"] for key in ["persona", "goal", "context", "simulation_guidelines"]
2202          )
2203  
2204  
2205  @pytest.mark.parametrize(
2206      ("records", "error_pattern"),
2207      [
2208          # Top-level session fields
2209          (
2210              [{"persona": "Student", "goal": "Find articles", "custom_field": "value"}],
2211              "Each record must have an 'inputs' field",
2212          ),
2213          # Mixed fields in inputs
2214          (
2215              [
2216                  {
2217                      "inputs": {
2218                          "persona": "Student",
2219                          "goal": "Find",
2220                          "custom_field": "value",
2221                      }
2222                  }
2223              ],
2224              "Invalid input schema.*cannot mix session fields",
2225          ),
2226          # Inconsistent batch schema
2227          (
2228              [
2229                  {"inputs": {"persona": "Student", "goal": "Find articles"}},
2230                  {"inputs": {"question": "What is MLflow?"}},
2231              ],
2232              "must use the same granularity.*Found",
2233          ),
2234          # Empty inputs in batch with session records
2235          (
2236              [
2237                  {"inputs": {"goal": "Find articles"}},
2238                  {"inputs": {}},
2239              ],
2240              "Empty inputs are not allowed for session records.*'goal' field is required",
2241          ),
2242      ],
2243  )
2244  def test_multiturn_validation_errors(experiments, records, error_pattern):
2245      dataset = create_dataset(name="multiturn_error_test", experiment_id=experiments[0])
2246      with pytest.raises(MlflowException, match=error_pattern):
2247          dataset.merge_records(records)
2248  
2249  
2250  @pytest.mark.parametrize(
2251      ("existing_records", "new_records"),
2252      [
2253          # Multiturn then custom
2254          (
2255              [{"inputs": {"persona": "Student", "goal": "Find articles"}}],
2256              [{"inputs": {"question": "What is MLflow?", "model": "gpt-4"}}],
2257          ),
2258          # Custom then multiturn
2259          (
2260              [{"inputs": {"question": "What is MLflow?", "model": "gpt-4"}}],
2261              [{"inputs": {"persona": "Student", "goal": "Find articles"}}],
2262          ),
2263      ],
2264  )
2265  def test_multiturn_schema_compatibility(experiments, existing_records, new_records):
2266      dataset = create_dataset(name="multiturn_compat_test", experiment_id=experiments[0])
2267      dataset.merge_records(existing_records)
2268  
2269      with pytest.raises(MlflowException, match="Cannot mix granularities"):
2270          dataset.merge_records(new_records)
2271  
2272  
2273  def test_multiturn_with_expectations_and_tags(experiments):
2274      dataset = create_dataset(name="multiturn_full_test", experiment_id=experiments[0])
2275      records = [
2276          {
2277              "inputs": {
2278                  "persona": "Graduate Student",
2279                  "goal": "Find peer-reviewed articles on machine learning",
2280                  "context": {"user_id": "U0001", "department": "CS"},
2281                  "simulation_guidelines": "Be thorough and cite sources",
2282              },
2283              "expectations": {"expected_output": "relevant articles", "quality": "high"},
2284              "tags": {"difficulty": "medium"},
2285          },
2286          {
2287              "inputs": {
2288                  "persona": "Librarian",
2289                  "goal": "Help with inter-library loan",
2290              },
2291              "expectations": {"expected_output": "loan information"},
2292          },
2293      ]
2294  
2295      dataset.merge_records(records)
2296  
2297      df = dataset.to_df()
2298      assert len(df) == 2
2299  
2300      grad_record = df[df["inputs"].apply(lambda x: x.get("persona") == "Graduate Student")].iloc[0]
2301      assert grad_record["expectations"]["expected_output"] == "relevant articles"
2302      assert grad_record["expectations"]["quality"] == "high"
2303      assert grad_record["tags"]["difficulty"] == "medium"
2304      assert grad_record["inputs"]["context"] == {"user_id": "U0001", "department": "CS"}
2305      assert grad_record["inputs"]["simulation_guidelines"] == "Be thorough and cite sources"