test_fluent.py
1 import json 2 import os 3 import sys 4 import warnings 5 from unittest import mock 6 7 import pandas as pd 8 import pytest 9 10 import mlflow 11 from mlflow.data import Dataset 12 from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin 13 from mlflow.entities.dataset_record_source import DatasetRecordSourceType 14 from mlflow.entities.evaluation_dataset import ( 15 EvaluationDataset as EntityEvaluationDataset, 16 ) 17 from mlflow.exceptions import MlflowException 18 from mlflow.genai.datasets import ( 19 EvaluationDataset, 20 create_dataset, 21 delete_dataset, 22 delete_dataset_tag, 23 get_dataset, 24 search_datasets, 25 set_dataset_tags, 26 ) 27 from mlflow.genai.datasets.evaluation_dataset import ( 28 EvaluationDataset as WrapperEvaluationDataset, 29 ) 30 from mlflow.store.entities.paged_list import PagedList 31 from mlflow.store.tracking import SEARCH_EVALUATION_DATASETS_MAX_RESULTS 32 from mlflow.tracking import MlflowClient 33 from mlflow.utils.mlflow_tags import MLFLOW_USER 34 35 36 @pytest.fixture 37 def mock_client(): 38 with ( 39 mock.patch("mlflow.tracking.client.MlflowClient") as mock_client_class, 40 mock.patch("mlflow.genai.datasets.MlflowClient", mock_client_class), 41 ): 42 mock_client_instance = mock_client_class.return_value 43 yield mock_client_instance 44 45 46 @pytest.fixture 47 def mock_databricks_environment(): 48 with mock.patch("mlflow.genai.datasets.is_databricks_uri", return_value=True): 49 yield 50 51 52 @pytest.fixture 53 def client(db_uri): 54 original_tracking_uri = mlflow.get_tracking_uri() 55 mlflow.set_tracking_uri(db_uri) 56 yield MlflowClient(tracking_uri=db_uri) 57 mlflow.set_tracking_uri(original_tracking_uri) 58 59 60 @pytest.fixture 61 def experiments(client): 62 exp1 = client.create_experiment("test_exp_1") 63 exp2 = client.create_experiment("test_exp_2") 64 exp3 = client.create_experiment("test_exp_3") 65 return [exp1, exp2, exp3] 66 67 68 @pytest.fixture 69 def experiment(client): 70 return client.create_experiment("test_trace_experiment") 71 72 73 def test_create_dataset(mock_client): 74 expected_dataset = EntityEvaluationDataset( 75 dataset_id="test_id", 76 name="test_dataset", 77 digest="abc123", 78 created_time=123456789, 79 last_update_time=123456789, 80 tags={"environment": "production", "version": "1.0"}, 81 ) 82 mock_client.create_dataset.return_value = expected_dataset 83 84 result = create_dataset( 85 name="test_dataset", 86 experiment_id=["exp1", "exp2"], 87 tags={"environment": "production", "version": "1.0"}, 88 ) 89 90 assert result == expected_dataset 91 mock_client.create_dataset.assert_called_once_with( 92 name="test_dataset", 93 experiment_id=["exp1", "exp2"], 94 tags={"environment": "production", "version": "1.0"}, 95 ) 96 97 98 def test_create_dataset_single_experiment_id(mock_client): 99 expected_dataset = EntityEvaluationDataset( 100 dataset_id="test_id", 101 name="test_dataset", 102 digest="abc123", 103 created_time=123456789, 104 last_update_time=123456789, 105 ) 106 mock_client.create_dataset.return_value = expected_dataset 107 108 result = create_dataset( 109 name="test_dataset", 110 experiment_id="exp1", 111 ) 112 113 assert result == expected_dataset 114 mock_client.create_dataset.assert_called_once_with( 115 name="test_dataset", 116 experiment_id=["exp1"], 117 tags=None, 118 ) 119 120 121 def test_create_dataset_with_empty_tags(mock_client): 122 expected_dataset = EntityEvaluationDataset( 123 dataset_id="test_id", 124 name="test_dataset", 125 digest="abc123", 126 created_time=123456789, 127 last_update_time=123456789, 128 tags={}, 129 ) 130 mock_client.create_dataset.return_value = expected_dataset 131 132 result = create_dataset( 133 name="test_dataset", 134 experiment_id=["exp1"], 135 tags={}, 136 ) 137 138 assert result == expected_dataset 139 mock_client.create_dataset.assert_called_once_with( 140 name="test_dataset", 141 experiment_id=["exp1"], 142 tags={}, 143 ) 144 145 146 def test_create_dataset_databricks(mock_databricks_environment): 147 mock_dataset = mock.Mock() 148 with mock.patch.dict( 149 "sys.modules", 150 { 151 "databricks.agents.datasets": mock.Mock( 152 create_dataset=mock.Mock(return_value=mock_dataset) 153 ) 154 }, 155 ): 156 result = create_dataset( 157 name="catalog.schema.table", 158 experiment_id=["exp1", "exp2"], 159 ) 160 161 sys.modules["databricks.agents.datasets"].create_dataset.assert_called_once_with( 162 "catalog.schema.table", ["exp1", "exp2"] 163 ) 164 assert isinstance(result, EvaluationDataset) 165 166 167 def test_get_dataset(mock_client): 168 expected_dataset = EntityEvaluationDataset( 169 dataset_id="test_id", 170 name="test_dataset", 171 digest="abc123", 172 created_time=123456789, 173 last_update_time=123456789, 174 ) 175 mock_client.get_dataset.return_value = expected_dataset 176 177 result = get_dataset(dataset_id="test_id") 178 179 assert result == expected_dataset 180 mock_client.get_dataset.assert_called_once_with("test_id") 181 182 183 def test_get_dataset_missing_id(): 184 with pytest.raises(ValueError, match="Either 'name' or 'dataset_id' must be provided"): 185 get_dataset() 186 187 188 def test_get_dataset_databricks(mock_databricks_environment): 189 mock_dataset = mock.Mock() 190 with mock.patch.dict( 191 "sys.modules", 192 {"databricks.agents.datasets": mock.Mock(get_dataset=mock.Mock(return_value=mock_dataset))}, 193 ): 194 result = get_dataset(name="catalog.schema.table") 195 196 sys.modules["databricks.agents.datasets"].get_dataset.assert_called_once_with( 197 "catalog.schema.table" 198 ) 199 assert isinstance(result, EvaluationDataset) 200 201 202 def test_get_dataset_databricks_missing_name(mock_databricks_environment): 203 with pytest.raises(ValueError, match="Parameter 'name' is required"): 204 get_dataset(dataset_id="test_id") 205 206 207 def test_get_dataset_by_name_oss(experiments): 208 dataset = create_dataset( 209 name="unique_dataset_name", 210 experiment_id=experiments[0], 211 tags={"test": "get_by_name"}, 212 ) 213 214 retrieved = get_dataset(name="unique_dataset_name") 215 216 assert retrieved.dataset_id == dataset.dataset_id 217 assert retrieved.name == "unique_dataset_name" 218 assert retrieved.tags["test"] == "get_by_name" 219 220 221 def test_get_dataset_by_name_not_found(client): 222 with pytest.raises(MlflowException, match="Dataset with name 'nonexistent_dataset' not found"): 223 get_dataset(name="nonexistent_dataset") 224 225 226 def test_get_dataset_by_name_multiple_matches(experiments): 227 create_dataset( 228 name="duplicate_name", 229 experiment_id=experiments[0], 230 ) 231 create_dataset( 232 name="duplicate_name", 233 experiment_id=experiments[1], 234 ) 235 236 with pytest.raises(MlflowException, match="Multiple datasets found with name 'duplicate_name'"): 237 get_dataset(name="duplicate_name") 238 239 240 def test_get_dataset_both_name_and_id_error(experiments): 241 dataset = create_dataset( 242 name="test_dataset_both", 243 experiment_id=experiments[0], 244 ) 245 246 with pytest.raises(ValueError, match="Cannot specify both 'name' and 'dataset_id'"): 247 get_dataset(name="test_dataset_both", dataset_id=dataset.dataset_id) 248 249 250 def test_get_dataset_neither_name_nor_id_error(client): 251 with pytest.raises(ValueError, match="Either 'name' or 'dataset_id' must be provided"): 252 get_dataset() 253 254 255 @pytest.mark.parametrize( 256 "name", 257 [ 258 "dataset's_with_single_quote", 259 'dataset"with_double_quote', 260 ], 261 ) 262 def test_get_dataset_name_with_quotes(experiments, name): 263 dataset = create_dataset(name=name, experiment_id=experiments[0]) 264 265 retrieved = get_dataset(name=name) 266 267 assert retrieved.dataset_id == dataset.dataset_id 268 assert retrieved.name == name 269 270 271 def test_delete_dataset(mock_client): 272 delete_dataset(dataset_id="test_id") 273 274 mock_client.delete_dataset.assert_called_once_with("test_id") 275 276 277 def test_delete_dataset_missing_id(): 278 with pytest.raises(ValueError, match="Parameter 'dataset_id' is required"): 279 delete_dataset() 280 281 282 def test_delete_dataset_databricks(mock_databricks_environment): 283 with mock.patch.dict( 284 "sys.modules", 285 {"databricks.agents.datasets": mock.Mock(delete_dataset=mock.Mock())}, 286 ): 287 delete_dataset(name="catalog.schema.table") 288 289 sys.modules["databricks.agents.datasets"].delete_dataset.assert_called_once_with( 290 "catalog.schema.table" 291 ) 292 293 294 def test_search_datasets_with_mock(mock_client): 295 datasets = [ 296 EntityEvaluationDataset( 297 dataset_id="id1", 298 name="dataset1", 299 digest="digest1", 300 created_time=123456789, 301 last_update_time=123456789, 302 ), 303 EntityEvaluationDataset( 304 dataset_id="id2", 305 name="dataset2", 306 digest="digest2", 307 created_time=123456789, 308 last_update_time=123456789, 309 ), 310 ] 311 mock_client.search_datasets.return_value = PagedList(datasets, None) 312 313 result = search_datasets( 314 experiment_ids=["exp1", "exp2"], 315 filter_string="name LIKE 'test%'", 316 max_results=100, 317 order_by=["created_time DESC"], 318 ) 319 320 assert len(result) == 2 321 assert isinstance(result, list) 322 323 mock_client.search_datasets.assert_called_once_with( 324 experiment_ids=["exp1", "exp2"], 325 filter_string="name LIKE 'test%'", 326 max_results=50, 327 order_by=["created_time DESC"], 328 page_token=None, 329 ) 330 331 332 def test_search_datasets_single_experiment_id(mock_client): 333 datasets = [ 334 EntityEvaluationDataset( 335 dataset_id="id1", 336 name="dataset1", 337 digest="digest1", 338 created_time=123456789, 339 last_update_time=123456789, 340 ) 341 ] 342 mock_client.search_datasets.return_value = PagedList(datasets, None) 343 344 # When no max_results is specified, it defaults to None which means get all 345 # Mock time to have a consistent filter_string 346 with mock.patch("time.time", return_value=1234567890): 347 search_datasets(experiment_ids="exp1") 348 349 # The pagination wrapper will use SEARCH_EVALUATION_DATASETS_MAX_RESULTS as the page size 350 # Now the function adds default filter (last 7 days) and order_by when not specified 351 seven_days_ago = int((1234567890 - 7 * 24 * 60 * 60) * 1000) 352 mock_client.search_datasets.assert_called_once_with( 353 experiment_ids=["exp1"], 354 filter_string=f"created_time >= {seven_days_ago}", 355 max_results=SEARCH_EVALUATION_DATASETS_MAX_RESULTS, # Page size 356 order_by=["created_time DESC"], 357 page_token=None, 358 ) 359 360 361 def test_search_datasets_pagination_handling(mock_client): 362 page1_datasets = [ 363 EntityEvaluationDataset( 364 dataset_id=f"id{i}", 365 name=f"dataset{i}", 366 digest=f"digest{i}", 367 created_time=123456789, 368 last_update_time=123456789, 369 ) 370 for i in range(3) 371 ] 372 373 page2_datasets = [ 374 EntityEvaluationDataset( 375 dataset_id=f"id{i}", 376 name=f"dataset{i}", 377 digest=f"digest{i}", 378 created_time=123456789, 379 last_update_time=123456789, 380 ) 381 for i in range(3, 5) 382 ] 383 384 mock_client.search_datasets.side_effect = [ 385 PagedList(page1_datasets, "token1"), 386 PagedList(page2_datasets, None), 387 ] 388 389 result = search_datasets(experiment_ids=["exp1"], max_results=10) 390 391 assert len(result) == 5 392 assert isinstance(result, list) 393 394 assert mock_client.search_datasets.call_count == 2 395 396 first_call = mock_client.search_datasets.call_args_list[0] 397 assert first_call[1]["page_token"] is None 398 399 second_call = mock_client.search_datasets.call_args_list[1] 400 assert second_call[1]["page_token"] == "token1" 401 402 403 def test_search_datasets_single_page(mock_client): 404 datasets = [ 405 EntityEvaluationDataset( 406 dataset_id="id1", 407 name="dataset1", 408 digest="digest1", 409 created_time=123456789, 410 last_update_time=123456789, 411 ) 412 ] 413 414 mock_client.search_datasets.return_value = PagedList(datasets, None) 415 416 result = search_datasets(max_results=10) 417 418 assert len(result) == 1 419 assert isinstance(result, list) 420 421 assert mock_client.search_datasets.call_count == 1 422 423 424 def test_search_datasets_databricks(mock_databricks_environment, mock_client): 425 datasets = [ 426 EntityEvaluationDataset( 427 dataset_id="id1", 428 name="dataset1", 429 digest="digest1", 430 created_time=123456789, 431 last_update_time=123456789, 432 ), 433 ] 434 mock_client.search_datasets.return_value = PagedList(datasets, None) 435 436 result = search_datasets(experiment_ids=["exp1"]) 437 438 assert len(result) == 1 439 assert isinstance(result, list) 440 441 # Verify that default filter_string and order_by are NOT set for Databricks 442 # (since these parameters may not be supported by all Databricks backends) 443 mock_client.search_datasets.assert_called_once() 444 call_kwargs = mock_client.search_datasets.call_args.kwargs 445 assert call_kwargs.get("filter_string") is None 446 assert call_kwargs.get("order_by") is None 447 448 449 def test_databricks_import_error(): 450 with ( 451 mock.patch("mlflow.genai.datasets.is_databricks_uri", return_value=True), 452 mock.patch.dict("sys.modules", {"databricks.agents.datasets": None}), 453 mock.patch("builtins.__import__", side_effect=ImportError("No module")), 454 ): 455 with pytest.raises(ImportError, match="databricks-agents"): 456 create_dataset(name="test", experiment_id="exp1") 457 458 459 def test_databricks_profile_uri_support(): 460 mock_dataset = mock.Mock() 461 with ( 462 mock.patch( 463 "mlflow.genai.datasets.get_tracking_uri", 464 return_value="databricks://profilename", 465 ), 466 mock.patch.dict( 467 "sys.modules", 468 { 469 "databricks.agents.datasets": mock.Mock( 470 get_dataset=mock.Mock(return_value=mock_dataset), 471 create_dataset=mock.Mock(return_value=mock_dataset), 472 delete_dataset=mock.Mock(), 473 ) 474 }, 475 ), 476 ): 477 result = get_dataset(name="catalog.schema.table") 478 sys.modules["databricks.agents.datasets"].get_dataset.assert_called_once_with( 479 "catalog.schema.table" 480 ) 481 assert isinstance(result, EvaluationDataset) 482 483 result2 = create_dataset(name="catalog.schema.table2", experiment_id=["exp1"]) 484 sys.modules["databricks.agents.datasets"].create_dataset.assert_called_once_with( 485 "catalog.schema.table2", ["exp1"] 486 ) 487 assert isinstance(result2, EvaluationDataset) 488 489 delete_dataset(name="catalog.schema.table3") 490 sys.modules["databricks.agents.datasets"].delete_dataset.assert_called_once_with( 491 "catalog.schema.table3" 492 ) 493 494 495 def test_databricks_profile_env_var_set_from_uri(monkeypatch): 496 mock_dataset = mock.Mock() 497 profile_values_during_calls = [] 498 499 def mock_get_dataset(name): 500 profile_values_during_calls.append(( 501 "get_dataset", 502 os.environ.get("DATABRICKS_CONFIG_PROFILE"), 503 )) 504 return mock_dataset 505 506 def mock_create_dataset(name, experiment_ids): 507 profile_values_during_calls.append(( 508 "create_dataset", 509 os.environ.get("DATABRICKS_CONFIG_PROFILE"), 510 )) 511 return mock_dataset 512 513 def mock_delete_dataset(name): 514 profile_values_during_calls.append(( 515 "delete_dataset", 516 os.environ.get("DATABRICKS_CONFIG_PROFILE"), 517 )) 518 519 mock_agents_module = mock.Mock( 520 get_dataset=mock_get_dataset, 521 create_dataset=mock_create_dataset, 522 delete_dataset=mock_delete_dataset, 523 ) 524 monkeypatch.setitem(sys.modules, "databricks.agents.datasets", mock_agents_module) 525 monkeypatch.setattr("mlflow.genai.datasets.get_tracking_uri", lambda: "databricks://myprofile") 526 527 assert "DATABRICKS_CONFIG_PROFILE" not in os.environ 528 529 get_dataset(name="catalog.schema.table") 530 create_dataset(name="catalog.schema.table", experiment_id="exp1") 531 delete_dataset(name="catalog.schema.table") 532 533 assert "DATABRICKS_CONFIG_PROFILE" not in os.environ 534 535 assert profile_values_during_calls == [ 536 ("get_dataset", "myprofile"), 537 ("create_dataset", "myprofile"), 538 ("delete_dataset", "myprofile"), 539 ] 540 541 542 def test_databricks_profile_env_var_overridden_and_restored(monkeypatch): 543 mock_dataset = mock.Mock() 544 profile_during_call = None 545 546 def mock_get_dataset(name): 547 nonlocal profile_during_call 548 profile_during_call = os.environ.get("DATABRICKS_CONFIG_PROFILE") 549 return mock_dataset 550 551 mock_agents_module = mock.Mock(get_dataset=mock_get_dataset) 552 monkeypatch.setitem(sys.modules, "databricks.agents.datasets", mock_agents_module) 553 monkeypatch.setattr("mlflow.genai.datasets.get_tracking_uri", lambda: "databricks://myprofile") 554 monkeypatch.setenv("DATABRICKS_CONFIG_PROFILE", "original_profile") 555 556 assert os.environ.get("DATABRICKS_CONFIG_PROFILE") == "original_profile" 557 558 get_dataset(name="catalog.schema.table") 559 560 assert os.environ.get("DATABRICKS_CONFIG_PROFILE") == "original_profile" 561 assert profile_during_call == "myprofile" 562 563 564 def test_databricks_dataset_merge_records_uses_profile(monkeypatch): 565 profile_during_merge = None 566 profile_during_to_df = None 567 568 mock_inner_dataset = mock.Mock() 569 mock_inner_dataset.digest = "test_digest" 570 mock_inner_dataset.name = "catalog.schema.table" 571 mock_inner_dataset.dataset_id = "dataset-123" 572 573 def mock_merge_records(records): 574 nonlocal profile_during_merge 575 profile_during_merge = os.environ.get("DATABRICKS_CONFIG_PROFILE") 576 return mock_inner_dataset 577 578 def mock_to_df(): 579 nonlocal profile_during_to_df 580 profile_during_to_df = os.environ.get("DATABRICKS_CONFIG_PROFILE") 581 import pandas as pd 582 583 return pd.DataFrame({"test": [1, 2, 3]}) 584 585 mock_inner_dataset.merge_records = mock_merge_records 586 mock_inner_dataset.to_df = mock_to_df 587 588 def mock_get_dataset(name): 589 return mock_inner_dataset 590 591 mock_agents_module = mock.Mock(get_dataset=mock_get_dataset) 592 monkeypatch.setitem(sys.modules, "databricks.agents.datasets", mock_agents_module) 593 monkeypatch.setattr("mlflow.genai.datasets.get_tracking_uri", lambda: "databricks://myprofile") 594 595 assert "DATABRICKS_CONFIG_PROFILE" not in os.environ 596 597 dataset = get_dataset(name="catalog.schema.table") 598 599 assert "DATABRICKS_CONFIG_PROFILE" not in os.environ 600 601 dataset.merge_records([{"inputs": {"q": "test"}}]) 602 assert profile_during_merge == "myprofile" 603 assert "DATABRICKS_CONFIG_PROFILE" not in os.environ 604 605 dataset.to_df() 606 assert profile_during_to_df == "myprofile" 607 assert "DATABRICKS_CONFIG_PROFILE" not in os.environ 608 609 610 def test_create_dataset_with_user_tag(experiments): 611 dataset = create_dataset( 612 name="test_user_attribution", 613 experiment_id=experiments[0], 614 tags={"environment": "test", MLFLOW_USER: "john_doe"}, 615 ) 616 617 assert dataset.name == "test_user_attribution" 618 assert dataset.tags[MLFLOW_USER] == "john_doe" 619 assert dataset.created_by == "john_doe" 620 621 dataset2 = create_dataset( 622 name="test_no_user", 623 experiment_id=experiments[0], 624 tags={"environment": "test"}, 625 ) 626 627 assert dataset2.name == "test_no_user" 628 assert isinstance(dataset2.tags[MLFLOW_USER], str) 629 assert dataset2.created_by == dataset2.tags[MLFLOW_USER] 630 631 632 def test_create_and_get_dataset(experiments): 633 dataset = create_dataset( 634 name="qa_evaluation_v1", 635 experiment_id=[experiments[0], experiments[1]], 636 tags={"source": "manual_curation", "environment": "test"}, 637 ) 638 639 assert dataset.name == "qa_evaluation_v1" 640 assert dataset.tags["source"] == "manual_curation" 641 assert dataset.tags["environment"] == "test" 642 assert len(dataset.experiment_ids) == 2 643 assert dataset.dataset_id is not None 644 645 retrieved = get_dataset(dataset_id=dataset.dataset_id) 646 647 assert retrieved.dataset_id == dataset.dataset_id 648 assert retrieved.name == dataset.name 649 assert retrieved.tags == dataset.tags 650 assert set(retrieved.experiment_ids) == {experiments[0], experiments[1]} 651 652 653 def test_create_dataset_minimal_params(client): 654 dataset = create_dataset(name="minimal_dataset") 655 656 assert dataset.name == "minimal_dataset" 657 assert "mlflow.user" not in dataset.tags or isinstance(dataset.tags.get("mlflow.user"), str) 658 assert dataset.experiment_ids == ["0"] 659 660 661 def test_active_record_pattern_merge_records(experiments): 662 dataset = create_dataset( 663 name="active_record_test", 664 experiment_id=experiments[0], 665 ) 666 667 records_batch1 = [ 668 { 669 "inputs": {"question": "What is MLflow?"}, 670 "outputs": { 671 "answer": "MLflow is an open source platform for managing the ML lifecycle", 672 "key1": "value1", 673 }, 674 "expectations": { 675 "answer": "MLflow is an open source platform", 676 "key2": "value2", 677 }, 678 "tags": {"difficulty": "easy"}, 679 }, 680 { 681 "inputs": {"question": "What is Python?"}, 682 "outputs": {"answer": "Python is a versatile programming language"}, 683 "expectations": {"answer": "Python is a programming language"}, 684 "tags": {"difficulty": "easy"}, 685 }, 686 ] 687 688 records_batch2 = [ 689 { 690 "inputs": {"question": "What is MLflow?"}, 691 "outputs": {"answer": "MLflow is a popular ML lifecycle platform"}, 692 "expectations": {"answer": "MLflow is an ML lifecycle platform"}, 693 "tags": {"category": "ml"}, 694 }, 695 { 696 "inputs": {"question": "What is Docker?"}, 697 "outputs": {"answer": "Docker is a popular containerization platform"}, 698 "expectations": {"answer": "Docker is a containerization platform"}, 699 "tags": {"difficulty": "medium"}, 700 }, 701 ] 702 703 dataset.merge_records(records_batch1) 704 705 df1 = dataset.to_df() 706 assert len(df1) == 2 707 708 mlflow_record = df1[df1["inputs"].apply(lambda x: x.get("question") == "What is MLflow?")].iloc[ 709 0 710 ] 711 assert mlflow_record["expectations"] == { 712 "answer": "MLflow is an open source platform", 713 "key2": "value2", 714 } 715 assert mlflow_record["outputs"] == { 716 "answer": "MLflow is an open source platform for managing the ML lifecycle", 717 "key1": "value1", 718 } 719 assert mlflow_record["tags"]["difficulty"] == "easy" 720 assert "category" not in mlflow_record["tags"] 721 722 dataset.merge_records(records_batch2) 723 724 df2 = dataset.to_df() 725 assert len(df2) == 3 726 727 mlflow_record_updated = df2[ 728 df2["inputs"].apply(lambda x: x.get("question") == "What is MLflow?") 729 ].iloc[0] 730 731 assert mlflow_record_updated["expectations"] == { 732 "answer": "MLflow is an ML lifecycle platform", 733 "key2": "value2", 734 } 735 assert mlflow_record_updated["outputs"] == { 736 "answer": "MLflow is a popular ML lifecycle platform" 737 } 738 assert mlflow_record_updated["tags"]["difficulty"] == "easy" 739 assert mlflow_record_updated["tags"]["category"] == "ml" 740 741 # Verify that the new Docker record also has outputs 742 docker_record = df2[df2["inputs"].apply(lambda x: x.get("question") == "What is Docker?")].iloc[ 743 0 744 ] 745 assert docker_record["outputs"]["answer"] == "Docker is a popular containerization platform" 746 assert docker_record["expectations"]["answer"] == "Docker is a containerization platform" 747 assert docker_record["tags"]["difficulty"] == "medium" 748 749 750 def test_dataset_with_dataframe_records(experiments): 751 dataset = create_dataset( 752 name="dataframe_test", 753 experiment_id=experiments[0], 754 tags={"source": "csv", "file": "test_data.csv"}, 755 ) 756 757 df = pd.DataFrame([ 758 { 759 "inputs": {"text": "The movie was amazing!", "model": "sentiment-v1"}, 760 "expectations": {"sentiment": "positive", "confidence": 0.95}, 761 "tags": {"source": "imdb"}, 762 }, 763 { 764 "inputs": {"text": "Terrible experience", "model": "sentiment-v1"}, 765 "expectations": {"sentiment": "negative", "confidence": 0.88}, 766 "tags": {"source": "yelp"}, 767 }, 768 ]) 769 770 dataset.merge_records(df) 771 772 result_df = dataset.to_df() 773 assert len(result_df) == 2 774 assert all(col in result_df.columns for col in ["inputs", "expectations", "tags"]) 775 776 # Check that all expected records are present (order-agnostic) 777 texts = {record["inputs"]["text"] for _, record in result_df.iterrows()} 778 expected_texts = {"The movie was amazing!", "Terrible experience"} 779 assert texts == expected_texts 780 781 sentiments = {record["expectations"]["sentiment"] for _, record in result_df.iterrows()} 782 expected_sentiments = {"positive", "negative"} 783 assert sentiments == expected_sentiments 784 785 786 def test_search_datasets(experiments): 787 for i in range(5): 788 create_dataset( 789 name=f"search_test_{i}", 790 experiment_id=[experiments[i % len(experiments)]], 791 tags={"type": "human" if i % 2 == 0 else "trace", "index": str(i)}, 792 ) 793 794 all_results = search_datasets() 795 assert len(all_results) == 5 796 797 exp0_results = search_datasets(experiment_ids=experiments[0]) 798 assert len(exp0_results) == 2 799 800 human_results = search_datasets(filter_string="name LIKE 'search_test_%'") 801 assert len(human_results) == 5 802 803 limited_results = search_datasets(max_results=2) 804 assert len(limited_results) == 2 805 806 more_results = search_datasets(max_results=4) 807 assert len(more_results) == 4 808 809 810 def test_delete_dataset(experiments): 811 dataset = create_dataset( 812 name="to_be_deleted", 813 experiment_id=[experiments[0], experiments[1]], 814 tags={"env": "test", "version": "1.0"}, 815 ) 816 dataset_id = dataset.dataset_id 817 818 dataset.merge_records([{"inputs": {"q": "test"}, "expectations": {"a": "answer"}}]) 819 820 retrieved = get_dataset(dataset_id=dataset_id) 821 assert retrieved is not None 822 assert len(retrieved.to_df()) == 1 823 824 delete_dataset(dataset_id=dataset_id) 825 826 with pytest.raises(MlflowException, match="Could not find|not found"): 827 get_dataset(dataset_id=dataset_id) 828 829 search_results = search_datasets(experiment_ids=[experiments[0], experiments[1]]) 830 found_ids = [d.dataset_id for d in search_results] 831 assert dataset_id not in found_ids 832 833 834 def test_dataset_lifecycle_workflow(experiments): 835 dataset = create_dataset( 836 name="qa_eval_prod_v1", 837 experiment_id=[experiments[0], experiments[1]], 838 tags={"source": "qa_team_annotations", "team": "qa", "env": "prod"}, 839 ) 840 841 initial_cases = [ 842 { 843 "inputs": {"question": "What is the capital of France?"}, 844 "expectations": {"answer": "Paris", "confidence": "high"}, 845 "tags": {"category": "geography", "difficulty": "easy"}, 846 }, 847 { 848 "inputs": {"question": "Explain quantum computing"}, 849 "expectations": {"answer": "Quantum computing uses quantum mechanics principles"}, 850 "tags": {"category": "science", "difficulty": "hard"}, 851 }, 852 ] 853 dataset.merge_records(initial_cases) 854 855 dataset_id = dataset.dataset_id 856 retrieved = get_dataset(dataset_id=dataset_id) 857 df = retrieved.to_df() 858 assert len(df) == 2 859 860 additional_cases = [ 861 { 862 "inputs": {"question": "What is 2+2?"}, 863 "expectations": {"answer": "4", "confidence": "high"}, 864 "tags": {"category": "math", "difficulty": "easy"}, 865 }, 866 ] 867 retrieved.merge_records(additional_cases) 868 869 found = search_datasets( 870 experiment_ids=experiments[0], 871 filter_string="name LIKE 'qa_eval%'", 872 ) 873 assert len(found) == 1 874 assert found[0].dataset_id == dataset_id 875 876 final_dataset = get_dataset(dataset_id=dataset_id) 877 final_df = final_dataset.to_df() 878 assert len(final_df) == 3 879 880 categories = set() 881 for _, row in final_df.iterrows(): 882 if row["tags"] and "category" in row["tags"]: 883 categories.add(row["tags"]["category"]) 884 assert categories == {"geography", "science", "math"} 885 886 887 def test_error_handling_filestore_backend(tmp_path): 888 pytest.skip("FileStore is no longer supported.") 889 file_uri = f"file://{tmp_path}" 890 mlflow.set_tracking_uri(file_uri) 891 892 with pytest.raises(MlflowException, match="not supported with FileStore") as exc: 893 create_dataset(name="test") 894 assert exc.value.error_code == "FEATURE_DISABLED" 895 896 with pytest.raises(MlflowException, match="not supported with FileStore") as exc: 897 get_dataset(dataset_id="test_id") 898 assert exc.value.error_code == "FEATURE_DISABLED" 899 900 with pytest.raises(MlflowException, match="not supported with FileStore") as exc: 901 search_datasets() 902 assert exc.value.error_code == "FEATURE_DISABLED" 903 904 with pytest.raises(MlflowException, match="not supported with FileStore") as exc: 905 delete_dataset(dataset_id="test_id") 906 assert exc.value.error_code == "FEATURE_DISABLED" 907 908 909 def test_single_experiment_id_handling(experiments): 910 dataset = create_dataset( 911 name="single_exp_test", 912 experiment_id=experiments[0], 913 ) 914 915 assert isinstance(dataset.experiment_ids, list) 916 assert dataset.experiment_ids == [experiments[0]] 917 918 results = search_datasets(experiment_ids=experiments[0]) 919 found_ids = [d.dataset_id for d in results] 920 assert dataset.dataset_id in found_ids 921 922 923 def test_trace_to_evaluation_dataset_integration(experiments): 924 trace_inputs = [ 925 {"question": "What is MLflow?", "context": "ML platforms"}, 926 {"question": "What is Python?", "context": "programming"}, 927 {"question": "What is MLflow?", "context": "ML platforms"}, 928 ] 929 930 created_trace_ids = [] 931 for i, inputs in enumerate(trace_inputs): 932 with mlflow.start_run(experiment_id=experiments[i % 2]): 933 with mlflow.start_span(name=f"qa_trace_{i}") as span: 934 span.set_inputs(inputs) 935 span.set_outputs({"answer": f"Answer for {inputs['question']}"}) 936 span.set_attributes({"model": "test-model", "temperature": "0.7"}) 937 trace_id = span.trace_id 938 created_trace_ids.append(trace_id) 939 940 mlflow.log_expectation( 941 trace_id=trace_id, 942 name="expected_answer", 943 value=f"Detailed answer for {inputs['question']}", 944 ) 945 mlflow.log_expectation( 946 trace_id=trace_id, 947 name="quality_score", 948 value=0.85 + i * 0.05, 949 ) 950 951 traces = mlflow.search_traces( 952 locations=[experiments[0], experiments[1]], 953 max_results=10, 954 return_type="list", 955 ) 956 assert len(traces) == 3 957 958 dataset = create_dataset( 959 name="trace_eval_dataset", 960 experiment_id=[experiments[0], experiments[1]], 961 tags={"source": "test_traces", "type": "trace_integration"}, 962 ) 963 964 dataset.merge_records(traces) 965 966 df = dataset.to_df() 967 assert len(df) == 2 968 969 for _, record in df.iterrows(): 970 assert "inputs" in record 971 assert "question" in record["inputs"] 972 assert "context" in record["inputs"] 973 assert record.get("source_type") == "TRACE" 974 assert record.get("source_id") is not None 975 976 mlflow_records = df[df["inputs"].apply(lambda x: x.get("question") == "What is MLflow?")] 977 assert len(mlflow_records) == 1 978 979 with mlflow.start_run(experiment_id=experiments[0]): 980 with mlflow.start_span(name="additional_trace") as span: 981 span.set_inputs({"question": "What is Docker?", "context": "containers"}) 982 span.set_outputs({"answer": "Docker is a containerization platform"}) 983 span.set_attributes({"model": "test-model"}) 984 985 all_traces = mlflow.search_traces( 986 locations=[experiments[0], experiments[1]], max_results=10, return_type="list" 987 ) 988 assert len(all_traces) == 4 989 990 new_trace = None 991 for trace in all_traces: 992 root_span = trace.data._get_root_span() if hasattr(trace, "data") else None 993 if root_span and root_span.inputs and root_span.inputs.get("question") == "What is Docker?": 994 new_trace = trace 995 break 996 997 assert new_trace is not None 998 999 dataset.merge_records([new_trace]) 1000 1001 final_df = dataset.to_df() 1002 assert len(final_df) == 3 1003 1004 retrieved = get_dataset(dataset_id=dataset.dataset_id) 1005 retrieved_df = retrieved.to_df() 1006 assert len(retrieved_df) == 3 1007 1008 delete_dataset(dataset_id=dataset.dataset_id) 1009 1010 with pytest.raises(MlflowException, match="Could not find|not found"): 1011 get_dataset(dataset_id=dataset.dataset_id) 1012 1013 search_results = search_datasets( 1014 experiment_ids=[experiments[0], experiments[1]], max_results=100 1015 ) 1016 found_dataset_ids = [d.dataset_id for d in search_results] 1017 assert dataset.dataset_id not in found_dataset_ids 1018 1019 all_datasets = search_datasets(max_results=100) 1020 all_dataset_ids = [d.dataset_id for d in all_datasets] 1021 assert dataset.dataset_id not in all_dataset_ids 1022 1023 1024 def test_search_traces_dataframe_to_dataset_integration(experiments): 1025 for i in range(3): 1026 with mlflow.start_run(experiment_id=experiments[0]): 1027 with mlflow.start_span(name=f"test_span_{i}") as span: 1028 span.set_inputs({"question": f"Question {i}?", "temperature": 0.7}) 1029 span.set_outputs({"answer": f"Answer {i}"}) 1030 1031 mlflow.log_expectation( 1032 trace_id=span.trace_id, 1033 name="expected_answer", 1034 value=f"Expected answer {i}", 1035 ) 1036 mlflow.log_expectation( 1037 trace_id=span.trace_id, 1038 name="min_score", 1039 value=0.8, 1040 ) 1041 1042 traces_df = mlflow.search_traces( 1043 locations=[experiments[0]], 1044 ) 1045 1046 assert "trace" in traces_df.columns 1047 assert "assessments" in traces_df.columns 1048 assert len(traces_df) == 3 1049 1050 dataset = create_dataset( 1051 name="traces_dataframe_dataset", 1052 experiment_id=experiments[0], 1053 tags={"source": "search_traces", "format": "dataframe"}, 1054 ) 1055 1056 dataset.merge_records(traces_df) 1057 1058 result_df = dataset.to_df() 1059 assert len(result_df) == 3 1060 1061 for idx, row in result_df.iterrows(): 1062 assert "inputs" in row 1063 assert "expectations" in row 1064 assert "source_type" in row 1065 assert row["source_type"] == "TRACE" 1066 1067 assert "question" in row["inputs"] 1068 question_text = row["inputs"]["question"] 1069 assert question_text.startswith("Question ") 1070 assert question_text.endswith("?") 1071 question_num = int(question_text.replace("Question ", "").replace("?", "")) 1072 assert 0 <= question_num <= 2 1073 1074 assert "expected_answer" in row["expectations"] 1075 assert f"Expected answer {question_num}" == row["expectations"]["expected_answer"] 1076 assert "min_score" in row["expectations"] 1077 assert row["expectations"]["min_score"] == 0.8 1078 1079 1080 def test_trace_to_dataset_with_assessments(client, experiment): 1081 trace_data = [ 1082 { 1083 "inputs": {"question": "What is MLflow?", "context": "ML platforms"}, 1084 "outputs": {"answer": "MLflow is an open source platform for ML lifecycle"}, 1085 "expectations": { 1086 "correctness": True, 1087 "completeness": 0.8, 1088 }, 1089 }, 1090 { 1091 "inputs": { 1092 "question": "What is Python?", 1093 "context": "programming languages", 1094 }, 1095 "outputs": {"answer": "Python is a high-level programming language"}, 1096 "expectations": { 1097 "correctness": True, 1098 }, 1099 }, 1100 { 1101 "inputs": {"question": "What is Docker?", "context": "containerization"}, 1102 "outputs": {"answer": "Docker is a container platform"}, 1103 "expectations": {}, 1104 }, 1105 ] 1106 1107 created_traces = [] 1108 for i, data in enumerate(trace_data): 1109 with mlflow.start_run(experiment_id=experiment): 1110 with mlflow.start_span(name=f"qa_trace_{i}") as span: 1111 span.set_inputs(data["inputs"]) 1112 span.set_outputs(data["outputs"]) 1113 span.set_attributes({"model": "test-model", "temperature": 0.7}) 1114 trace_id = span.trace_id 1115 1116 for name, value in data["expectations"].items(): 1117 mlflow.log_expectation( 1118 trace_id=trace_id, 1119 name=name, 1120 value=value, 1121 span_id=span.span_id, 1122 ) 1123 1124 trace = client.get_trace(trace_id) 1125 created_traces.append(trace) 1126 1127 dataset = create_dataset( 1128 name="trace_assessment_dataset", 1129 experiment_id=[experiment], 1130 tags={"source": "trace_integration_test", "version": "1.0"}, 1131 ) 1132 1133 dataset.merge_records(created_traces) 1134 1135 df = dataset.to_df() 1136 assert len(df) == 3 1137 1138 mlflow_record = df[df["inputs"].apply(lambda x: x.get("question") == "What is MLflow?")].iloc[0] 1139 assert mlflow_record["inputs"]["question"] == "What is MLflow?" 1140 assert mlflow_record["inputs"]["context"] == "ML platforms" 1141 1142 assert "expectations" in mlflow_record 1143 assert mlflow_record["expectations"]["correctness"] is True 1144 assert mlflow_record["expectations"]["completeness"] == 0.8 1145 1146 assert mlflow_record["source_type"] == "TRACE" 1147 assert mlflow_record["source_id"] is not None 1148 1149 python_record = df[df["inputs"].apply(lambda x: x.get("question") == "What is Python?")].iloc[0] 1150 assert python_record["expectations"]["correctness"] is True 1151 assert len(python_record["expectations"]) == 1 1152 1153 docker_record = df[df["inputs"].apply(lambda x: x.get("question") == "What is Docker?")].iloc[0] 1154 assert docker_record["expectations"] is None or docker_record["expectations"] == {} 1155 1156 retrieved = get_dataset(dataset_id=dataset.dataset_id) 1157 assert retrieved.tags["source"] == "trace_integration_test" 1158 assert retrieved.tags["version"] == "1.0" 1159 assert set(retrieved.experiment_ids) == {experiment} 1160 1161 1162 def test_trace_deduplication_with_assessments(client, experiment): 1163 trace_ids = [] 1164 for i in range(3): 1165 with mlflow.start_run(experiment_id=experiment): 1166 with mlflow.start_span(name=f"duplicate_trace_{i}") as span: 1167 span.set_inputs({"question": "What is AI?", "model": "gpt-4"}) 1168 span.set_outputs({"answer": f"AI is artificial intelligence (version {i})"}) 1169 trace_id = span.trace_id 1170 trace_ids.append(trace_id) 1171 1172 mlflow.log_expectation( 1173 trace_id=trace_id, 1174 name="quality", 1175 value=0.5 + i * 0.2, 1176 span_id=span.span_id, 1177 ) 1178 1179 traces = [client.get_trace(tid) for tid in trace_ids] 1180 1181 dataset = create_dataset( 1182 name="dedup_test", 1183 experiment_id=experiment, 1184 tags={"test": "deduplication"}, 1185 ) 1186 dataset.merge_records(traces) 1187 1188 df = dataset.to_df() 1189 assert len(df) == 1 1190 1191 record = df.iloc[0] 1192 assert record["inputs"]["question"] == "What is AI?" 1193 assert record["expectations"]["quality"] == 0.9 1194 assert record["source_id"] in trace_ids 1195 1196 1197 def test_mixed_record_types_with_traces(client, experiment): 1198 with mlflow.start_run(experiment_id=experiment): 1199 with mlflow.start_span(name="mixed_test_trace") as span: 1200 span.set_inputs({"question": "What is ML?", "context": "machine learning"}) 1201 span.set_outputs({"answer": "ML stands for Machine Learning"}) 1202 trace_id = span.trace_id 1203 1204 mlflow.log_expectation( 1205 trace_id=trace_id, 1206 name="accuracy", 1207 value=0.95, 1208 span_id=span.span_id, 1209 ) 1210 1211 trace = client.get_trace(trace_id) 1212 1213 dataset = create_dataset( 1214 name="mixed_records_test", 1215 experiment_id=experiment, 1216 tags={"type": "mixed", "test": "true"}, 1217 ) 1218 1219 manual_records = [ 1220 { 1221 "inputs": {"question": "What is AI?"}, 1222 "expectations": {"correctness": True}, 1223 "tags": {"source": "manual"}, 1224 }, 1225 { 1226 "inputs": {"question": "What is Python?"}, 1227 "expectations": {"correctness": True}, 1228 "tags": {"source": "manual"}, 1229 }, 1230 ] 1231 dataset.merge_records(manual_records) 1232 1233 df1 = dataset.to_df() 1234 assert len(df1) == 2 1235 1236 dataset.merge_records([trace]) 1237 1238 df2 = dataset.to_df() 1239 assert len(df2) == 3 1240 1241 ml_record = df2[df2["inputs"].apply(lambda x: x.get("question") == "What is ML?")].iloc[0] 1242 assert ml_record["expectations"]["accuracy"] == 0.95 1243 assert ml_record["source_type"] == "TRACE" 1244 1245 manual_questions = {"What is AI?", "What is Python?"} 1246 manual_records_df = df2[df2["inputs"].apply(lambda x: x.get("question") in manual_questions)] 1247 assert len(manual_records_df) == 2 1248 1249 for _, record in manual_records_df.iterrows(): 1250 assert record.get("source_type") != "TRACE" 1251 1252 1253 def test_trace_without_root_span_inputs(client, experiment): 1254 with mlflow.start_run(experiment_id=experiment): 1255 with mlflow.start_span(name="no_inputs_trace") as span: 1256 span.set_outputs({"result": "some output"}) 1257 trace_id = span.trace_id 1258 1259 trace = client.get_trace(trace_id) 1260 1261 dataset = create_dataset( 1262 name="no_inputs_test", 1263 experiment_id=experiment, 1264 ) 1265 1266 dataset.merge_records([trace]) 1267 1268 df = dataset.to_df() 1269 assert len(df) == 1 1270 assert df.iloc[0]["inputs"] == {} 1271 assert df.iloc[0]["expectations"] is None or df.iloc[0]["expectations"] == {} 1272 1273 1274 def test_error_handling_invalid_trace_types(client, experiment): 1275 dataset = create_dataset( 1276 name="error_test", 1277 experiment_id=experiment, 1278 ) 1279 1280 with mlflow.start_run(experiment_id=experiment): 1281 with mlflow.start_span(name="valid_trace") as span: 1282 span.set_inputs({"q": "test"}) 1283 trace_id = span.trace_id 1284 1285 valid_trace = client.get_trace(trace_id) 1286 1287 with pytest.raises(MlflowException, match="Mixed types in trace list"): 1288 dataset.merge_records([valid_trace, {"inputs": {"q": "dict record"}}]) 1289 1290 with pytest.raises(MlflowException, match="Mixed types in trace list"): 1291 dataset.merge_records([valid_trace, "not a trace"]) 1292 1293 1294 def test_trace_integration_end_to_end(client, experiment): 1295 traces_to_create = [ 1296 { 1297 "name": "successful_qa", 1298 "inputs": {"question": "What is the capital of France?", "language": "en"}, 1299 "outputs": {"answer": "Paris", "confidence": 0.99}, 1300 "expectations": {"correctness": True, "confidence_threshold": 0.8}, 1301 }, 1302 { 1303 "name": "incorrect_qa", 1304 "inputs": {"question": "What is 2+2?", "language": "en"}, 1305 "outputs": {"answer": "5", "confidence": 0.5}, 1306 "expectations": {"correctness": False}, 1307 }, 1308 { 1309 "name": "multilingual_qa", 1310 "inputs": {"question": "¿Cómo estás?", "language": "es"}, 1311 "outputs": {"answer": "I'm doing well, thank you!", "confidence": 0.9}, 1312 "expectations": {"language_match": False, "politeness": True}, 1313 }, 1314 ] 1315 1316 created_trace_ids = [] 1317 for trace_config in traces_to_create: 1318 with mlflow.start_run(experiment_id=experiment): 1319 with mlflow.start_span(name=trace_config["name"]) as span: 1320 span.set_inputs(trace_config["inputs"]) 1321 span.set_outputs(trace_config["outputs"]) 1322 span.set_attributes({ 1323 "model": "test-llm-v1", 1324 "temperature": 0.7, 1325 "max_tokens": 100, 1326 }) 1327 trace_id = span.trace_id 1328 created_trace_ids.append(trace_id) 1329 1330 for exp_name, exp_value in trace_config["expectations"].items(): 1331 mlflow.log_expectation( 1332 trace_id=trace_id, 1333 name=exp_name, 1334 value=exp_value, 1335 span_id=span.span_id, 1336 metadata={"trace_name": trace_config["name"]}, 1337 ) 1338 1339 dataset = create_dataset( 1340 name="comprehensive_trace_test", 1341 experiment_id=[experiment], 1342 tags={ 1343 "test_type": "end_to_end", 1344 "model": "test-llm-v1", 1345 "language": "multilingual", 1346 }, 1347 ) 1348 1349 traces = [client.get_trace(tid) for tid in created_trace_ids] 1350 dataset.merge_records(traces) 1351 1352 df = dataset.to_df() 1353 assert len(df) == 3 1354 1355 french_record = df[df["inputs"].apply(lambda x: "France" in str(x.get("question", "")))].iloc[0] 1356 assert french_record["expectations"]["correctness"] is True 1357 assert french_record["expectations"]["confidence_threshold"] == 0.8 1358 1359 math_record = df[df["inputs"].apply(lambda x: "2+2" in str(x.get("question", "")))].iloc[0] 1360 assert math_record["expectations"]["correctness"] is False 1361 1362 spanish_record = df[df["inputs"].apply(lambda x: x.get("language") == "es")].iloc[0] 1363 assert spanish_record["expectations"]["language_match"] is False 1364 assert spanish_record["expectations"]["politeness"] is True 1365 1366 retrieved_dataset = get_dataset(dataset_id=dataset.dataset_id) 1367 retrieved_df = retrieved_dataset.to_df() 1368 assert len(retrieved_df) == 3 1369 assert retrieved_dataset.tags["model"] == "test-llm-v1" 1370 1371 additional_records = [ 1372 { 1373 "inputs": {"question": "What is Python?", "language": "en"}, 1374 "expectations": {"technical_accuracy": True}, 1375 "tags": {"source": "manual_addition"}, 1376 } 1377 ] 1378 retrieved_dataset.merge_records(additional_records) 1379 1380 final_df = retrieved_dataset.to_df() 1381 assert len(final_df) == 4 1382 1383 trace_records = final_df[final_df["source_type"] == "TRACE"] 1384 assert len(trace_records) == 3 1385 1386 manual_records = final_df[final_df["source_type"] != "TRACE"] 1387 assert len(manual_records) == 1 1388 1389 1390 def test_dataset_pagination_transparency_large_records(experiments): 1391 dataset = create_dataset( 1392 name="test_pagination_transparency", 1393 experiment_id=experiments[0], 1394 tags={"test": "large_dataset"}, 1395 ) 1396 1397 large_records = [ 1398 { 1399 "inputs": {"question": f"Question {i}", "index": i}, 1400 "expectations": {"answer": f"Answer {i}", "score": i * 0.01}, 1401 } 1402 for i in range(150) 1403 ] 1404 1405 dataset.merge_records(large_records) 1406 1407 all_records = dataset._mlflow_dataset.records 1408 assert len(all_records) == 150 1409 1410 record_indices = {record.inputs["index"] for record in all_records} 1411 expected_indices = set(range(150)) 1412 assert record_indices == expected_indices 1413 1414 record_scores = {record.expectations["score"] for record in all_records} 1415 expected_scores = {i * 0.01 for i in range(150)} 1416 assert record_scores == expected_scores 1417 1418 df = dataset.to_df() 1419 assert len(df) == 150 1420 1421 df_indices = {row["index"] for row in df["inputs"]} 1422 assert df_indices == expected_indices 1423 1424 assert not hasattr(dataset, "page_token") 1425 assert not hasattr(dataset, "next_page_token") 1426 assert not hasattr(dataset, "max_results") 1427 1428 second_access = dataset._mlflow_dataset.records 1429 assert second_access is all_records 1430 1431 dataset._mlflow_dataset._records = None 1432 refreshed_records = dataset._mlflow_dataset.records 1433 assert len(refreshed_records) == 150 1434 1435 1436 def test_dataset_internal_pagination_with_mock(experiments): 1437 from mlflow.tracking._tracking_service.utils import _get_store 1438 1439 dataset = create_dataset( 1440 name="test_internal_pagination", 1441 experiment_id=experiments[0], 1442 tags={"test": "pagination_mock"}, 1443 ) 1444 1445 records = [ 1446 {"inputs": {"question": f"Q{i}", "id": i}, "expectations": {"answer": f"A{i}"}} 1447 for i in range(75) 1448 ] 1449 1450 dataset.merge_records(records) 1451 1452 dataset._mlflow_dataset._records = None 1453 1454 store = _get_store() 1455 with mock.patch.object( 1456 store, "_load_dataset_records", wraps=store._load_dataset_records 1457 ) as mock_load: 1458 accessed_records = dataset._mlflow_dataset.records 1459 1460 mock_load.assert_called_once_with(dataset.dataset_id, max_results=None) 1461 assert len(accessed_records) == 75 1462 1463 dataset._mlflow_dataset._records = None 1464 1465 with mock.patch.object( 1466 store, "_load_dataset_records", wraps=store._load_dataset_records 1467 ) as mock_load: 1468 df = dataset.to_df() 1469 1470 mock_load.assert_called_once_with(dataset.dataset_id, max_results=None) 1471 assert len(df) == 75 1472 1473 1474 def test_dataset_experiment_associations(experiments): 1475 from mlflow.genai.datasets import ( 1476 add_dataset_to_experiments, 1477 remove_dataset_from_experiments, 1478 ) 1479 1480 dataset = create_dataset( 1481 name="test_associations", 1482 experiment_id=experiments[0], 1483 tags={"test": "associations"}, 1484 ) 1485 1486 initial_exp_ids = dataset.experiment_ids 1487 assert experiments[0] in initial_exp_ids 1488 1489 updated = add_dataset_to_experiments( 1490 dataset_id=dataset.dataset_id, experiment_ids=[experiments[1], experiments[2]] 1491 ) 1492 assert experiments[0] in updated.experiment_ids 1493 assert experiments[1] in updated.experiment_ids 1494 assert experiments[2] in updated.experiment_ids 1495 assert len(updated.experiment_ids) == 3 1496 1497 result = add_dataset_to_experiments( 1498 dataset_id=dataset.dataset_id, experiment_ids=[experiments[1], experiments[2]] 1499 ) 1500 assert len(result.experiment_ids) == 3 1501 assert all(exp in result.experiment_ids for exp in experiments) 1502 1503 removed = remove_dataset_from_experiments( 1504 dataset_id=dataset.dataset_id, experiment_ids=[experiments[1], experiments[2]] 1505 ) 1506 assert experiments[1] not in removed.experiment_ids 1507 assert experiments[2] not in removed.experiment_ids 1508 assert experiments[0] in removed.experiment_ids 1509 assert len(removed.experiment_ids) == 1 1510 1511 with mock.patch("mlflow.store.tracking.sqlalchemy_store._logger.warning") as mock_warning: 1512 idempotent = remove_dataset_from_experiments( 1513 dataset_id=dataset.dataset_id, 1514 experiment_ids=[experiments[1], experiments[2]], 1515 ) 1516 assert mock_warning.call_count == 2 1517 assert "was not associated" in mock_warning.call_args_list[0][0][0] 1518 1519 assert len(idempotent.experiment_ids) == 1 1520 1521 1522 def test_dataset_associations_filestore_blocking(tmp_path): 1523 pytest.skip("FileStore is no longer supported.") 1524 from mlflow.genai.datasets import ( 1525 add_dataset_to_experiments, 1526 remove_dataset_from_experiments, 1527 ) 1528 1529 mlflow.set_tracking_uri(tmp_path.as_uri()) 1530 1531 with pytest.raises(NotImplementedError, match="not supported with FileStore"): 1532 add_dataset_to_experiments(dataset_id="d-test123", experiment_ids=["1", "2"]) 1533 1534 with pytest.raises(NotImplementedError, match="not supported with FileStore"): 1535 remove_dataset_from_experiments(dataset_id="d-test123", experiment_ids=["1"]) 1536 1537 1538 def test_evaluation_dataset_tags_crud_workflow(experiments): 1539 dataset = create_dataset( 1540 name="test_tags_crud", 1541 experiment_id=experiments[0], 1542 ) 1543 initial_tags = dataset.tags.copy() 1544 1545 set_dataset_tags( 1546 dataset_id=dataset.dataset_id, 1547 tags={ 1548 "team": "ml-platform", 1549 "project": "evaluation", 1550 "priority": "high", 1551 }, 1552 ) 1553 1554 dataset = get_dataset(dataset_id=dataset.dataset_id) 1555 expected_tags = initial_tags.copy() 1556 expected_tags.update({ 1557 "team": "ml-platform", 1558 "project": "evaluation", 1559 "priority": "high", 1560 }) 1561 assert dataset.tags == expected_tags 1562 1563 set_dataset_tags( 1564 dataset_id=dataset.dataset_id, 1565 tags={ 1566 "priority": "medium", 1567 "status": "active", 1568 }, 1569 ) 1570 1571 dataset = get_dataset(dataset_id=dataset.dataset_id) 1572 expected_tags = initial_tags.copy() 1573 expected_tags.update({ 1574 "team": "ml-platform", 1575 "project": "evaluation", 1576 "priority": "medium", 1577 "status": "active", 1578 }) 1579 assert dataset.tags == expected_tags 1580 1581 delete_dataset_tag( 1582 dataset_id=dataset.dataset_id, 1583 key="priority", 1584 ) 1585 1586 dataset = get_dataset(dataset_id=dataset.dataset_id) 1587 expected_tags = initial_tags.copy() 1588 expected_tags.update({ 1589 "team": "ml-platform", 1590 "project": "evaluation", 1591 "status": "active", 1592 }) 1593 assert dataset.tags == expected_tags 1594 1595 delete_dataset(dataset_id=dataset.dataset_id) 1596 1597 with pytest.raises(MlflowException, match="Could not find|not found"): 1598 get_dataset(dataset_id=dataset.dataset_id) 1599 1600 with pytest.raises(MlflowException, match="Could not find|not found"): 1601 set_dataset_tags( 1602 dataset_id=dataset.dataset_id, 1603 tags={"should": "fail"}, 1604 ) 1605 1606 delete_dataset_tag(dataset_id=dataset.dataset_id, key="status") 1607 1608 1609 def test_set_dataset_tags_databricks(mock_databricks_environment): 1610 with pytest.raises(NotImplementedError, match="tag operations are not available"): 1611 set_dataset_tags(dataset_id="test", tags={"key": "value"}) 1612 1613 1614 def test_delete_dataset_tag_databricks(mock_databricks_environment): 1615 with pytest.raises(NotImplementedError, match="tag operations are not available"): 1616 delete_dataset_tag(dataset_id="test", key="key") 1617 1618 1619 def test_dataset_schema_evolution_and_log_input(experiments): 1620 dataset = create_dataset( 1621 name="schema_evolution_test", 1622 experiment_id=[experiments[0]], 1623 tags={"test": "schema_evolution", "mlflow.user": "test_user"}, 1624 ) 1625 1626 stage1_records = [ 1627 { 1628 "inputs": {"prompt": "What is MLflow?"}, 1629 "expectations": {"response": "MLflow is a platform"}, 1630 } 1631 ] 1632 dataset.merge_records(stage1_records) 1633 1634 ds1 = get_dataset(dataset_id=dataset.dataset_id) 1635 schema1 = json.loads(ds1.schema) 1636 assert schema1 is not None 1637 assert "prompt" in schema1["inputs"] 1638 assert schema1["inputs"]["prompt"] == "string" 1639 assert len(schema1["inputs"]) == 1 1640 assert len(schema1["expectations"]) == 1 1641 1642 stage2_records = [ 1643 { 1644 "inputs": { 1645 "prompt": "Explain Python", 1646 "temperature": 0.7, 1647 "max_length": 500, 1648 "top_p": 0.95, 1649 }, 1650 "expectations": { 1651 "response": "Python is a programming language", 1652 "quality_score": 0.85, 1653 "token_count": 127, 1654 }, 1655 } 1656 ] 1657 dataset.merge_records(stage2_records) 1658 1659 ds2 = get_dataset(dataset_id=dataset.dataset_id) 1660 schema2 = json.loads(ds2.schema) 1661 assert "temperature" in schema2["inputs"] 1662 assert schema2["inputs"]["temperature"] == "float" 1663 assert "max_length" in schema2["inputs"] 1664 assert schema2["inputs"]["max_length"] == "integer" 1665 assert len(schema2["inputs"]) == 4 1666 assert len(schema2["expectations"]) == 3 1667 1668 stage3_records = [ 1669 { 1670 "inputs": { 1671 "prompt": "Complex query", 1672 "streaming": True, 1673 "stop_sequences": ["\n\n", "END"], 1674 "config": {"model": "gpt-4", "version": "1.0"}, 1675 }, 1676 "expectations": { 1677 "response": "Complex response", 1678 "is_valid": True, 1679 "citations": ["source1", "source2"], 1680 "metadata": {"confidence": 0.9}, 1681 }, 1682 } 1683 ] 1684 dataset.merge_records(stage3_records) 1685 1686 ds3 = get_dataset(dataset_id=dataset.dataset_id) 1687 schema3 = json.loads(ds3.schema) 1688 1689 assert schema3["inputs"]["streaming"] == "boolean" 1690 assert schema3["inputs"]["stop_sequences"] == "array" 1691 assert schema3["inputs"]["config"] == "object" 1692 assert schema3["expectations"]["is_valid"] == "boolean" 1693 assert schema3["expectations"]["citations"] == "array" 1694 assert schema3["expectations"]["metadata"] == "object" 1695 1696 assert "prompt" in schema3["inputs"] 1697 assert "temperature" in schema3["inputs"] 1698 assert "quality_score" in schema3["expectations"] 1699 1700 with mlflow.start_run(experiment_id=experiments[0]) as run: 1701 mlflow.log_input(dataset, context="evaluation") 1702 1703 mlflow.log_metrics({"accuracy": 0.92, "f1_score": 0.89}) 1704 1705 run_data = mlflow.get_run(run.info.run_id) 1706 assert run_data.inputs is not None 1707 assert run_data.inputs.dataset_inputs is not None 1708 assert len(run_data.inputs.dataset_inputs) > 0 1709 1710 dataset_input = run_data.inputs.dataset_inputs[0] 1711 assert dataset_input.dataset.name == "schema_evolution_test" 1712 assert dataset_input.dataset.source_type == "mlflow_evaluation_dataset" 1713 1714 tag_dict = {tag.key: tag.value for tag in dataset_input.tags} 1715 assert "mlflow.data.context" in tag_dict 1716 assert tag_dict["mlflow.data.context"] == "evaluation" 1717 1718 final_dataset = get_dataset(dataset_id=dataset.dataset_id) 1719 final_schema = json.loads(final_dataset.schema) 1720 1721 assert "inputs" in final_schema 1722 assert "expectations" in final_schema 1723 assert "version" in final_schema 1724 assert final_schema["version"] == "1.0" 1725 1726 profile = json.loads(final_dataset.profile) 1727 assert profile is not None 1728 assert profile["num_records"] == 3 1729 1730 consistency_records = [ 1731 { 1732 "inputs": {"prompt": "Another test", "temperature": 0.5, "max_length": 200}, 1733 "expectations": {"response": "Another response", "quality_score": 0.75}, 1734 } 1735 ] 1736 dataset.merge_records(consistency_records) 1737 1738 consistent_dataset = get_dataset(dataset_id=dataset.dataset_id) 1739 consistent_schema = json.loads(consistent_dataset.schema) 1740 1741 assert set(consistent_schema["inputs"].keys()) == set(final_schema["inputs"].keys()) 1742 assert set(consistent_schema["expectations"].keys()) == set(final_schema["expectations"].keys()) 1743 1744 consistent_profile = json.loads(consistent_dataset.profile) 1745 assert consistent_profile["num_records"] == 4 1746 delete_dataset_tag(dataset_id="test", key="key") 1747 1748 1749 def test_deprecated_parameter_substitution(experiment): 1750 with warnings.catch_warnings(record=True) as w: 1751 warnings.simplefilter("always") 1752 1753 dataset = create_dataset( 1754 uc_table_name="test_dataset_deprecated", 1755 experiment_id=experiment, 1756 tags={"test": "deprecated_parameter"}, 1757 ) 1758 1759 assert len(w) == 1 1760 assert issubclass(w[0].category, FutureWarning) 1761 assert "uc_table_name" in str(w[0].message) 1762 assert "deprecated" in str(w[0].message).lower() 1763 assert "name" in str(w[0].message) 1764 1765 assert dataset.name == "test_dataset_deprecated" 1766 assert dataset.tags["test"] == "deprecated_parameter" 1767 1768 with pytest.raises(ValueError, match="Cannot specify both.*uc_table_name.*and.*name"): 1769 create_dataset( 1770 uc_table_name="old_name", 1771 name="new_name", 1772 experiment_id=experiment, 1773 ) 1774 1775 with warnings.catch_warnings(record=True) as w: 1776 warnings.simplefilter("always") 1777 1778 with pytest.raises(ValueError, match="name.*only supported in Databricks"): 1779 delete_dataset(uc_table_name="test_dataset_deprecated") 1780 1781 assert len(w) == 1 1782 assert issubclass(w[0].category, FutureWarning) 1783 assert "uc_table_name" in str(w[0].message) 1784 1785 delete_dataset(dataset_id=dataset.dataset_id) 1786 1787 1788 def test_create_dataset_uses_active_experiment_when_not_specified(client): 1789 exp_id = mlflow.create_experiment("test_active_experiment") 1790 mlflow.set_experiment(experiment_id=exp_id) 1791 1792 dataset = create_dataset(name="test_with_active_exp") 1793 1794 assert dataset.experiment_ids == [exp_id] 1795 1796 from mlflow.tracking import fluent 1797 1798 fluent._active_experiment_id = None 1799 1800 1801 def test_create_dataset_with_no_active_experiment(client): 1802 from mlflow.tracking import fluent 1803 1804 fluent._active_experiment_id = None 1805 1806 dataset = create_dataset(name="test_no_active_exp") 1807 1808 assert dataset.experiment_ids == ["0"] 1809 1810 1811 def test_create_dataset_explicit_overrides_active_experiment(client): 1812 active_exp = mlflow.create_experiment("active_exp") 1813 explicit_exp = mlflow.create_experiment("explicit_exp") 1814 1815 mlflow.set_experiment(experiment_id=active_exp) 1816 1817 dataset = create_dataset(name="test_explicit_override", experiment_id=explicit_exp) 1818 1819 assert dataset.experiment_ids == [explicit_exp] 1820 1821 from mlflow.tracking import fluent 1822 1823 fluent._active_experiment_id = None 1824 1825 1826 def test_create_dataset_none_uses_active_experiment(client): 1827 exp_id = mlflow.create_experiment("test_none_experiment") 1828 mlflow.set_experiment(experiment_id=exp_id) 1829 1830 dataset = create_dataset(name="test_none_exp", experiment_id=None) 1831 1832 assert dataset.experiment_ids == [exp_id] 1833 1834 from mlflow.tracking import fluent 1835 1836 fluent._active_experiment_id = None 1837 1838 1839 def test_source_type_inference(): 1840 exp = mlflow.create_experiment("test_source_inference") 1841 dataset = create_dataset( 1842 name="test_source_inference", 1843 experiment_id=exp, 1844 tags={"test": "source_inference"}, 1845 ) 1846 1847 human_records = [ 1848 { 1849 "inputs": {"question": "What is MLflow?"}, 1850 "expectations": {"answer": "MLflow is an ML platform", "quality": 0.9}, 1851 }, 1852 { 1853 "inputs": {"question": "How to track experiments?"}, 1854 "expectations": {"answer": "Use mlflow.start_run()", "quality": 0.85}, 1855 }, 1856 ] 1857 dataset.merge_records(human_records) 1858 1859 df = dataset.to_df() 1860 human_sources = df[df["source_type"] == DatasetRecordSourceType.HUMAN.value] 1861 assert len(human_sources) == 2 1862 1863 code_records = [{"inputs": {"question": f"Generated question {i}"}} for i in range(3)] 1864 dataset.merge_records(code_records) 1865 1866 df = dataset.to_df() 1867 code_sources = df[df["source_type"] == DatasetRecordSourceType.CODE.value] 1868 assert len(code_sources) == 3 1869 1870 explicit_records = [ 1871 { 1872 "inputs": {"question": "Document-based question"}, 1873 "expectations": {"answer": "From document"}, 1874 "source": { 1875 "source_type": DatasetRecordSourceType.DOCUMENT.value, 1876 "source_data": {"source_id": "doc123", "page": 5}, 1877 }, 1878 } 1879 ] 1880 dataset.merge_records(explicit_records) 1881 1882 df = dataset.to_df() 1883 doc_sources = df[df["source_type"] == DatasetRecordSourceType.DOCUMENT.value] 1884 assert len(doc_sources) == 1 1885 assert doc_sources.iloc[0]["source_id"] == "doc123" 1886 1887 empty_exp_records = [{"inputs": {"question": "Has empty expectations"}, "expectations": {}}] 1888 dataset.merge_records(empty_exp_records) 1889 1890 df = dataset.to_df() 1891 last_record = df.iloc[-1] 1892 assert last_record["source_type"] not in [ 1893 DatasetRecordSourceType.HUMAN.value, 1894 DatasetRecordSourceType.CODE.value, 1895 ] 1896 1897 explicit_trace = [ 1898 { 1899 "inputs": {"question": "From trace"}, 1900 "source": { 1901 "source_type": DatasetRecordSourceType.TRACE.value, 1902 "source_data": {"trace_id": "trace123"}, 1903 }, 1904 } 1905 ] 1906 dataset.merge_records(explicit_trace) 1907 1908 df = dataset.to_df() 1909 trace_sources = df[df["source_type"] == DatasetRecordSourceType.TRACE.value] 1910 assert len(trace_sources) == 1, f"Expected 1 TRACE source, got {len(trace_sources)}" 1911 assert trace_sources.iloc[0]["source_id"] == "trace123" 1912 1913 source_counts = df["source_type"].value_counts() 1914 assert source_counts.get(DatasetRecordSourceType.HUMAN.value, 0) == 2 1915 assert source_counts.get(DatasetRecordSourceType.CODE.value, 0) == 3 1916 assert source_counts.get(DatasetRecordSourceType.DOCUMENT.value, 0) == 1 1917 assert source_counts.get(DatasetRecordSourceType.TRACE.value, 0) == 1 1918 1919 delete_dataset(dataset_id=dataset.dataset_id) 1920 1921 1922 def test_trace_source_type_detection(): 1923 exp = mlflow.create_experiment("test_trace_source_detection") 1924 1925 trace_ids = [] 1926 for i in range(3): 1927 with mlflow.start_run(experiment_id=exp): 1928 with mlflow.start_span(name=f"test_span_{i}") as span: 1929 span.set_inputs({"question": f"Question {i}", "context": f"Context {i}"}) 1930 span.set_outputs({"answer": f"Answer {i}"}) 1931 trace_ids.append(span.trace_id) 1932 1933 if i < 2: 1934 mlflow.log_expectation( 1935 trace_id=span.trace_id, 1936 name="quality", 1937 value=0.8 + i * 0.05, 1938 span_id=span.span_id, 1939 ) 1940 1941 dataset = create_dataset( 1942 name="test_trace_sources", 1943 experiment_id=exp, 1944 tags={"test": "trace_source_detection"}, 1945 ) 1946 1947 client = mlflow.MlflowClient() 1948 traces = [client.get_trace(tid) for tid in trace_ids] 1949 dataset.merge_records(traces) 1950 1951 df = dataset.to_df() 1952 trace_sources = df[df["source_type"] == DatasetRecordSourceType.TRACE.value] 1953 assert len(trace_sources) == 3 1954 1955 for trace_id in trace_ids: 1956 matching_records = df[df["source_id"] == trace_id] 1957 assert len(matching_records) == 1 1958 1959 dataset2 = create_dataset( 1960 name="test_trace_sources_df", 1961 experiment_id=exp, 1962 tags={"test": "trace_source_df"}, 1963 ) 1964 1965 traces_df = mlflow.search_traces(locations=[exp]) 1966 assert not traces_df.empty 1967 dataset2.merge_records(traces_df) 1968 1969 df2 = dataset2.to_df() 1970 trace_sources2 = df2[df2["source_type"] == DatasetRecordSourceType.TRACE.value] 1971 assert len(trace_sources2) == len(traces_df) 1972 1973 dataset3 = create_dataset( 1974 name="test_trace_sources_list", 1975 experiment_id=exp, 1976 tags={"test": "trace_source_list"}, 1977 ) 1978 1979 traces_list = mlflow.search_traces(locations=[exp], return_type="list") 1980 assert len(traces_list) > 0 1981 dataset3.merge_records(traces_list) 1982 1983 df3 = dataset3.to_df() 1984 trace_sources3 = df3[df3["source_type"] == DatasetRecordSourceType.TRACE.value] 1985 assert len(trace_sources3) == len(traces_list) 1986 1987 df_with_expectations = df[df["expectations"].apply(lambda x: bool(x) and len(x) > 0)] 1988 assert len(df_with_expectations) == 2 1989 1990 delete_dataset(dataset_id=dataset.dataset_id) 1991 delete_dataset(dataset_id=dataset2.dataset_id) 1992 delete_dataset(dataset_id=dataset3.dataset_id) 1993 1994 1995 def test_create_dataset_empty_list_stays_empty(client): 1996 exp_id = mlflow.create_experiment("test_empty_list") 1997 mlflow.set_experiment(experiment_id=exp_id) 1998 1999 dataset = create_dataset(name="test_empty_list", experiment_id=[]) 2000 2001 assert dataset.experiment_ids == [] 2002 2003 from mlflow.tracking import fluent 2004 2005 fluent._active_experiment_id = None 2006 2007 2008 def test_search_datasets_filter_string_edge_cases(client): 2009 exp_id = mlflow.create_experiment("test_filter_edge_cases") 2010 2011 dataset = create_dataset(name="test_dataset", experiment_id=exp_id, tags={"test": "value"}) 2012 2013 with mock.patch("mlflow.tracking.client.MlflowClient.search_datasets") as mock_search: 2014 mock_search.return_value = mock.MagicMock(token=None, items=[dataset]) 2015 2016 search_datasets(experiment_ids=exp_id, filter_string=None) 2017 call_args = mock_search.call_args 2018 filter_arg = call_args.kwargs.get("filter_string") 2019 assert "created_time >=" in filter_arg 2020 2021 mock_search.reset_mock() 2022 2023 search_datasets(experiment_ids=exp_id, filter_string=[]) 2024 call_args = mock_search.call_args 2025 filter_arg = call_args.kwargs.get("filter_string") 2026 assert "created_time >=" in filter_arg 2027 2028 mock_search.reset_mock() 2029 2030 search_datasets(experiment_ids=exp_id, filter_string="") 2031 call_args = mock_search.call_args 2032 filter_arg = call_args.kwargs.get("filter_string") 2033 assert "created_time >=" in filter_arg 2034 2035 mock_search.reset_mock() 2036 2037 search_datasets(experiment_ids=exp_id, filter_string='name = "test"') 2038 call_args = mock_search.call_args 2039 filter_arg = call_args.kwargs.get("filter_string") 2040 assert filter_arg == 'name = "test"' 2041 2042 2043 def test_wrapper_type_is_actually_returned_not_entity(experiments): 2044 dataset = create_dataset( 2045 name="test_wrapper", 2046 experiment_id=experiments[0], 2047 tags={"test": "wrapper_check"}, 2048 ) 2049 2050 assert isinstance(dataset, WrapperEvaluationDataset) 2051 assert not isinstance(dataset, EntityEvaluationDataset) 2052 assert hasattr(dataset, "_mlflow_dataset") 2053 assert dataset._mlflow_dataset is not None 2054 assert isinstance(dataset._mlflow_dataset, EntityEvaluationDataset) 2055 2056 2057 def test_wrapper_delegates_all_properties_correctly(experiments): 2058 dataset = create_dataset( 2059 name="test_delegation", 2060 experiment_id=experiments[0], 2061 tags={"env": "test", "version": "1.0"}, 2062 ) 2063 2064 assert dataset.name == "test_delegation" 2065 assert dataset.dataset_id.startswith("d-") 2066 assert dataset.tags["env"] == "test" 2067 assert dataset.tags["version"] == "1.0" 2068 assert experiments[0] in dataset.experiment_ids 2069 assert dataset.created_time > 0 2070 assert dataset.last_update_time > 0 2071 assert dataset.digest is not None 2072 assert hasattr(dataset, "source") 2073 assert dataset.source._get_source_type() == "mlflow_evaluation_dataset" 2074 2075 2076 def test_get_and_search_return_wrapper_not_entity(experiments): 2077 created = create_dataset( 2078 name="test_get_wrapper", 2079 experiment_id=experiments[0], 2080 tags={"test": "get"}, 2081 ) 2082 2083 retrieved = get_dataset(dataset_id=created.dataset_id) 2084 assert isinstance(retrieved, WrapperEvaluationDataset) 2085 assert not isinstance(retrieved, EntityEvaluationDataset) 2086 assert retrieved.dataset_id == created.dataset_id 2087 assert retrieved.name == created.name 2088 2089 results = search_datasets( 2090 experiment_ids=experiments[0], 2091 filter_string="name = 'test_get_wrapper'", 2092 ) 2093 assert len(results) == 1 2094 assert isinstance(results[0], WrapperEvaluationDataset) 2095 assert not isinstance(results[0], EntityEvaluationDataset) 2096 2097 2098 def test_wrapper_vs_direct_client_usage(experiments): 2099 client = MlflowClient() 2100 2101 entity_dataset = client.create_dataset( 2102 name="test_client_direct", 2103 experiment_id=experiments[0], 2104 tags={"direct": "client"}, 2105 ) 2106 assert isinstance(entity_dataset, EntityEvaluationDataset) 2107 assert not isinstance(entity_dataset, WrapperEvaluationDataset) 2108 2109 wrapped_dataset = create_dataset( 2110 name="test_wrapped", 2111 experiment_id=experiments[0], 2112 tags={"wrapped": "fluent"}, 2113 ) 2114 assert isinstance(wrapped_dataset, WrapperEvaluationDataset) 2115 assert not isinstance(wrapped_dataset, EntityEvaluationDataset) 2116 assert wrapped_dataset._mlflow_dataset is not None 2117 2118 wrapped_from_entity = WrapperEvaluationDataset(entity_dataset) 2119 assert wrapped_from_entity == entity_dataset 2120 2121 2122 def test_wrapper_works_with_mlflow_log_input_integration(experiments): 2123 dataset = create_dataset( 2124 name="test_log_input", 2125 experiment_id=experiments[0], 2126 ) 2127 2128 records = [ 2129 { 2130 "inputs": {"question": "Test question"}, 2131 "expectations": {"answer": "Test answer"}, 2132 } 2133 ] 2134 dataset.merge_records(records) 2135 2136 with mlflow.start_run(experiment_id=experiments[0]) as run: 2137 mlflow.log_input(dataset, context="evaluation") 2138 2139 run_data = mlflow.get_run(run.info.run_id) 2140 assert len(run_data.inputs.dataset_inputs) == 1 2141 dataset_input = run_data.inputs.dataset_inputs[0] 2142 assert dataset_input.dataset.name == "test_log_input" 2143 assert dataset_input.dataset.digest == dataset.digest 2144 2145 2146 def test_wrapper_isinstance_checks_for_dataset_interfaces(experiments): 2147 dataset = create_dataset( 2148 name="test_isinstance", 2149 experiment_id=experiments[0], 2150 ) 2151 2152 assert isinstance(dataset, Dataset) 2153 assert isinstance(dataset, PyFuncConvertibleDatasetMixin) 2154 assert isinstance(dataset, WrapperEvaluationDataset) 2155 assert not isinstance(dataset, EntityEvaluationDataset) 2156 assert isinstance(dataset, (WrapperEvaluationDataset, EntityEvaluationDataset)) 2157 2158 2159 @pytest.mark.parametrize( 2160 "records", 2161 [ 2162 [ 2163 {"inputs": {"persona": "Student", "goal": "Find articles"}}, 2164 { 2165 "inputs": { 2166 "persona": "Researcher", 2167 "goal": "Review", 2168 "context": {"dept": "CS"}, 2169 } 2170 }, 2171 {"inputs": {"goal": "Single goal"}, "expectations": {"output": "expected"}}, 2172 ], 2173 [ 2174 {"inputs": {"goal": "Learn ML", "simulation_guidelines": "Be concise"}}, 2175 { 2176 "inputs": { 2177 "persona": "Engineer", 2178 "goal": "Debug", 2179 "simulation_guidelines": "Focus on logs", 2180 } 2181 }, 2182 { 2183 "inputs": { 2184 "persona": "Student", 2185 "goal": "Study", 2186 "context": {"course": "CS101"}, 2187 "simulation_guidelines": "Ask clarifying questions", 2188 } 2189 }, 2190 ], 2191 ], 2192 ) 2193 def test_multiturn_valid_formats(experiments, records): 2194 dataset = create_dataset(name="multiturn_test", experiment_id=experiments[0]) 2195 dataset.merge_records(records) 2196 df = dataset.to_df() 2197 2198 assert len(df) == 3 2199 for _, row in df.iterrows(): 2200 assert any( 2201 key in row["inputs"] for key in ["persona", "goal", "context", "simulation_guidelines"] 2202 ) 2203 2204 2205 @pytest.mark.parametrize( 2206 ("records", "error_pattern"), 2207 [ 2208 # Top-level session fields 2209 ( 2210 [{"persona": "Student", "goal": "Find articles", "custom_field": "value"}], 2211 "Each record must have an 'inputs' field", 2212 ), 2213 # Mixed fields in inputs 2214 ( 2215 [ 2216 { 2217 "inputs": { 2218 "persona": "Student", 2219 "goal": "Find", 2220 "custom_field": "value", 2221 } 2222 } 2223 ], 2224 "Invalid input schema.*cannot mix session fields", 2225 ), 2226 # Inconsistent batch schema 2227 ( 2228 [ 2229 {"inputs": {"persona": "Student", "goal": "Find articles"}}, 2230 {"inputs": {"question": "What is MLflow?"}}, 2231 ], 2232 "must use the same granularity.*Found", 2233 ), 2234 # Empty inputs in batch with session records 2235 ( 2236 [ 2237 {"inputs": {"goal": "Find articles"}}, 2238 {"inputs": {}}, 2239 ], 2240 "Empty inputs are not allowed for session records.*'goal' field is required", 2241 ), 2242 ], 2243 ) 2244 def test_multiturn_validation_errors(experiments, records, error_pattern): 2245 dataset = create_dataset(name="multiturn_error_test", experiment_id=experiments[0]) 2246 with pytest.raises(MlflowException, match=error_pattern): 2247 dataset.merge_records(records) 2248 2249 2250 @pytest.mark.parametrize( 2251 ("existing_records", "new_records"), 2252 [ 2253 # Multiturn then custom 2254 ( 2255 [{"inputs": {"persona": "Student", "goal": "Find articles"}}], 2256 [{"inputs": {"question": "What is MLflow?", "model": "gpt-4"}}], 2257 ), 2258 # Custom then multiturn 2259 ( 2260 [{"inputs": {"question": "What is MLflow?", "model": "gpt-4"}}], 2261 [{"inputs": {"persona": "Student", "goal": "Find articles"}}], 2262 ), 2263 ], 2264 ) 2265 def test_multiturn_schema_compatibility(experiments, existing_records, new_records): 2266 dataset = create_dataset(name="multiturn_compat_test", experiment_id=experiments[0]) 2267 dataset.merge_records(existing_records) 2268 2269 with pytest.raises(MlflowException, match="Cannot mix granularities"): 2270 dataset.merge_records(new_records) 2271 2272 2273 def test_multiturn_with_expectations_and_tags(experiments): 2274 dataset = create_dataset(name="multiturn_full_test", experiment_id=experiments[0]) 2275 records = [ 2276 { 2277 "inputs": { 2278 "persona": "Graduate Student", 2279 "goal": "Find peer-reviewed articles on machine learning", 2280 "context": {"user_id": "U0001", "department": "CS"}, 2281 "simulation_guidelines": "Be thorough and cite sources", 2282 }, 2283 "expectations": {"expected_output": "relevant articles", "quality": "high"}, 2284 "tags": {"difficulty": "medium"}, 2285 }, 2286 { 2287 "inputs": { 2288 "persona": "Librarian", 2289 "goal": "Help with inter-library loan", 2290 }, 2291 "expectations": {"expected_output": "loan information"}, 2292 }, 2293 ] 2294 2295 dataset.merge_records(records) 2296 2297 df = dataset.to_df() 2298 assert len(df) == 2 2299 2300 grad_record = df[df["inputs"].apply(lambda x: x.get("persona") == "Graduate Student")].iloc[0] 2301 assert grad_record["expectations"]["expected_output"] == "relevant articles" 2302 assert grad_record["expectations"]["quality"] == "high" 2303 assert grad_record["tags"]["difficulty"] == "medium" 2304 assert grad_record["inputs"]["context"] == {"user_id": "U0001", "department": "CS"} 2305 assert grad_record["inputs"]["simulation_guidelines"] == "Be thorough and cite sources"