test_assessment.py
1 import os 2 3 import pytest 4 5 import mlflow 6 from mlflow.entities.assessment import ( 7 AssessmentError, 8 AssessmentSource, 9 Expectation, 10 Feedback, 11 IssueReference, 12 ) 13 from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType 14 from mlflow.entities.issue import IssueStatus 15 from mlflow.exceptions import MlflowException 16 from mlflow.version import IS_TRACING_SDK_ONLY 17 18 _HUMAN_ASSESSMENT_SOURCE = AssessmentSource( 19 source_type=AssessmentSourceType.HUMAN, 20 source_id="bob@example.com", 21 ) 22 23 _LLM_ASSESSMENT_SOURCE = AssessmentSource( 24 source_type=AssessmentSourceType.LLM_JUDGE, 25 source_id="gpt-4o-mini", 26 ) 27 28 _CODE_ASSESSMENT_SOURCE = AssessmentSource( 29 source_type=AssessmentSourceType.CODE, 30 source_id="issue_detector.py", 31 ) 32 33 34 @pytest.fixture 35 def trace_id(): 36 with mlflow.start_span(name="test_span") as span: 37 pass 38 39 mlflow.flush_trace_async_logging() 40 return span.trace_id 41 42 43 @pytest.fixture(params=["file", "sqlalchemy"], autouse=True) 44 def tracking_uri(request, tmp_path, db_uri): 45 """Set an MLflow Tracking URI with different type of backend.""" 46 if request.param == "file": 47 pytest.skip("FileStore is no longer supported.") 48 if "MLFLOW_SKINNY" in os.environ and request.param == "sqlalchemy": 49 pytest.skip("SQLAlchemy store is not available in skinny.") 50 51 if IS_TRACING_SDK_ONLY and request.param == "sqlalchemy": 52 pytest.skip("SQLAlchemy store is not available in tracing SDK only mode.") 53 54 original_tracking_uri = mlflow.get_tracking_uri() 55 56 if request.param == "file": 57 tracking_uri = tmp_path.joinpath("file").as_uri() 58 elif request.param == "sqlalchemy": 59 tracking_uri = db_uri 60 61 # NB: MLflow tracer does not handle the change of tracking URI well, 62 # so we need to reset the tracer to switch the tracking URI during testing. 63 mlflow.tracing.disable() 64 mlflow.set_tracking_uri(tracking_uri) 65 mlflow.tracing.enable() 66 67 yield tracking_uri 68 69 # Reset tracking URI 70 mlflow.set_tracking_uri(original_tracking_uri) 71 72 73 @pytest.mark.parametrize("legacy_api", [True, False]) 74 def test_log_expectation(trace_id, legacy_api): 75 if legacy_api: 76 mlflow.log_expectation( 77 trace_id=trace_id, 78 name="expected_answer", 79 value="MLflow", 80 source=_HUMAN_ASSESSMENT_SOURCE, 81 metadata={"key": "value"}, 82 ) 83 else: 84 feedback = Expectation( 85 name="expected_answer", 86 value="MLflow", 87 source=_HUMAN_ASSESSMENT_SOURCE, 88 metadata={"key": "value"}, 89 ) 90 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 91 92 trace = mlflow.get_trace(trace_id) 93 assert len(trace.info.assessments) == 1 94 assessment = trace.info.assessments[0] 95 assert isinstance(assessment, Expectation) 96 assert assessment.trace_id == trace_id 97 assert assessment.name == "expected_answer" 98 assert assessment.value == "MLflow" 99 assert assessment.trace_id == trace_id 100 assert assessment.span_id is None 101 assert assessment.source == _HUMAN_ASSESSMENT_SOURCE 102 assert assessment.create_time_ms is not None 103 assert assessment.last_update_time_ms is not None 104 assert assessment.expectation.value == "MLflow" 105 assert assessment.rationale is None 106 assert assessment.metadata == {"key": "value"} 107 108 109 def test_log_expectation_invalid_parameters(): 110 with pytest.raises(MlflowException, match=r"The `value` field must be specified."): 111 Expectation( 112 name="expected_answer", 113 value=None, 114 source=_HUMAN_ASSESSMENT_SOURCE, 115 ) 116 117 118 def test_update_expectation(trace_id): 119 assessment_id = mlflow.log_expectation( 120 trace_id=trace_id, 121 name="expected_answer", 122 value="MLflow", 123 ).assessment_id 124 125 updated_assessment = Expectation( 126 name="expected_answer", 127 value="Spark", 128 metadata={"reason": "human override"}, 129 ) 130 131 mlflow.update_assessment( 132 assessment_id=assessment_id, 133 trace_id=trace_id, 134 assessment=updated_assessment, 135 ) 136 137 trace = mlflow.get_trace(trace_id) 138 assert len(trace.info.assessments) == 1 139 assessment = trace.info.assessments[0] 140 assert assessment.trace_id == trace_id 141 assert assessment.name == "expected_answer" 142 assert assessment.expectation.value == "Spark" 143 assert assessment.feedback is None 144 assert assessment.rationale is None 145 assert assessment.metadata == {"reason": "human override"} 146 147 148 @pytest.mark.parametrize("legacy_api", [True, False]) 149 def test_log_feedback(trace_id, legacy_api): 150 if legacy_api: 151 mlflow.log_feedback( 152 trace_id=trace_id, 153 name="faithfulness", 154 value=1.0, 155 source=_LLM_ASSESSMENT_SOURCE, 156 rationale="This answer is very faithful.", 157 metadata={"model": "gpt-4o-mini"}, 158 ) 159 else: 160 feedback = Feedback( 161 name="faithfulness", 162 value=1.0, 163 source=_LLM_ASSESSMENT_SOURCE, 164 rationale="This answer is very faithful.", 165 metadata={"model": "gpt-4o-mini"}, 166 ) 167 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 168 169 trace = mlflow.get_trace(trace_id) 170 assert len(trace.info.assessments) == 1 171 assessment = trace.info.assessments[0] 172 assert isinstance(assessment, Feedback) 173 assert assessment.trace_id == trace_id 174 assert assessment.name == "faithfulness" 175 assert assessment.span_id is None 176 assert assessment.source == _LLM_ASSESSMENT_SOURCE 177 assert assessment.create_time_ms is not None 178 assert assessment.last_update_time_ms is not None 179 assert assessment.feedback.value == 1.0 180 assert assessment.feedback.error is None 181 assert assessment.expectation is None 182 assert assessment.rationale == "This answer is very faithful." 183 assert assessment.metadata == {"model": "gpt-4o-mini"} 184 185 186 @pytest.mark.parametrize("legacy_api", [True, False]) 187 def test_log_feedback_with_error(trace_id, legacy_api): 188 if legacy_api: 189 mlflow.log_feedback( 190 trace_id=trace_id, 191 name="faithfulness", 192 source=_LLM_ASSESSMENT_SOURCE, 193 error=AssessmentError( 194 error_code="RATE_LIMIT_EXCEEDED", 195 error_message="Rate limit for the judge exceeded.", 196 ), 197 ) 198 else: 199 feedback = Feedback( 200 name="faithfulness", 201 value=None, 202 source=_LLM_ASSESSMENT_SOURCE, 203 error=AssessmentError( 204 error_code="RATE_LIMIT_EXCEEDED", 205 error_message="Rate limit for the judge exceeded.", 206 ), 207 ) 208 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 209 210 trace = mlflow.get_trace(trace_id) 211 assert len(trace.info.assessments) == 1 212 assessment = trace.info.assessments[0] 213 assert assessment.name == "faithfulness" 214 assert assessment.trace_id == trace_id 215 assert assessment.span_id is None 216 assert assessment.source == _LLM_ASSESSMENT_SOURCE 217 assert assessment.create_time_ms is not None 218 assert assessment.last_update_time_ms is not None 219 assert assessment.expectation is None 220 assert assessment.feedback.value is None 221 assert assessment.feedback.error.error_code == "RATE_LIMIT_EXCEEDED" 222 assert assessment.feedback.error.error_message == "Rate limit for the judge exceeded." 223 assert assessment.rationale is None 224 225 226 @pytest.mark.parametrize("legacy_api", [True, False]) 227 def test_log_feedback_with_exception_object(trace_id, legacy_api): 228 test_exception = ValueError("Test exception message") 229 230 if legacy_api: 231 mlflow.log_feedback( 232 trace_id=trace_id, 233 name="faithfulness", 234 source=_LLM_ASSESSMENT_SOURCE, 235 error=test_exception, 236 ) 237 else: 238 feedback = Feedback( 239 name="faithfulness", 240 value=None, 241 source=_LLM_ASSESSMENT_SOURCE, 242 error=test_exception, 243 ) 244 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 245 246 trace = mlflow.get_trace(trace_id) 247 assert len(trace.info.assessments) == 1 248 assessment = trace.info.assessments[0] 249 assert assessment.name == "faithfulness" 250 assert assessment.trace_id == trace_id 251 assert assessment.span_id is None 252 assert assessment.source == _LLM_ASSESSMENT_SOURCE 253 assert assessment.create_time_ms is not None 254 assert assessment.last_update_time_ms is not None 255 assert assessment.expectation is None 256 assert assessment.feedback.value is None 257 # Exception should be converted to AssessmentError 258 assert assessment.feedback.error.error_code == "ValueError" 259 assert assessment.feedback.error.error_message == "Test exception message" 260 assert assessment.feedback.error.stack_trace is not None 261 assert assessment.rationale is None 262 263 264 @pytest.mark.parametrize("legacy_api", [True, False]) 265 def test_log_feedback_with_value_and_error(trace_id, legacy_api): 266 if legacy_api: 267 mlflow.log_feedback( 268 trace_id=trace_id, 269 name="faithfulness", 270 source=_LLM_ASSESSMENT_SOURCE, 271 value=0.5, 272 error=AssessmentError( 273 error_code="RATE_LIMIT_EXCEEDED", 274 error_message="Rate limit for the judge exceeded.", 275 ), 276 ) 277 else: 278 feedback = Feedback( 279 name="faithfulness", 280 value=0.5, 281 source=_LLM_ASSESSMENT_SOURCE, 282 error=AssessmentError( 283 error_code="RATE_LIMIT_EXCEEDED", 284 error_message="Rate limit for the judge exceeded.", 285 ), 286 ) 287 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 288 289 trace = mlflow.get_trace(trace_id) 290 assert len(trace.info.assessments) == 1 291 assessment = trace.info.assessments[0] 292 assert assessment.name == "faithfulness" 293 assert assessment.trace_id == trace_id 294 assert assessment.span_id is None 295 assert assessment.source == _LLM_ASSESSMENT_SOURCE 296 assert assessment.create_time_ms is not None 297 assert assessment.last_update_time_ms is not None 298 assert assessment.expectation is None 299 assert assessment.feedback.value == 0.5 300 assert assessment.feedback.error.error_code == "RATE_LIMIT_EXCEEDED" 301 assert assessment.feedback.error.error_message == "Rate limit for the judge exceeded." 302 assert assessment.rationale is None 303 304 305 def test_log_feedback_none_value(trace_id): 306 mlflow.log_feedback( 307 trace_id=trace_id, 308 name="faithfulness", 309 value=None, 310 source=_LLM_ASSESSMENT_SOURCE, 311 ) 312 313 trace = mlflow.get_trace(trace_id) 314 assert len(trace.info.assessments) == 1 315 assessment = trace.info.assessments[0] 316 assert isinstance(assessment, Feedback) 317 assert assessment.trace_id == trace_id 318 assert assessment.name == "faithfulness" 319 assert assessment.feedback.value is None 320 assert assessment.feedback.error is None 321 322 323 def test_log_feedback_invalid_parameters(): 324 # Test with a non-AssessmentSource object that is not None 325 with pytest.raises(MlflowException, match=r"`source` must be an instance of"): 326 Feedback( 327 trace_id="1234", 328 name="faithfulness", 329 value=1.0, 330 source="invalid_source_type", 331 ) 332 333 334 def test_update_feedback(trace_id): 335 assessment_id = mlflow.log_feedback( 336 trace_id=trace_id, 337 name="faithfulness", 338 value=1.0, 339 rationale="This answer is very faithful.", 340 metadata={"model": "gpt-4o-mini"}, 341 ).assessment_id 342 343 updated_feedback = Feedback( 344 name="faithfulness", 345 value=0, 346 rationale="This answer is not faithful.", 347 metadata={"reason": "human override"}, 348 ) 349 mlflow.update_assessment( 350 assessment_id=assessment_id, 351 trace_id=trace_id, 352 assessment=updated_feedback, 353 ) 354 355 trace = mlflow.get_trace(trace_id) 356 assert len(trace.info.assessments) == 1 357 assessment = trace.info.assessments[0] 358 assert assessment.name == "faithfulness" 359 assert assessment.trace_id == trace_id 360 assert assessment.feedback.value == 0 361 assert assessment.feedback.error is None 362 assert assessment.rationale == "This answer is not faithful." 363 assert assessment.metadata == { 364 "model": "gpt-4o-mini", 365 "reason": "human override", 366 } 367 368 369 def test_override_feedback(trace_id): 370 assessment_id = mlflow.log_feedback( 371 trace_id=trace_id, 372 name="faithfulness", 373 value=0.5, 374 source=_LLM_ASSESSMENT_SOURCE, 375 rationale="Original feedback", 376 metadata={"model": "gpt-3.5"}, 377 ).assessment_id 378 379 new_assessment_id = mlflow.override_feedback( 380 trace_id=trace_id, 381 assessment_id=assessment_id, 382 value=1.0, 383 source=_LLM_ASSESSMENT_SOURCE, 384 rationale="This answer is very faithful.", 385 metadata={"model": "gpt-4o-mini"}, 386 ).assessment_id 387 388 # New assessment should have the same trace_id as the original assessment 389 assessment = mlflow.get_assessment(trace_id, new_assessment_id) 390 assert assessment.trace_id == trace_id 391 assert assessment.name == "faithfulness" 392 assert assessment.span_id is None 393 assert assessment.source == _LLM_ASSESSMENT_SOURCE 394 assert assessment.create_time_ms is not None 395 assert assessment.last_update_time_ms is not None 396 assert assessment.value == 1.0 397 assert assessment.error is None 398 assert assessment.rationale == "This answer is very faithful." 399 assert assessment.metadata == {"model": "gpt-4o-mini"} 400 assert assessment.overrides == assessment_id 401 assert assessment.valid is True 402 403 # Original assessment should be invalidated 404 original_assessment = mlflow.get_assessment(trace_id, assessment_id) 405 assert original_assessment.valid is False 406 assert original_assessment.feedback.value == 0.5 407 408 409 def test_delete_assessment(trace_id): 410 assessment_id = mlflow.log_feedback( 411 trace_id=trace_id, 412 name="faithfulness", 413 value=1.0, 414 ).assessment_id 415 416 mlflow.delete_assessment(trace_id=trace_id, assessment_id=assessment_id) 417 418 with pytest.raises(MlflowException, match=r"Assessment with ID"): 419 assert mlflow.get_assessment(trace_id, assessment_id) is None 420 421 # Assessment should be deleted from the trace 422 trace = mlflow.get_trace(trace_id) 423 assert len(trace.info.assessments) == 0 424 425 426 def test_log_feedback_default_source(trace_id): 427 # Test that the default CODE source is used when no source is provided 428 feedback = Feedback( 429 trace_id=trace_id, 430 name="faithfulness", 431 value=1.0, 432 ) 433 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 434 435 trace = mlflow.get_trace(trace_id) 436 assert len(trace.info.assessments) == 1 437 assessment = trace.info.assessments[0] 438 assert assessment.name == "faithfulness" 439 assert assessment.trace_id == trace_id 440 assert assessment.source.source_type == AssessmentSourceType.CODE 441 assert assessment.source.source_id == "default" 442 assert assessment.feedback.value == 1.0 443 444 445 def test_log_expectation_default_source(trace_id): 446 # Test that the default HUMAN source is used when no source is provided 447 expectation = Expectation( 448 trace_id=trace_id, 449 name="expected_answer", 450 value="MLflow", 451 ) 452 mlflow.log_assessment(trace_id=trace_id, assessment=expectation) 453 454 trace = mlflow.get_trace(trace_id) 455 assert len(trace.info.assessments) == 1 456 assessment = trace.info.assessments[0] 457 assert assessment.name == "expected_answer" 458 assert assessment.trace_id == trace_id 459 assert assessment.source.source_type == AssessmentSourceType.HUMAN 460 assert assessment.source.source_id == "default" 461 assert assessment.expectation.value == "MLflow" 462 463 464 def test_log_issue(trace_id, tracking_uri): 465 if tracking_uri.startswith("file:"): 466 pytest.skip("Issue APIs are not supported with file-based tracking URI") 467 468 tracing_client = mlflow.MlflowClient()._tracing_client 469 issue = tracing_client._create_issue( 470 experiment_id="0", 471 name="timeout_error", 472 description="Timeout errors in API calls", 473 status=IssueStatus.PENDING, 474 ) 475 476 mlflow.log_issue( 477 trace_id=trace_id, 478 issue_id=issue.issue_id, 479 issue_name="timeout_error", 480 source=_CODE_ASSESSMENT_SOURCE, 481 rationale="Request exceeded 30 second timeout", 482 metadata={"severity": "high", "affected_count": "150"}, 483 ) 484 485 trace = mlflow.get_trace(trace_id) 486 assert len(trace.info.assessments) == 1 487 assessment = trace.info.assessments[0] 488 assert isinstance(assessment, IssueReference) 489 assert assessment.trace_id == trace_id 490 assert assessment.name == issue.issue_id 491 assert assessment.issue_id == issue.issue_id 492 assert assessment.issue_name == "timeout_error" 493 assert assessment.span_id is None 494 assert assessment.source == _CODE_ASSESSMENT_SOURCE 495 assert assessment.create_time_ms is not None 496 assert assessment.last_update_time_ms is not None 497 assert assessment.issue.issue_name == "timeout_error" 498 assert assessment.rationale == "Request exceeded 30 second timeout" 499 assert assessment.metadata == {"severity": "high", "affected_count": "150"} 500 501 502 def test_log_issue_default_source(trace_id, tracking_uri): 503 if tracking_uri.startswith("file:"): 504 pytest.skip("Issue APIs are not supported with file-based tracking URI") 505 506 tracing_client = mlflow.MlflowClient()._tracing_client 507 issue = tracing_client._create_issue( 508 experiment_id="0", 509 name="connection_issue", 510 description="Connection issues", 511 status=IssueStatus.PENDING, 512 ) 513 514 # Test that the default LLM_JUDGE source is used when no source is provided 515 mlflow.log_issue( 516 trace_id=trace_id, 517 issue_id=issue.issue_id, 518 issue_name="connection_issue", 519 ) 520 521 trace = mlflow.get_trace(trace_id) 522 assert len(trace.info.assessments) == 1 523 assessment = trace.info.assessments[0] 524 assert assessment.name == issue.issue_id 525 assert assessment.issue_id == issue.issue_id 526 assert assessment.issue_name == "connection_issue" 527 assert assessment.trace_id == trace_id 528 assert assessment.source.source_type == AssessmentSourceType.LLM_JUDGE 529 assert assessment.source.source_id == "default" 530 531 532 def test_log_issue_with_run_id_and_span_id(tracking_uri): 533 if tracking_uri.startswith("file:"): 534 pytest.skip("Issue APIs are not supported with file-based tracking URI") 535 536 tracing_client = mlflow.MlflowClient()._tracing_client 537 issue = tracing_client._create_issue( 538 experiment_id="0", 539 name="data_quality_issue", 540 description="Data quality issues", 541 status=IssueStatus.PENDING, 542 ) 543 544 with mlflow.start_span(name="test_span") as span: 545 mlflow.log_issue( 546 trace_id=span.trace_id, 547 issue_id=issue.issue_id, 548 issue_name="data_quality_issue", 549 source=_LLM_ASSESSMENT_SOURCE, 550 run_id="run-12345", 551 rationale="Input data contains missing values in critical fields", 552 metadata={"category": "data", "priority": "high"}, 553 span_id=span.span_id, 554 ) 555 mlflow.flush_trace_async_logging() 556 trace = mlflow.get_trace(span.trace_id) 557 assert len(trace.info.assessments) == 1 558 assessment = trace.info.assessments[0] 559 assert assessment.issue_id == issue.issue_id 560 assert assessment.issue_name == "data_quality_issue" 561 assert assessment.trace_id == span.trace_id 562 assert assessment.span_id == span.span_id 563 assert assessment.source == _LLM_ASSESSMENT_SOURCE 564 assert assessment.rationale == "Input data contains missing values in critical fields" 565 assert assessment.metadata == {"category": "data", "priority": "high"} 566 567 568 def test_log_issue_without_issue_name(trace_id, tracking_uri): 569 if tracking_uri.startswith("file:"): 570 pytest.skip("Issue APIs are not supported with file-based tracking URI") 571 572 tracing_client = mlflow.MlflowClient()._tracing_client 573 issue = tracing_client._create_issue( 574 experiment_id="0", 575 name="timeout_error", 576 description="Request exceeded 30 second timeout", 577 status=IssueStatus.PENDING, 578 ) 579 mlflow.log_issue( 580 trace_id=trace_id, 581 issue_id=issue.issue_id, 582 source=_CODE_ASSESSMENT_SOURCE, 583 rationale="Request exceeded 30 second timeout", 584 metadata={"severity": "high", "affected_count": "150"}, 585 ) 586 trace = mlflow.get_trace(trace_id) 587 assessment = trace.info.assessments[0] 588 assert assessment.issue_id == issue.issue_id 589 assert assessment.issue_name == "timeout_error" 590 591 fetched_issue = tracing_client._get_issue(issue.issue_id) 592 assert fetched_issue.name == issue.name 593 assert fetched_issue.description == issue.description 594 assert fetched_issue.status == issue.status 595 596 597 def test_log_issue_with_invalid_issue_name(trace_id, tracking_uri): 598 if tracking_uri.startswith("file:"): 599 pytest.skip("Issue APIs are not supported with file-based tracking URI") 600 601 tracing_client = mlflow.MlflowClient()._tracing_client 602 issue = tracing_client._create_issue( 603 experiment_id="0", 604 name="timeout_error", 605 description="Timeout errors in API calls", 606 status=IssueStatus.PENDING, 607 ) 608 609 with pytest.raises( 610 MlflowException, match=r"Provided issue name 'wrong_name' does not match the issue name" 611 ): 612 mlflow.log_issue( 613 trace_id=trace_id, 614 issue_id=issue.issue_id, 615 issue_name="wrong_name", 616 ) 617 618 619 def test_log_issue_with_invalid_issue_id(trace_id, tracking_uri): 620 if tracking_uri.startswith("file:"): 621 pytest.skip("Issue APIs are not supported with file-based tracking URI") 622 623 with pytest.raises( 624 MlflowException, 625 match=r"Issue with ID 'iss-nonexistent' not found", 626 ): 627 mlflow.log_issue( 628 trace_id=trace_id, 629 issue_id="iss-nonexistent", 630 issue_name="some_issue", 631 ) 632 633 634 def test_log_feedback_and_exception_blocks_positional_args(): 635 with pytest.raises(TypeError, match=r"log_feedback\(\) takes 0 positional"): 636 mlflow.log_feedback("tr-1234", "faithfulness", 1.0) 637 638 with pytest.raises(TypeError, match=r"log_expectation\(\) takes 0 positional"): 639 mlflow.log_expectation("tr-1234", "expected_answer", "MLflow") 640 641 with pytest.raises(TypeError, match=r"log_issue\(\) takes 0 positional"): 642 mlflow.log_issue("tr-1234", "iss-12345", "timeout_error") 643 644 645 @pytest.mark.parametrize("legacy_api", [True, False]) 646 def test_log_assessment_on_in_progress_trace(trace_id, legacy_api): 647 @mlflow.trace 648 def func(x: int, y: int) -> int: 649 active_trace_id = mlflow.get_active_trace_id() 650 if legacy_api: 651 mlflow.log_assessment(active_trace_id, Feedback(name="feedback", value=1.0)) 652 mlflow.log_assessment(active_trace_id, Expectation(name="expectation", value="MLflow")) 653 mlflow.log_assessment(trace_id, Feedback(name="other", value=2.0)) 654 else: 655 mlflow.log_feedback(trace_id=active_trace_id, name="feedback", value=1.0) 656 mlflow.log_expectation(trace_id=active_trace_id, name="expectation", value="MLflow") 657 mlflow.log_feedback(trace_id=trace_id, name="other", value=2.0) 658 return x + y 659 660 assert func(1, 2) == 3 661 662 mlflow.flush_trace_async_logging() 663 664 # Two assessments should be logged as a part of StartTraceV3 call 665 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 666 assert len(trace.info.assessments) == 2 667 assessments = {a.name: a for a in trace.info.assessments} 668 assert assessments["feedback"].value == 1.0 669 assert assessments["expectation"].value == "MLflow" 670 671 # Assessment on the other trace 672 trace = mlflow.get_trace(trace_id) 673 assert len(trace.info.assessments) == 1 674 assert trace.info.assessments[0].name == "other" 675 assert trace.info.assessments[0].feedback.value == 2.0 676 677 678 @pytest.mark.asyncio 679 async def test_log_assessment_on_in_progress_trace_async(): 680 @mlflow.trace 681 async def func(x: int, y: int) -> int: 682 trace_id = mlflow.get_active_trace_id() 683 mlflow.log_assessment(trace_id, Feedback(name="feedback", value=1.0)) 684 mlflow.log_assessment(trace_id, Expectation(name="expectation", value="MLflow")) 685 return x + y 686 687 assert (await func(1, 2)) == 3 688 689 mlflow.flush_trace_async_logging() 690 691 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 692 trace_info = trace.info 693 assert len(trace_info.assessments) == 2 694 assessments = {a.name: a for a in trace_info.assessments} 695 assert assessments["feedback"].feedback.value == 1.0 696 assert assessments["expectation"].expectation.value == "MLflow" 697 698 699 def test_log_assessment_on_in_progress_with_span_id(): 700 with mlflow.start_span(name="test_span") as span: 701 # Only proceed if we have a real span (not NO_OP) 702 if span.span_id is not None and span.trace_id != "MLFLOW_NO_OP_SPAN_TRACE_ID": 703 mlflow.log_assessment( 704 trace_id=span.trace_id, 705 assessment=Feedback(name="feedback", value=1.0, span_id=span.span_id), 706 ) 707 708 mlflow.flush_trace_async_logging() 709 710 # Two assessments should be logged as a part of StartTraceV3 call 711 trace = mlflow.get_trace(mlflow.get_last_active_trace_id()) 712 trace_info = trace.info 713 assert len(trace_info.assessments) == 1 714 assert trace_info.assessments[0].name == "feedback" 715 assert trace_info.assessments[0].feedback.value == 1.0 716 assert trace_info.assessments[0].span_id == span.span_id 717 718 719 def test_log_assessment_on_in_progress_trace_works_when_tracing_is_disabled(): 720 # Calling log_assessment to an active trace should not fail when tracing is disabled. 721 mlflow.tracing.disable() 722 723 @mlflow.trace 724 def func(x: int, y: int): 725 trace_id = mlflow.get_active_trace_id() 726 mlflow.log_assessment(trace_id=trace_id, assessment=Feedback(name="feedback", value=1.0)) 727 return x + y 728 729 assert func(1, 2) == 3 730 731 mlflow.flush_trace_async_logging() 732 733 734 def test_get_assessment(trace_id): 735 assessment_id = mlflow.log_feedback( 736 trace_id=trace_id, 737 name="faithfulness", 738 value=1.0, 739 ).assessment_id 740 741 result = mlflow.get_assessment(trace_id, assessment_id) 742 743 assert isinstance(result, Feedback) 744 assert result.name == "faithfulness" 745 assert result.trace_id == trace_id 746 assert result.value == 1.0 747 assert result.error is None 748 assert result.source.source_type == AssessmentSourceType.CODE 749 assert result.source.source_id == "default" 750 assert result.create_time_ms is not None 751 assert result.last_update_time_ms is not None 752 assert result.valid is True 753 assert result.overrides is None 754 755 756 def test_search_traces_with_assessments(): 757 # Create traces with assessments 758 with mlflow.start_span(name="trace_1") as span_1: 759 mlflow.log_feedback( 760 trace_id=span_1.trace_id, 761 name="feedback_1", 762 value=1.0, 763 ) 764 mlflow.log_expectation( 765 trace_id=span_1.trace_id, 766 name="expectation_1", 767 value="test", 768 source=AssessmentSource(source_id="test", source_type=AssessmentSourceType.LLM_JUDGE), 769 ) 770 with mlflow.start_span(name="child") as span_1_child: 771 mlflow.log_feedback( 772 trace_id=span_1_child.trace_id, 773 name="feedback_2", 774 value=1.0, 775 span_id=span_1_child.span_id, 776 ) 777 778 with mlflow.start_span(name="trace_2") as span_2: 779 mlflow.log_feedback( 780 trace_id=span_2.trace_id, 781 name="feedback_3", 782 value=1.0, 783 ) 784 785 mlflow.flush_trace_async_logging() 786 787 traces = mlflow.search_traces( 788 locations=["0"], 789 max_results=2, 790 return_type="list", 791 order_by=["timestamp_ms"], 792 ) 793 # Verify the results 794 assert len(traces) == 2 795 assert len(traces[0].info.assessments) == 3 796 797 assessments = {a.name: a for a in traces[0].info.assessments} 798 assert assessments["feedback_1"].trace_id == span_1.trace_id 799 assert assessments["feedback_1"].name == "feedback_1" 800 assert assessments["feedback_1"].value == 1.0 801 assert assessments["expectation_1"].trace_id == span_1.trace_id 802 assert assessments["expectation_1"].name == "expectation_1" 803 assert assessments["expectation_1"].value == "test" 804 assert assessments["feedback_2"].trace_id == span_1_child.trace_id 805 assert assessments["feedback_2"].name == "feedback_2" 806 assert assessments["feedback_2"].value == 1.0 807 808 assert len(traces[1].info.assessments) == 1 809 assessment = traces[1].info.assessments[0] 810 assert assessment.trace_id == span_2.trace_id 811 assert assessment.name == "feedback_3" 812 assert assessment.value == 1.0 813 814 815 @pytest.mark.parametrize("source_type", ["AI_JUDGE", AssessmentSourceType.AI_JUDGE]) 816 def test_log_feedback_ai_judge_deprecation_warning(trace_id, source_type): 817 with pytest.warns(FutureWarning, match="AI_JUDGE is deprecated. Use LLM_JUDGE instead."): 818 ai_judge_source = AssessmentSource(source_type=source_type, source_id="gpt-4") 819 820 mlflow.log_feedback( 821 trace_id=trace_id, 822 name="quality", 823 value=0.8, 824 source=ai_judge_source, 825 rationale="AI evaluation", 826 ) 827 828 trace = mlflow.get_trace(trace_id) 829 assert len(trace.info.assessments) == 1 830 assessment = trace.info.assessments[0] 831 assert assessment.source.source_type == AssessmentSourceType.LLM_JUDGE 832 assert assessment.source.source_id == "gpt-4" 833 assert assessment.name == "quality" 834 assert assessment.feedback.value == 0.8 835 assert assessment.rationale == "AI evaluation" 836 837 838 def test_log_issue_reference(trace_id, tracking_uri): 839 if tracking_uri.startswith("file:"): 840 pytest.skip("Issue APIs are not supported with file-based tracking URI") 841 842 tracing_client = mlflow.MlflowClient()._tracing_client 843 issue = tracing_client._create_issue( 844 experiment_id="0", 845 name="timeout_error", 846 description="Timeout errors in API calls", 847 status=IssueStatus.PENDING, 848 ) 849 850 issue_ref = IssueReference( 851 issue_id=issue.issue_id, 852 issue_name="timeout_error", 853 source=_CODE_ASSESSMENT_SOURCE, 854 metadata={"severity": "high", "affected_count": "150"}, 855 ) 856 mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref) 857 858 trace = mlflow.get_trace(trace_id) 859 assert len(trace.info.assessments) == 1 860 assessment = trace.info.assessments[0] 861 assert isinstance(assessment, IssueReference) 862 assert assessment.trace_id == trace_id 863 assert assessment.name == issue.issue_id 864 assert assessment.issue_id == issue.issue_id 865 assert assessment.issue_name == "timeout_error" 866 assert assessment.span_id is None 867 assert assessment.source == _CODE_ASSESSMENT_SOURCE 868 assert assessment.create_time_ms is not None 869 assert assessment.last_update_time_ms is not None 870 assert assessment.issue.issue_name == "timeout_error" 871 assert assessment.expectation is None 872 assert assessment.feedback is None 873 assert assessment.metadata == {"severity": "high", "affected_count": "150"} 874 875 876 def test_log_issue_reference_invalid_parameters(): 877 with pytest.raises(MlflowException, match=r"The `issue_id` field must be specified"): 878 IssueReference( 879 issue_id=None, 880 issue_name="test_issue", 881 source=_CODE_ASSESSMENT_SOURCE, 882 ) 883 884 885 def test_log_issue_reference_default_source(trace_id, tracking_uri): 886 if tracking_uri.startswith("file:"): 887 pytest.skip("Issue APIs are not supported with file-based tracking URI") 888 889 tracing_client = mlflow.MlflowClient()._tracing_client 890 issue = tracing_client._create_issue( 891 experiment_id="0", 892 name="connection_issue", 893 description="Connection issues", 894 status=IssueStatus.PENDING, 895 ) 896 897 issue_ref = IssueReference( 898 issue_id=issue.issue_id, 899 issue_name="connection_issue", 900 ) 901 mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref) 902 903 trace = mlflow.get_trace(trace_id) 904 assert len(trace.info.assessments) == 1 905 assessment = trace.info.assessments[0] 906 assert assessment.name == issue.issue_id 907 assert assessment.issue_name == "connection_issue" 908 assert assessment.trace_id == trace_id 909 assert assessment.issue_id == issue.issue_id 910 assert assessment.source.source_type == AssessmentSourceType.LLM_JUDGE 911 assert assessment.source.source_id == "default" 912 913 914 def test_get_issue_reference_assessment(trace_id): 915 issue_ref = IssueReference( 916 issue_id="iss-55555", 917 issue_name="performance_issue", 918 metadata={"category": "latency"}, 919 ) 920 assessment_id = mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref).assessment_id 921 922 result = mlflow.get_assessment(trace_id, assessment_id) 923 924 assert isinstance(result, IssueReference) 925 assert result.name == "iss-55555" 926 assert result.issue_name == "performance_issue" 927 assert result.trace_id == trace_id 928 assert result.issue_id == "iss-55555" 929 assert result.source.source_type == AssessmentSourceType.LLM_JUDGE 930 assert result.source.source_id == "default" 931 assert result.create_time_ms is not None 932 assert result.last_update_time_ms is not None 933 assert result.metadata == {"category": "latency"} 934 935 936 def test_log_multiple_assessment_types(trace_id): 937 feedback = Feedback( 938 name="accuracy", 939 value=0.95, 940 source=_LLM_ASSESSMENT_SOURCE, 941 ) 942 mlflow.log_assessment(trace_id=trace_id, assessment=feedback) 943 944 expectation = Expectation( 945 name="expected_output", 946 value="MLflow", 947 source=_HUMAN_ASSESSMENT_SOURCE, 948 ) 949 mlflow.log_assessment(trace_id=trace_id, assessment=expectation) 950 951 issue_ref = IssueReference( 952 issue_id="iss-11111", 953 issue_name="data_quality_issue", 954 source=_CODE_ASSESSMENT_SOURCE, 955 ) 956 mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref) 957 958 trace = mlflow.get_trace(trace_id) 959 assert len(trace.info.assessments) == 3 960 961 assessments_by_type = {} 962 for a in trace.info.assessments: 963 if isinstance(a, Feedback): 964 assessments_by_type["feedback"] = a 965 elif isinstance(a, Expectation): 966 assessments_by_type["expectation"] = a 967 elif isinstance(a, IssueReference): 968 assessments_by_type["issue"] = a 969 970 assert assessments_by_type["feedback"].name == "accuracy" 971 assert assessments_by_type["feedback"].value == 0.95 972 973 assert assessments_by_type["expectation"].name == "expected_output" 974 assert assessments_by_type["expectation"].value == "MLflow" 975 976 assert assessments_by_type["issue"].name == "iss-11111" 977 assert assessments_by_type["issue"].issue_name == "data_quality_issue" 978 assert assessments_by_type["issue"].issue_id == "iss-11111"