test_utils.py
1 import json 2 import sys 3 from typing import Any, Literal 4 from unittest.mock import MagicMock, Mock, patch 5 6 import pandas as pd 7 import pytest 8 9 import mlflow 10 from mlflow.entities.assessment_source import AssessmentSource 11 from mlflow.entities.span import SpanType 12 from mlflow.entities.trace import Trace 13 from mlflow.exceptions import MlflowException 14 from mlflow.genai import scorer 15 from mlflow.genai.datasets import EvaluationDataset, create_dataset 16 from mlflow.genai.evaluation.utils import ( 17 _convert_scorer_to_legacy_metric, 18 _convert_to_eval_set, 19 _deserialize_trace_column_if_needed, 20 validate_tags, 21 ) 22 from mlflow.genai.scorers.builtin_scorers import RelevanceToQuery 23 from mlflow.utils.spark_utils import is_spark_connect_mode 24 25 from tests.genai.conftest import databricks_only 26 27 28 @pytest.fixture(scope="module") 29 def spark(): 30 # databricks-agents installs databricks-connect 31 if is_spark_connect_mode(): 32 pytest.skip("Local Spark Session is not supported when databricks-connect is installed.") 33 34 from pyspark.sql import SparkSession 35 36 with SparkSession.builder.getOrCreate() as spark: 37 yield spark 38 39 40 def count_rows(data: Any) -> int: 41 try: 42 from mlflow.utils.spark_utils import get_spark_dataframe_type 43 44 if isinstance(data, get_spark_dataframe_type()): 45 return data.count() 46 except Exception: 47 pass 48 49 if isinstance(data, EvaluationDataset): 50 data = data.to_df() 51 52 return len(data) 53 54 55 @pytest.fixture 56 def sample_dict_data_single(): 57 return [ 58 { 59 "inputs": {"question": "What is Spark?"}, 60 "outputs": "actual response for first question", 61 "expectations": {"expected_response": "expected response for first question"}, 62 "tags": {"sample_tag": "value"}, 63 }, 64 ] 65 66 67 @pytest.fixture 68 def sample_dict_data_multiple(): 69 return [ 70 { 71 "inputs": {"question": "What is Spark?"}, 72 "outputs": "actual response for first question", 73 "expectations": {"expected_response": "expected response for first question"}, 74 "tags": {"category": "spark"}, 75 }, 76 { 77 "inputs": {"question": "How can you minimize data shuffling in Spark?"}, 78 "outputs": "actual response for second question", 79 "expectations": {"expected_response": "expected response for second question"}, 80 "tags": {"category": "spark", "topic": "optimization"}, 81 }, 82 # Some records might not have expectations or tags 83 { 84 "inputs": {"question": "What is MLflow?"}, 85 "outputs": "actual response for third question", 86 "expectations": {}, 87 "tags": {}, 88 }, 89 ] 90 91 92 @pytest.fixture 93 def sample_dict_data_multiple_with_custom_expectations(): 94 return [ 95 { 96 "inputs": {"question": "What is Spark?"}, 97 "outputs": "actual response for first question", 98 "expectations": { 99 "expected_response": "expected response for first question", 100 "my_custom_expectation": "custom expectation for the first question", 101 }, 102 }, 103 { 104 "inputs": {"question": "How can you minimize data shuffling in Spark?"}, 105 "outputs": "actual response for second question", 106 "expectations": { 107 "expected_response": "expected response for second question", 108 "my_custom_expectation": "custom expectation for the second question", 109 }, 110 }, 111 # Some records might not have all expectations 112 { 113 "inputs": {"question": "What is MLflow?"}, 114 "outputs": "actual response for third question", 115 "expectations": { 116 "my_custom_expectation": "custom expectation for the third question", 117 }, 118 }, 119 ] 120 121 122 @pytest.fixture 123 def sample_pd_data(sample_dict_data_multiple): 124 """Returns a pandas DataFrame with sample data""" 125 return pd.DataFrame(sample_dict_data_multiple) 126 127 128 @pytest.fixture 129 def sample_spark_data(sample_pd_data, spark): 130 """Convert pandas DataFrame to PySpark DataFrame""" 131 return spark.createDataFrame(sample_pd_data) 132 133 134 @pytest.fixture 135 def sample_spark_data_with_string_columns(sample_pd_data, spark): 136 # Cast inputs and expectations columns to string 137 df = sample_pd_data.copy() 138 df["inputs"] = df["inputs"].apply(json.dumps) 139 df["expectations"] = df["expectations"].apply(json.dumps) 140 return spark.createDataFrame(df) 141 142 143 @pytest.fixture 144 def sample_evaluation_dataset(sample_dict_data_single): 145 dataset = create_dataset("test") 146 dataset.merge_records(sample_dict_data_single) 147 return dataset 148 149 150 _ALL_DATA_FIXTURES = [ 151 "sample_dict_data_single", 152 "sample_dict_data_multiple", 153 "sample_dict_data_multiple_with_custom_expectations", 154 "sample_pd_data", 155 "sample_spark_data", 156 "sample_spark_data_with_string_columns", 157 "sample_evaluation_dataset", 158 ] 159 160 161 class TestModel: 162 @mlflow.trace(span_type=SpanType.AGENT) 163 def predict(self, question: str) -> str: 164 response = self.call_llm(messages=[{"role": "user", "content": question}]) 165 return response["choices"][0]["message"]["content"] 166 167 @mlflow.trace(span_type=SpanType.LLM) 168 def call_llm(self, messages: list[dict[str, Any]]) -> dict[str, Any]: 169 return {"choices": [{"message": {"role": "assistant", "content": "I don't know"}}]} 170 171 172 def get_test_traces(type=Literal["pandas", "list"]): 173 model = TestModel() 174 175 model.predict("What is MLflow?") 176 trace_id = mlflow.get_last_active_trace_id() 177 178 # Add assessments. Since log_assessment API is not supported in OSS MLflow yet, we 179 # need to add it to the trace info manually. 180 source = AssessmentSource(source_id="test", source_type="HUMAN") 181 # 1. Expectation with reserved name "expected_response" 182 mlflow.log_expectation( 183 trace_id=trace_id, 184 name="expected_response", 185 value="expected response for first question", 186 source=source, 187 ) 188 # 2. Expectation with reserved name "expected_facts" 189 mlflow.log_expectation( 190 trace_id=trace_id, 191 name="expected_facts", 192 value=["fact1", "fact2"], 193 source=source, 194 ) 195 # 3. Expectation with reserved name "guidelines" 196 mlflow.log_expectation( 197 trace_id=trace_id, 198 name="guidelines", 199 value=["Be polite", "Be kind"], 200 source=source, 201 ) 202 # 4. Expectation with custom name "my_custom_expectation" 203 mlflow.log_expectation( 204 trace_id=trace_id, 205 name="my_custom_expectation", 206 value="custom expectation for the first question", 207 source=source, 208 ) 209 # 5. Non-expectation assessment 210 mlflow.log_feedback( 211 trace_id=trace_id, 212 name="feedback", 213 value="some feedback", 214 source=source, 215 ) 216 traces = mlflow.search_traces(return_type=type, order_by=["timestamp_ms ASC"]) 217 return [{"trace": trace} for trace in traces] if type == "list" else traces 218 219 220 @pytest.mark.parametrize("input_type", ["list", "pandas"]) 221 def test_convert_to_legacy_eval_traces(input_type): 222 sample_data = get_test_traces(type=input_type) 223 data = _convert_to_eval_set(sample_data) 224 225 assert "trace" in data.columns 226 227 # "inputs" column should be derived from the trace 228 assert "inputs" in data.columns 229 assert list(data["inputs"]) == [{"question": "What is MLflow?"}] 230 assert data["expectations"][0] == { 231 "expected_response": "expected response for first question", 232 "expected_facts": ["fact1", "fact2"], 233 "guidelines": ["Be polite", "Be kind"], 234 "my_custom_expectation": "custom expectation for the first question", 235 } 236 # Assessment with type "Feedback" should not be present in the transformed data 237 assert "feedback" not in data.columns 238 239 240 @pytest.mark.parametrize("data_fixture", _ALL_DATA_FIXTURES) 241 def test_convert_to_eval_set_has_no_errors(data_fixture, request): 242 sample_data = request.getfixturevalue(data_fixture) 243 244 transformed_data = _convert_to_eval_set(sample_data) 245 246 assert "inputs" in transformed_data.columns 247 assert "outputs" in transformed_data.columns 248 assert "expectations" in transformed_data.columns 249 250 251 def test_convert_to_eval_set_without_request_and_response(): 252 for _ in range(3): 253 with mlflow.start_span(): 254 pass 255 256 trace_df = mlflow.search_traces() 257 trace_df = trace_df[["trace"]] 258 transformed_data = _convert_to_eval_set(trace_df) 259 260 assert "inputs" in transformed_data.columns 261 assert "outputs" in transformed_data.columns 262 assert transformed_data["inputs"].isna().all() 263 264 265 def test_convert_to_eval_set_with_missing_root_span(): 266 # Create traces 267 for _ in range(2): 268 with mlflow.start_span(): 269 pass 270 271 trace_df = mlflow.search_traces() 272 trace_df = trace_df[["trace"]] 273 274 # Deserialize the trace from JSON string to Trace object 275 trace_df["trace"] = trace_df["trace"].apply( 276 lambda t: Trace.from_json(t) if isinstance(t, str) else t 277 ) 278 279 # Mock _get_root_span to return None for the first trace to simulate missing root span 280 with patch.object(trace_df["trace"].iloc[0].data, "_get_root_span", return_value=None): 281 transformed_data = _convert_to_eval_set(trace_df) 282 283 # Verify inputs and outputs columns exist 284 assert "inputs" in transformed_data.columns 285 assert "outputs" in transformed_data.columns 286 287 # Verify first trace has None for inputs/outputs (missing root span) 288 assert transformed_data["inputs"].iloc[0] is None 289 assert transformed_data["outputs"].iloc[0] is None 290 291 # Verify second trace has None for inputs/outputs (normal empty span behavior) 292 assert transformed_data["inputs"].iloc[1] is None 293 assert transformed_data["outputs"].iloc[1] is None 294 295 296 def test_convert_to_legacy_eval_raise_for_invalid_json_columns(spark): 297 # Data with invalid `inputs` column 298 df = spark.createDataFrame([ 299 {"inputs": "invalid json", "expectations": '{"expected_response": "expected"}'}, 300 {"inputs": "invalid json", "expectations": '{"expected_response": "expected"}'}, 301 ]) 302 with pytest.raises(MlflowException, match="Failed to parse `inputs` column."): 303 _convert_to_eval_set(df) 304 305 # Data with invalid `expectations` column 306 df = spark.createDataFrame([ 307 { 308 "inputs": '{"question": "What is the capital of France?"}', 309 "expectations": "invalid expectations", 310 }, 311 { 312 "inputs": '{"question": "What is the capital of Germany?"}', 313 "expectations": "invalid expectations", 314 }, 315 ]) 316 with pytest.raises(MlflowException, match="Failed to parse `expectations` column."): 317 _convert_to_eval_set(df) 318 319 320 def _trace_test_cases(): 321 data = { 322 "info": { 323 "trace_id": "test-trace-id", 324 "trace_location": { 325 "type": "MLFLOW_EXPERIMENT", 326 "mlflow_experiment": {"experiment_id": "0"}, 327 }, 328 "request_time": "2024-01-21T12:00:00Z", 329 "state": "OK", 330 "trace_metadata": {}, 331 "tags": {}, 332 "assessments": [], 333 }, 334 "data": {"spans": []}, 335 } 336 return [ 337 pytest.param(data, dict, id="dict"), 338 pytest.param(json.dumps(data), str, id="string"), 339 pytest.param(Trace.from_dict(data), Trace, id="trace_object"), 340 ] 341 342 343 @pytest.mark.parametrize(("trace_value", "expected_input_type"), _trace_test_cases()) 344 def test_deserialize_trace_column(trace_value, expected_input_type): 345 df = pd.DataFrame([{"trace": trace_value, "inputs": {"question": "test"}}]) 346 assert isinstance(df["trace"].iloc[0], expected_input_type) 347 348 result = _deserialize_trace_column_if_needed(df) 349 assert isinstance(result["trace"].iloc[0], Trace) 350 assert result["trace"].iloc[0].info.trace_id == "test-trace-id" 351 352 353 def test_deserialize_trace_column_with_none(): 354 df = pd.DataFrame([{"trace": None, "inputs": {"question": "test"}}]) 355 356 result = _deserialize_trace_column_if_needed(df) 357 assert result["trace"].iloc[0] is None 358 359 360 @pytest.mark.parametrize("data_fixture", _ALL_DATA_FIXTURES) 361 def test_scorer_receives_correct_data(data_fixture, request): 362 sample_data = request.getfixturevalue(data_fixture) 363 364 received_args = [] 365 366 @scorer 367 def dummy_scorer(inputs, outputs, expectations): 368 received_args.append(( 369 inputs["question"], 370 outputs, 371 expectations.get("expected_response"), 372 expectations.get("my_custom_expectation"), 373 )) 374 return 0 375 376 mlflow.genai.evaluate( 377 data=sample_data, 378 scorers=[dummy_scorer], 379 ) 380 381 all_inputs, all_outputs, all_expectations, all_custom_expectations = zip(*received_args) 382 row_count = count_rows(sample_data) 383 expected_inputs = [ 384 "What is Spark?", 385 "How can you minimize data shuffling in Spark?", 386 "What is MLflow?", 387 ][:row_count] 388 expected_outputs = [ 389 "actual response for first question", 390 "actual response for second question", 391 "actual response for third question", 392 ][:row_count] 393 expected_expectations = [ 394 "expected response for first question", 395 "expected response for second question", 396 None, 397 ][:row_count] 398 399 assert set(all_inputs) == set(expected_inputs) 400 assert set(all_outputs) == set(expected_outputs) 401 assert set(all_expectations) == set(expected_expectations) 402 403 if data_fixture == "sample_dict_data_multiple_with_custom_expectations": 404 expected_custom_expectations = [ 405 "custom expectation for the first question", 406 "custom expectation for the second question", 407 "custom expectation for the third question", 408 ] 409 assert set(all_custom_expectations) == set(expected_custom_expectations) 410 411 412 def test_input_is_required_if_trace_is_not_provided(): 413 with patch("mlflow.genai.evaluation.harness.run") as mock_evaluate: 414 with pytest.raises(MlflowException, match="inputs.*required"): 415 mlflow.genai.evaluate( 416 data=pd.DataFrame({"outputs": ["Paris"]}), 417 scorers=[RelevanceToQuery()], 418 ) 419 420 mock_evaluate.assert_not_called() 421 422 mlflow.genai.evaluate( 423 data=pd.DataFrame({ 424 "inputs": [{"question": "What is the capital of France?"}], 425 "outputs": ["Paris"], 426 }), 427 scorers=[RelevanceToQuery()], 428 ) 429 mock_evaluate.assert_called_once() 430 431 432 def test_input_is_optional_if_trace_is_provided(): 433 with mlflow.start_span() as span: 434 span.set_inputs({"question": "What is the capital of France?"}) 435 span.set_outputs("Paris") 436 437 trace = mlflow.get_trace(span.trace_id) 438 439 with patch("mlflow.genai.evaluation.harness.run") as mock_evaluate: 440 mlflow.genai.evaluate( 441 data=pd.DataFrame({"trace": [trace]}), 442 scorers=[RelevanceToQuery()], 443 ) 444 445 mock_evaluate.assert_called_once() 446 447 448 @pytest.mark.parametrize("input_type", ["list", "pandas"]) 449 def test_scorer_receives_correct_data_with_trace_data(input_type, monkeypatch: pytest.MonkeyPatch): 450 sample_data = get_test_traces(type=input_type) 451 received_args = [] 452 453 @scorer 454 def dummy_scorer(inputs, outputs, expectations, trace): 455 received_args.append((inputs, outputs, expectations, trace)) 456 return 0 457 458 # Disable logging traces to MLflow to avoid calling mlflow APIs which need to be mocked 459 monkeypatch.setenv("AGENT_EVAL_LOG_TRACES_TO_MLFLOW_ENABLED", "false") 460 mlflow.genai.evaluate( 461 data=sample_data, 462 scorers=[dummy_scorer], 463 ) 464 465 inputs, outputs, expectations, trace = received_args[0] 466 assert inputs == {"question": "What is MLflow?"} 467 assert outputs == "I don't know" 468 assert expectations == { 469 "expected_response": "expected response for first question", 470 "expected_facts": ["fact1", "fact2"], 471 "guidelines": ["Be polite", "Be kind"], 472 "my_custom_expectation": "custom expectation for the first question", 473 } 474 assert isinstance(trace, Trace) 475 476 477 @pytest.mark.parametrize("data_fixture", _ALL_DATA_FIXTURES) 478 def test_predict_fn_receives_correct_data(data_fixture, request): 479 sample_data = request.getfixturevalue(data_fixture) 480 481 received_args = [] 482 483 def predict_fn(question: str): 484 received_args.append(question) 485 return question 486 487 @scorer 488 def dummy_scorer(inputs, outputs): 489 return 0 490 491 mlflow.genai.evaluate( 492 predict_fn=predict_fn, 493 data=sample_data, 494 scorers=[dummy_scorer], 495 ) 496 497 received_args.pop(0) # Remove the one-time prediction to check if a model is traced 498 row_count = count_rows(sample_data) 499 assert len(received_args) == row_count 500 expected_contents = [ 501 "What is Spark?", 502 "How can you minimize data shuffling in Spark?", 503 "What is MLflow?", 504 ][:row_count] 505 # Using set because eval harness runs predict_fn in parallel 506 assert set(received_args) == set(expected_contents) 507 508 509 def test_convert_scorer_to_legacy_metric_aggregations_attribute(monkeypatch): 510 mock_metric_instance = MagicMock() 511 512 # NB: Mocking the behavior of databricks-agents, which does not have the aggregations 513 # argument for the evaluation interface for a metric. 514 def mock_metric_decorator(**kwargs): 515 if "aggregations" in kwargs: 516 raise TypeError("metric() got an unexpected keyword argument 'aggregations'") 517 assert set(kwargs.keys()) <= {"eval_fn", "name"} 518 return mock_metric_instance 519 520 mock_evals = Mock(metric=mock_metric_decorator) 521 mock_evals.judges = Mock() # Add the judges submodule to prevent AttributeError 522 523 monkeypatch.setitem(sys.modules, "databricks.agents.evals", mock_evals) 524 monkeypatch.setitem(sys.modules, "databricks.agents.evals.judges", mock_evals.judges) 525 526 mock_scorer = Mock() 527 mock_scorer.name = "test_scorer" 528 mock_scorer.aggregations = ["mean", "max", "p90"] 529 mock_scorer.run = Mock(return_value={"score": 1.0}) 530 531 result = _convert_scorer_to_legacy_metric(mock_scorer) 532 533 assert result.aggregations == ["mean", "max", "p90"] 534 535 536 @databricks_only 537 def test_convert_scorer_to_legacy_metric(): 538 # Test with a built-in scorer 539 builtin_scorer = RelevanceToQuery() 540 legacy_metric = _convert_scorer_to_legacy_metric(builtin_scorer) 541 542 # Verify the metric has the _is_builtin_scorer attribute set to True 543 assert hasattr(legacy_metric, "_is_builtin_scorer") 544 assert legacy_metric._is_builtin_scorer is True 545 assert legacy_metric.name == builtin_scorer.name 546 547 # Test with a custom scorer 548 @scorer(name="custom_scorer", aggregations=["mean", "max"]) 549 def custom_scorer_func(inputs, outputs=None, expectations=None, **kwargs): 550 return {"score": 1.0} 551 552 custom_scorer_instance = custom_scorer_func 553 legacy_metric_custom = _convert_scorer_to_legacy_metric(custom_scorer_instance) 554 555 # Verify the metric has the _is_builtin_scorer attribute set to False 556 assert hasattr(legacy_metric_custom, "_is_builtin_scorer") 557 assert legacy_metric_custom._is_builtin_scorer is False 558 assert legacy_metric_custom.name == custom_scorer_instance.name 559 assert legacy_metric_custom.aggregations == custom_scorer_instance.aggregations 560 561 562 @pytest.mark.parametrize( 563 "aggregations", 564 [ 565 ["mean", "max", "mean", "median", "variance", "p90"], 566 [sum, max], 567 ], 568 ) 569 @databricks_only 570 def test_scorer_pass_through_aggregations(aggregations): 571 @scorer(name="custom_scorer", aggregations=aggregations) 572 def custom_scorer_func(outputs): 573 return {"score": 1.0} 574 575 legacy_metric_custom = _convert_scorer_to_legacy_metric(custom_scorer_func) 576 assert legacy_metric_custom.name == "custom_scorer" 577 assert legacy_metric_custom.aggregations == aggregations 578 579 builtin_scorer = RelevanceToQuery(aggregations=aggregations) 580 legacy_metric_builtin = _convert_scorer_to_legacy_metric(builtin_scorer) 581 assert legacy_metric_builtin.name == "relevance_to_query" 582 assert legacy_metric_builtin.aggregations == builtin_scorer.aggregations 583 584 585 @pytest.mark.parametrize( 586 "tags", 587 [ 588 None, 589 {}, 590 {"key": "value"}, 591 {"env": "test", "model": "v1.0"}, 592 {"key": 123}, # Values can be any type 593 {"key1": "value1", "key2": None}, # Values can be any type 594 ], 595 ) 596 def test_validate_tags_valid(tags): 597 validate_tags(tags) 598 599 600 @pytest.mark.parametrize( 601 ("tags", "expected_error"), 602 [ 603 ("invalid", "Tags must be a dictionary, got str"), 604 (123, "Tags must be a dictionary, got int"), 605 ([1, 2, 3], "Tags must be a dictionary, got list"), 606 ({123: "value"}, "Invalid tags:\n - Key 123 has type int; expected str."), 607 ( 608 {"key1": "value1", 123: "value2"}, 609 "Invalid tags:\n - Key 123 has type int; expected str.", 610 ), 611 ( 612 {123: "value1", 456: "value2"}, 613 ( 614 "Invalid tags:\n - Key 123 has type int; expected str." 615 "\n - Key 456 has type int; expected str." 616 ), 617 ), 618 ], 619 ) 620 def test_validate_tags_invalid(tags, expected_error): 621 with pytest.raises(MlflowException, match=expected_error): 622 validate_tags(tags)