/ tests / tracing / test_assessment.py
test_assessment.py
  1  import os
  2  
  3  import pytest
  4  
  5  import mlflow
  6  from mlflow.entities.assessment import (
  7      AssessmentError,
  8      AssessmentSource,
  9      Expectation,
 10      Feedback,
 11      IssueReference,
 12  )
 13  from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType
 14  from mlflow.entities.issue import IssueStatus
 15  from mlflow.exceptions import MlflowException
 16  from mlflow.version import IS_TRACING_SDK_ONLY
 17  
 18  _HUMAN_ASSESSMENT_SOURCE = AssessmentSource(
 19      source_type=AssessmentSourceType.HUMAN,
 20      source_id="bob@example.com",
 21  )
 22  
 23  _LLM_ASSESSMENT_SOURCE = AssessmentSource(
 24      source_type=AssessmentSourceType.LLM_JUDGE,
 25      source_id="gpt-4o-mini",
 26  )
 27  
 28  _CODE_ASSESSMENT_SOURCE = AssessmentSource(
 29      source_type=AssessmentSourceType.CODE,
 30      source_id="issue_detector.py",
 31  )
 32  
 33  
 34  @pytest.fixture
 35  def trace_id():
 36      with mlflow.start_span(name="test_span") as span:
 37          pass
 38  
 39      mlflow.flush_trace_async_logging()
 40      return span.trace_id
 41  
 42  
 43  @pytest.fixture(params=["file", "sqlalchemy"], autouse=True)
 44  def tracking_uri(request, tmp_path, db_uri):
 45      """Set an MLflow Tracking URI with different type of backend."""
 46      if request.param == "file":
 47          pytest.skip("FileStore is no longer supported.")
 48      if "MLFLOW_SKINNY" in os.environ and request.param == "sqlalchemy":
 49          pytest.skip("SQLAlchemy store is not available in skinny.")
 50  
 51      if IS_TRACING_SDK_ONLY and request.param == "sqlalchemy":
 52          pytest.skip("SQLAlchemy store is not available in tracing SDK only mode.")
 53  
 54      original_tracking_uri = mlflow.get_tracking_uri()
 55  
 56      if request.param == "file":
 57          tracking_uri = tmp_path.joinpath("file").as_uri()
 58      elif request.param == "sqlalchemy":
 59          tracking_uri = db_uri
 60  
 61      # NB: MLflow tracer does not handle the change of tracking URI well,
 62      # so we need to reset the tracer to switch the tracking URI during testing.
 63      mlflow.tracing.disable()
 64      mlflow.set_tracking_uri(tracking_uri)
 65      mlflow.tracing.enable()
 66  
 67      yield tracking_uri
 68  
 69      # Reset tracking URI
 70      mlflow.set_tracking_uri(original_tracking_uri)
 71  
 72  
 73  @pytest.mark.parametrize("legacy_api", [True, False])
 74  def test_log_expectation(trace_id, legacy_api):
 75      if legacy_api:
 76          mlflow.log_expectation(
 77              trace_id=trace_id,
 78              name="expected_answer",
 79              value="MLflow",
 80              source=_HUMAN_ASSESSMENT_SOURCE,
 81              metadata={"key": "value"},
 82          )
 83      else:
 84          feedback = Expectation(
 85              name="expected_answer",
 86              value="MLflow",
 87              source=_HUMAN_ASSESSMENT_SOURCE,
 88              metadata={"key": "value"},
 89          )
 90          mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
 91  
 92      trace = mlflow.get_trace(trace_id)
 93      assert len(trace.info.assessments) == 1
 94      assessment = trace.info.assessments[0]
 95      assert isinstance(assessment, Expectation)
 96      assert assessment.trace_id == trace_id
 97      assert assessment.name == "expected_answer"
 98      assert assessment.value == "MLflow"
 99      assert assessment.trace_id == trace_id
100      assert assessment.span_id is None
101      assert assessment.source == _HUMAN_ASSESSMENT_SOURCE
102      assert assessment.create_time_ms is not None
103      assert assessment.last_update_time_ms is not None
104      assert assessment.expectation.value == "MLflow"
105      assert assessment.rationale is None
106      assert assessment.metadata == {"key": "value"}
107  
108  
109  def test_log_expectation_invalid_parameters():
110      with pytest.raises(MlflowException, match=r"The `value` field must be specified."):
111          Expectation(
112              name="expected_answer",
113              value=None,
114              source=_HUMAN_ASSESSMENT_SOURCE,
115          )
116  
117  
118  def test_update_expectation(trace_id):
119      assessment_id = mlflow.log_expectation(
120          trace_id=trace_id,
121          name="expected_answer",
122          value="MLflow",
123      ).assessment_id
124  
125      updated_assessment = Expectation(
126          name="expected_answer",
127          value="Spark",
128          metadata={"reason": "human override"},
129      )
130  
131      mlflow.update_assessment(
132          assessment_id=assessment_id,
133          trace_id=trace_id,
134          assessment=updated_assessment,
135      )
136  
137      trace = mlflow.get_trace(trace_id)
138      assert len(trace.info.assessments) == 1
139      assessment = trace.info.assessments[0]
140      assert assessment.trace_id == trace_id
141      assert assessment.name == "expected_answer"
142      assert assessment.expectation.value == "Spark"
143      assert assessment.feedback is None
144      assert assessment.rationale is None
145      assert assessment.metadata == {"reason": "human override"}
146  
147  
148  @pytest.mark.parametrize("legacy_api", [True, False])
149  def test_log_feedback(trace_id, legacy_api):
150      if legacy_api:
151          mlflow.log_feedback(
152              trace_id=trace_id,
153              name="faithfulness",
154              value=1.0,
155              source=_LLM_ASSESSMENT_SOURCE,
156              rationale="This answer is very faithful.",
157              metadata={"model": "gpt-4o-mini"},
158          )
159      else:
160          feedback = Feedback(
161              name="faithfulness",
162              value=1.0,
163              source=_LLM_ASSESSMENT_SOURCE,
164              rationale="This answer is very faithful.",
165              metadata={"model": "gpt-4o-mini"},
166          )
167          mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
168  
169      trace = mlflow.get_trace(trace_id)
170      assert len(trace.info.assessments) == 1
171      assessment = trace.info.assessments[0]
172      assert isinstance(assessment, Feedback)
173      assert assessment.trace_id == trace_id
174      assert assessment.name == "faithfulness"
175      assert assessment.span_id is None
176      assert assessment.source == _LLM_ASSESSMENT_SOURCE
177      assert assessment.create_time_ms is not None
178      assert assessment.last_update_time_ms is not None
179      assert assessment.feedback.value == 1.0
180      assert assessment.feedback.error is None
181      assert assessment.expectation is None
182      assert assessment.rationale == "This answer is very faithful."
183      assert assessment.metadata == {"model": "gpt-4o-mini"}
184  
185  
186  @pytest.mark.parametrize("legacy_api", [True, False])
187  def test_log_feedback_with_error(trace_id, legacy_api):
188      if legacy_api:
189          mlflow.log_feedback(
190              trace_id=trace_id,
191              name="faithfulness",
192              source=_LLM_ASSESSMENT_SOURCE,
193              error=AssessmentError(
194                  error_code="RATE_LIMIT_EXCEEDED",
195                  error_message="Rate limit for the judge exceeded.",
196              ),
197          )
198      else:
199          feedback = Feedback(
200              name="faithfulness",
201              value=None,
202              source=_LLM_ASSESSMENT_SOURCE,
203              error=AssessmentError(
204                  error_code="RATE_LIMIT_EXCEEDED",
205                  error_message="Rate limit for the judge exceeded.",
206              ),
207          )
208          mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
209  
210      trace = mlflow.get_trace(trace_id)
211      assert len(trace.info.assessments) == 1
212      assessment = trace.info.assessments[0]
213      assert assessment.name == "faithfulness"
214      assert assessment.trace_id == trace_id
215      assert assessment.span_id is None
216      assert assessment.source == _LLM_ASSESSMENT_SOURCE
217      assert assessment.create_time_ms is not None
218      assert assessment.last_update_time_ms is not None
219      assert assessment.expectation is None
220      assert assessment.feedback.value is None
221      assert assessment.feedback.error.error_code == "RATE_LIMIT_EXCEEDED"
222      assert assessment.feedback.error.error_message == "Rate limit for the judge exceeded."
223      assert assessment.rationale is None
224  
225  
226  @pytest.mark.parametrize("legacy_api", [True, False])
227  def test_log_feedback_with_exception_object(trace_id, legacy_api):
228      test_exception = ValueError("Test exception message")
229  
230      if legacy_api:
231          mlflow.log_feedback(
232              trace_id=trace_id,
233              name="faithfulness",
234              source=_LLM_ASSESSMENT_SOURCE,
235              error=test_exception,
236          )
237      else:
238          feedback = Feedback(
239              name="faithfulness",
240              value=None,
241              source=_LLM_ASSESSMENT_SOURCE,
242              error=test_exception,
243          )
244          mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
245  
246      trace = mlflow.get_trace(trace_id)
247      assert len(trace.info.assessments) == 1
248      assessment = trace.info.assessments[0]
249      assert assessment.name == "faithfulness"
250      assert assessment.trace_id == trace_id
251      assert assessment.span_id is None
252      assert assessment.source == _LLM_ASSESSMENT_SOURCE
253      assert assessment.create_time_ms is not None
254      assert assessment.last_update_time_ms is not None
255      assert assessment.expectation is None
256      assert assessment.feedback.value is None
257      # Exception should be converted to AssessmentError
258      assert assessment.feedback.error.error_code == "ValueError"
259      assert assessment.feedback.error.error_message == "Test exception message"
260      assert assessment.feedback.error.stack_trace is not None
261      assert assessment.rationale is None
262  
263  
264  @pytest.mark.parametrize("legacy_api", [True, False])
265  def test_log_feedback_with_value_and_error(trace_id, legacy_api):
266      if legacy_api:
267          mlflow.log_feedback(
268              trace_id=trace_id,
269              name="faithfulness",
270              source=_LLM_ASSESSMENT_SOURCE,
271              value=0.5,
272              error=AssessmentError(
273                  error_code="RATE_LIMIT_EXCEEDED",
274                  error_message="Rate limit for the judge exceeded.",
275              ),
276          )
277      else:
278          feedback = Feedback(
279              name="faithfulness",
280              value=0.5,
281              source=_LLM_ASSESSMENT_SOURCE,
282              error=AssessmentError(
283                  error_code="RATE_LIMIT_EXCEEDED",
284                  error_message="Rate limit for the judge exceeded.",
285              ),
286          )
287          mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
288  
289      trace = mlflow.get_trace(trace_id)
290      assert len(trace.info.assessments) == 1
291      assessment = trace.info.assessments[0]
292      assert assessment.name == "faithfulness"
293      assert assessment.trace_id == trace_id
294      assert assessment.span_id is None
295      assert assessment.source == _LLM_ASSESSMENT_SOURCE
296      assert assessment.create_time_ms is not None
297      assert assessment.last_update_time_ms is not None
298      assert assessment.expectation is None
299      assert assessment.feedback.value == 0.5
300      assert assessment.feedback.error.error_code == "RATE_LIMIT_EXCEEDED"
301      assert assessment.feedback.error.error_message == "Rate limit for the judge exceeded."
302      assert assessment.rationale is None
303  
304  
305  def test_log_feedback_none_value(trace_id):
306      mlflow.log_feedback(
307          trace_id=trace_id,
308          name="faithfulness",
309          value=None,
310          source=_LLM_ASSESSMENT_SOURCE,
311      )
312  
313      trace = mlflow.get_trace(trace_id)
314      assert len(trace.info.assessments) == 1
315      assessment = trace.info.assessments[0]
316      assert isinstance(assessment, Feedback)
317      assert assessment.trace_id == trace_id
318      assert assessment.name == "faithfulness"
319      assert assessment.feedback.value is None
320      assert assessment.feedback.error is None
321  
322  
323  def test_log_feedback_invalid_parameters():
324      # Test with a non-AssessmentSource object that is not None
325      with pytest.raises(MlflowException, match=r"`source` must be an instance of"):
326          Feedback(
327              trace_id="1234",
328              name="faithfulness",
329              value=1.0,
330              source="invalid_source_type",
331          )
332  
333  
334  def test_update_feedback(trace_id):
335      assessment_id = mlflow.log_feedback(
336          trace_id=trace_id,
337          name="faithfulness",
338          value=1.0,
339          rationale="This answer is very faithful.",
340          metadata={"model": "gpt-4o-mini"},
341      ).assessment_id
342  
343      updated_feedback = Feedback(
344          name="faithfulness",
345          value=0,
346          rationale="This answer is not faithful.",
347          metadata={"reason": "human override"},
348      )
349      mlflow.update_assessment(
350          assessment_id=assessment_id,
351          trace_id=trace_id,
352          assessment=updated_feedback,
353      )
354  
355      trace = mlflow.get_trace(trace_id)
356      assert len(trace.info.assessments) == 1
357      assessment = trace.info.assessments[0]
358      assert assessment.name == "faithfulness"
359      assert assessment.trace_id == trace_id
360      assert assessment.feedback.value == 0
361      assert assessment.feedback.error is None
362      assert assessment.rationale == "This answer is not faithful."
363      assert assessment.metadata == {
364          "model": "gpt-4o-mini",
365          "reason": "human override",
366      }
367  
368  
369  def test_override_feedback(trace_id):
370      assessment_id = mlflow.log_feedback(
371          trace_id=trace_id,
372          name="faithfulness",
373          value=0.5,
374          source=_LLM_ASSESSMENT_SOURCE,
375          rationale="Original feedback",
376          metadata={"model": "gpt-3.5"},
377      ).assessment_id
378  
379      new_assessment_id = mlflow.override_feedback(
380          trace_id=trace_id,
381          assessment_id=assessment_id,
382          value=1.0,
383          source=_LLM_ASSESSMENT_SOURCE,
384          rationale="This answer is very faithful.",
385          metadata={"model": "gpt-4o-mini"},
386      ).assessment_id
387  
388      # New assessment should have the same trace_id as the original assessment
389      assessment = mlflow.get_assessment(trace_id, new_assessment_id)
390      assert assessment.trace_id == trace_id
391      assert assessment.name == "faithfulness"
392      assert assessment.span_id is None
393      assert assessment.source == _LLM_ASSESSMENT_SOURCE
394      assert assessment.create_time_ms is not None
395      assert assessment.last_update_time_ms is not None
396      assert assessment.value == 1.0
397      assert assessment.error is None
398      assert assessment.rationale == "This answer is very faithful."
399      assert assessment.metadata == {"model": "gpt-4o-mini"}
400      assert assessment.overrides == assessment_id
401      assert assessment.valid is True
402  
403      # Original assessment should be invalidated
404      original_assessment = mlflow.get_assessment(trace_id, assessment_id)
405      assert original_assessment.valid is False
406      assert original_assessment.feedback.value == 0.5
407  
408  
409  def test_delete_assessment(trace_id):
410      assessment_id = mlflow.log_feedback(
411          trace_id=trace_id,
412          name="faithfulness",
413          value=1.0,
414      ).assessment_id
415  
416      mlflow.delete_assessment(trace_id=trace_id, assessment_id=assessment_id)
417  
418      with pytest.raises(MlflowException, match=r"Assessment with ID"):
419          assert mlflow.get_assessment(trace_id, assessment_id) is None
420  
421      # Assessment should be deleted from the trace
422      trace = mlflow.get_trace(trace_id)
423      assert len(trace.info.assessments) == 0
424  
425  
426  def test_log_feedback_default_source(trace_id):
427      # Test that the default CODE source is used when no source is provided
428      feedback = Feedback(
429          trace_id=trace_id,
430          name="faithfulness",
431          value=1.0,
432      )
433      mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
434  
435      trace = mlflow.get_trace(trace_id)
436      assert len(trace.info.assessments) == 1
437      assessment = trace.info.assessments[0]
438      assert assessment.name == "faithfulness"
439      assert assessment.trace_id == trace_id
440      assert assessment.source.source_type == AssessmentSourceType.CODE
441      assert assessment.source.source_id == "default"
442      assert assessment.feedback.value == 1.0
443  
444  
445  def test_log_expectation_default_source(trace_id):
446      # Test that the default HUMAN source is used when no source is provided
447      expectation = Expectation(
448          trace_id=trace_id,
449          name="expected_answer",
450          value="MLflow",
451      )
452      mlflow.log_assessment(trace_id=trace_id, assessment=expectation)
453  
454      trace = mlflow.get_trace(trace_id)
455      assert len(trace.info.assessments) == 1
456      assessment = trace.info.assessments[0]
457      assert assessment.name == "expected_answer"
458      assert assessment.trace_id == trace_id
459      assert assessment.source.source_type == AssessmentSourceType.HUMAN
460      assert assessment.source.source_id == "default"
461      assert assessment.expectation.value == "MLflow"
462  
463  
464  def test_log_issue(trace_id, tracking_uri):
465      if tracking_uri.startswith("file:"):
466          pytest.skip("Issue APIs are not supported with file-based tracking URI")
467  
468      tracing_client = mlflow.MlflowClient()._tracing_client
469      issue = tracing_client._create_issue(
470          experiment_id="0",
471          name="timeout_error",
472          description="Timeout errors in API calls",
473          status=IssueStatus.PENDING,
474      )
475  
476      mlflow.log_issue(
477          trace_id=trace_id,
478          issue_id=issue.issue_id,
479          issue_name="timeout_error",
480          source=_CODE_ASSESSMENT_SOURCE,
481          rationale="Request exceeded 30 second timeout",
482          metadata={"severity": "high", "affected_count": "150"},
483      )
484  
485      trace = mlflow.get_trace(trace_id)
486      assert len(trace.info.assessments) == 1
487      assessment = trace.info.assessments[0]
488      assert isinstance(assessment, IssueReference)
489      assert assessment.trace_id == trace_id
490      assert assessment.name == issue.issue_id
491      assert assessment.issue_id == issue.issue_id
492      assert assessment.issue_name == "timeout_error"
493      assert assessment.span_id is None
494      assert assessment.source == _CODE_ASSESSMENT_SOURCE
495      assert assessment.create_time_ms is not None
496      assert assessment.last_update_time_ms is not None
497      assert assessment.issue.issue_name == "timeout_error"
498      assert assessment.rationale == "Request exceeded 30 second timeout"
499      assert assessment.metadata == {"severity": "high", "affected_count": "150"}
500  
501  
502  def test_log_issue_default_source(trace_id, tracking_uri):
503      if tracking_uri.startswith("file:"):
504          pytest.skip("Issue APIs are not supported with file-based tracking URI")
505  
506      tracing_client = mlflow.MlflowClient()._tracing_client
507      issue = tracing_client._create_issue(
508          experiment_id="0",
509          name="connection_issue",
510          description="Connection issues",
511          status=IssueStatus.PENDING,
512      )
513  
514      # Test that the default LLM_JUDGE source is used when no source is provided
515      mlflow.log_issue(
516          trace_id=trace_id,
517          issue_id=issue.issue_id,
518          issue_name="connection_issue",
519      )
520  
521      trace = mlflow.get_trace(trace_id)
522      assert len(trace.info.assessments) == 1
523      assessment = trace.info.assessments[0]
524      assert assessment.name == issue.issue_id
525      assert assessment.issue_id == issue.issue_id
526      assert assessment.issue_name == "connection_issue"
527      assert assessment.trace_id == trace_id
528      assert assessment.source.source_type == AssessmentSourceType.LLM_JUDGE
529      assert assessment.source.source_id == "default"
530  
531  
532  def test_log_issue_with_run_id_and_span_id(tracking_uri):
533      if tracking_uri.startswith("file:"):
534          pytest.skip("Issue APIs are not supported with file-based tracking URI")
535  
536      tracing_client = mlflow.MlflowClient()._tracing_client
537      issue = tracing_client._create_issue(
538          experiment_id="0",
539          name="data_quality_issue",
540          description="Data quality issues",
541          status=IssueStatus.PENDING,
542      )
543  
544      with mlflow.start_span(name="test_span") as span:
545          mlflow.log_issue(
546              trace_id=span.trace_id,
547              issue_id=issue.issue_id,
548              issue_name="data_quality_issue",
549              source=_LLM_ASSESSMENT_SOURCE,
550              run_id="run-12345",
551              rationale="Input data contains missing values in critical fields",
552              metadata={"category": "data", "priority": "high"},
553              span_id=span.span_id,
554          )
555      mlflow.flush_trace_async_logging()
556      trace = mlflow.get_trace(span.trace_id)
557      assert len(trace.info.assessments) == 1
558      assessment = trace.info.assessments[0]
559      assert assessment.issue_id == issue.issue_id
560      assert assessment.issue_name == "data_quality_issue"
561      assert assessment.trace_id == span.trace_id
562      assert assessment.span_id == span.span_id
563      assert assessment.source == _LLM_ASSESSMENT_SOURCE
564      assert assessment.rationale == "Input data contains missing values in critical fields"
565      assert assessment.metadata == {"category": "data", "priority": "high"}
566  
567  
568  def test_log_issue_without_issue_name(trace_id, tracking_uri):
569      if tracking_uri.startswith("file:"):
570          pytest.skip("Issue APIs are not supported with file-based tracking URI")
571  
572      tracing_client = mlflow.MlflowClient()._tracing_client
573      issue = tracing_client._create_issue(
574          experiment_id="0",
575          name="timeout_error",
576          description="Request exceeded 30 second timeout",
577          status=IssueStatus.PENDING,
578      )
579      mlflow.log_issue(
580          trace_id=trace_id,
581          issue_id=issue.issue_id,
582          source=_CODE_ASSESSMENT_SOURCE,
583          rationale="Request exceeded 30 second timeout",
584          metadata={"severity": "high", "affected_count": "150"},
585      )
586      trace = mlflow.get_trace(trace_id)
587      assessment = trace.info.assessments[0]
588      assert assessment.issue_id == issue.issue_id
589      assert assessment.issue_name == "timeout_error"
590  
591      fetched_issue = tracing_client._get_issue(issue.issue_id)
592      assert fetched_issue.name == issue.name
593      assert fetched_issue.description == issue.description
594      assert fetched_issue.status == issue.status
595  
596  
597  def test_log_issue_with_invalid_issue_name(trace_id, tracking_uri):
598      if tracking_uri.startswith("file:"):
599          pytest.skip("Issue APIs are not supported with file-based tracking URI")
600  
601      tracing_client = mlflow.MlflowClient()._tracing_client
602      issue = tracing_client._create_issue(
603          experiment_id="0",
604          name="timeout_error",
605          description="Timeout errors in API calls",
606          status=IssueStatus.PENDING,
607      )
608  
609      with pytest.raises(
610          MlflowException, match=r"Provided issue name 'wrong_name' does not match the issue name"
611      ):
612          mlflow.log_issue(
613              trace_id=trace_id,
614              issue_id=issue.issue_id,
615              issue_name="wrong_name",
616          )
617  
618  
619  def test_log_issue_with_invalid_issue_id(trace_id, tracking_uri):
620      if tracking_uri.startswith("file:"):
621          pytest.skip("Issue APIs are not supported with file-based tracking URI")
622  
623      with pytest.raises(
624          MlflowException,
625          match=r"Issue with ID 'iss-nonexistent' not found",
626      ):
627          mlflow.log_issue(
628              trace_id=trace_id,
629              issue_id="iss-nonexistent",
630              issue_name="some_issue",
631          )
632  
633  
634  def test_log_feedback_and_exception_blocks_positional_args():
635      with pytest.raises(TypeError, match=r"log_feedback\(\) takes 0 positional"):
636          mlflow.log_feedback("tr-1234", "faithfulness", 1.0)
637  
638      with pytest.raises(TypeError, match=r"log_expectation\(\) takes 0 positional"):
639          mlflow.log_expectation("tr-1234", "expected_answer", "MLflow")
640  
641      with pytest.raises(TypeError, match=r"log_issue\(\) takes 0 positional"):
642          mlflow.log_issue("tr-1234", "iss-12345", "timeout_error")
643  
644  
645  @pytest.mark.parametrize("legacy_api", [True, False])
646  def test_log_assessment_on_in_progress_trace(trace_id, legacy_api):
647      @mlflow.trace
648      def func(x: int, y: int) -> int:
649          active_trace_id = mlflow.get_active_trace_id()
650          if legacy_api:
651              mlflow.log_assessment(active_trace_id, Feedback(name="feedback", value=1.0))
652              mlflow.log_assessment(active_trace_id, Expectation(name="expectation", value="MLflow"))
653              mlflow.log_assessment(trace_id, Feedback(name="other", value=2.0))
654          else:
655              mlflow.log_feedback(trace_id=active_trace_id, name="feedback", value=1.0)
656              mlflow.log_expectation(trace_id=active_trace_id, name="expectation", value="MLflow")
657              mlflow.log_feedback(trace_id=trace_id, name="other", value=2.0)
658          return x + y
659  
660      assert func(1, 2) == 3
661  
662      mlflow.flush_trace_async_logging()
663  
664      # Two assessments should be logged as a part of StartTraceV3 call
665      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
666      assert len(trace.info.assessments) == 2
667      assessments = {a.name: a for a in trace.info.assessments}
668      assert assessments["feedback"].value == 1.0
669      assert assessments["expectation"].value == "MLflow"
670  
671      # Assessment on the other trace
672      trace = mlflow.get_trace(trace_id)
673      assert len(trace.info.assessments) == 1
674      assert trace.info.assessments[0].name == "other"
675      assert trace.info.assessments[0].feedback.value == 2.0
676  
677  
678  @pytest.mark.asyncio
679  async def test_log_assessment_on_in_progress_trace_async():
680      @mlflow.trace
681      async def func(x: int, y: int) -> int:
682          trace_id = mlflow.get_active_trace_id()
683          mlflow.log_assessment(trace_id, Feedback(name="feedback", value=1.0))
684          mlflow.log_assessment(trace_id, Expectation(name="expectation", value="MLflow"))
685          return x + y
686  
687      assert (await func(1, 2)) == 3
688  
689      mlflow.flush_trace_async_logging()
690  
691      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
692      trace_info = trace.info
693      assert len(trace_info.assessments) == 2
694      assessments = {a.name: a for a in trace_info.assessments}
695      assert assessments["feedback"].feedback.value == 1.0
696      assert assessments["expectation"].expectation.value == "MLflow"
697  
698  
699  def test_log_assessment_on_in_progress_with_span_id():
700      with mlflow.start_span(name="test_span") as span:
701          # Only proceed if we have a real span (not NO_OP)
702          if span.span_id is not None and span.trace_id != "MLFLOW_NO_OP_SPAN_TRACE_ID":
703              mlflow.log_assessment(
704                  trace_id=span.trace_id,
705                  assessment=Feedback(name="feedback", value=1.0, span_id=span.span_id),
706              )
707  
708      mlflow.flush_trace_async_logging()
709  
710      # Two assessments should be logged as a part of StartTraceV3 call
711      trace = mlflow.get_trace(mlflow.get_last_active_trace_id())
712      trace_info = trace.info
713      assert len(trace_info.assessments) == 1
714      assert trace_info.assessments[0].name == "feedback"
715      assert trace_info.assessments[0].feedback.value == 1.0
716      assert trace_info.assessments[0].span_id == span.span_id
717  
718  
719  def test_log_assessment_on_in_progress_trace_works_when_tracing_is_disabled():
720      # Calling log_assessment to an active trace should not fail when tracing is disabled.
721      mlflow.tracing.disable()
722  
723      @mlflow.trace
724      def func(x: int, y: int):
725          trace_id = mlflow.get_active_trace_id()
726          mlflow.log_assessment(trace_id=trace_id, assessment=Feedback(name="feedback", value=1.0))
727          return x + y
728  
729      assert func(1, 2) == 3
730  
731      mlflow.flush_trace_async_logging()
732  
733  
734  def test_get_assessment(trace_id):
735      assessment_id = mlflow.log_feedback(
736          trace_id=trace_id,
737          name="faithfulness",
738          value=1.0,
739      ).assessment_id
740  
741      result = mlflow.get_assessment(trace_id, assessment_id)
742  
743      assert isinstance(result, Feedback)
744      assert result.name == "faithfulness"
745      assert result.trace_id == trace_id
746      assert result.value == 1.0
747      assert result.error is None
748      assert result.source.source_type == AssessmentSourceType.CODE
749      assert result.source.source_id == "default"
750      assert result.create_time_ms is not None
751      assert result.last_update_time_ms is not None
752      assert result.valid is True
753      assert result.overrides is None
754  
755  
756  def test_search_traces_with_assessments():
757      # Create traces with assessments
758      with mlflow.start_span(name="trace_1") as span_1:
759          mlflow.log_feedback(
760              trace_id=span_1.trace_id,
761              name="feedback_1",
762              value=1.0,
763          )
764          mlflow.log_expectation(
765              trace_id=span_1.trace_id,
766              name="expectation_1",
767              value="test",
768              source=AssessmentSource(source_id="test", source_type=AssessmentSourceType.LLM_JUDGE),
769          )
770          with mlflow.start_span(name="child") as span_1_child:
771              mlflow.log_feedback(
772                  trace_id=span_1_child.trace_id,
773                  name="feedback_2",
774                  value=1.0,
775                  span_id=span_1_child.span_id,
776              )
777  
778      with mlflow.start_span(name="trace_2") as span_2:
779          mlflow.log_feedback(
780              trace_id=span_2.trace_id,
781              name="feedback_3",
782              value=1.0,
783          )
784  
785      mlflow.flush_trace_async_logging()
786  
787      traces = mlflow.search_traces(
788          locations=["0"],
789          max_results=2,
790          return_type="list",
791          order_by=["timestamp_ms"],
792      )
793      # Verify the results
794      assert len(traces) == 2
795      assert len(traces[0].info.assessments) == 3
796  
797      assessments = {a.name: a for a in traces[0].info.assessments}
798      assert assessments["feedback_1"].trace_id == span_1.trace_id
799      assert assessments["feedback_1"].name == "feedback_1"
800      assert assessments["feedback_1"].value == 1.0
801      assert assessments["expectation_1"].trace_id == span_1.trace_id
802      assert assessments["expectation_1"].name == "expectation_1"
803      assert assessments["expectation_1"].value == "test"
804      assert assessments["feedback_2"].trace_id == span_1_child.trace_id
805      assert assessments["feedback_2"].name == "feedback_2"
806      assert assessments["feedback_2"].value == 1.0
807  
808      assert len(traces[1].info.assessments) == 1
809      assessment = traces[1].info.assessments[0]
810      assert assessment.trace_id == span_2.trace_id
811      assert assessment.name == "feedback_3"
812      assert assessment.value == 1.0
813  
814  
815  @pytest.mark.parametrize("source_type", ["AI_JUDGE", AssessmentSourceType.AI_JUDGE])
816  def test_log_feedback_ai_judge_deprecation_warning(trace_id, source_type):
817      with pytest.warns(FutureWarning, match="AI_JUDGE is deprecated. Use LLM_JUDGE instead."):
818          ai_judge_source = AssessmentSource(source_type=source_type, source_id="gpt-4")
819  
820      mlflow.log_feedback(
821          trace_id=trace_id,
822          name="quality",
823          value=0.8,
824          source=ai_judge_source,
825          rationale="AI evaluation",
826      )
827  
828      trace = mlflow.get_trace(trace_id)
829      assert len(trace.info.assessments) == 1
830      assessment = trace.info.assessments[0]
831      assert assessment.source.source_type == AssessmentSourceType.LLM_JUDGE
832      assert assessment.source.source_id == "gpt-4"
833      assert assessment.name == "quality"
834      assert assessment.feedback.value == 0.8
835      assert assessment.rationale == "AI evaluation"
836  
837  
838  def test_log_issue_reference(trace_id, tracking_uri):
839      if tracking_uri.startswith("file:"):
840          pytest.skip("Issue APIs are not supported with file-based tracking URI")
841  
842      tracing_client = mlflow.MlflowClient()._tracing_client
843      issue = tracing_client._create_issue(
844          experiment_id="0",
845          name="timeout_error",
846          description="Timeout errors in API calls",
847          status=IssueStatus.PENDING,
848      )
849  
850      issue_ref = IssueReference(
851          issue_id=issue.issue_id,
852          issue_name="timeout_error",
853          source=_CODE_ASSESSMENT_SOURCE,
854          metadata={"severity": "high", "affected_count": "150"},
855      )
856      mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref)
857  
858      trace = mlflow.get_trace(trace_id)
859      assert len(trace.info.assessments) == 1
860      assessment = trace.info.assessments[0]
861      assert isinstance(assessment, IssueReference)
862      assert assessment.trace_id == trace_id
863      assert assessment.name == issue.issue_id
864      assert assessment.issue_id == issue.issue_id
865      assert assessment.issue_name == "timeout_error"
866      assert assessment.span_id is None
867      assert assessment.source == _CODE_ASSESSMENT_SOURCE
868      assert assessment.create_time_ms is not None
869      assert assessment.last_update_time_ms is not None
870      assert assessment.issue.issue_name == "timeout_error"
871      assert assessment.expectation is None
872      assert assessment.feedback is None
873      assert assessment.metadata == {"severity": "high", "affected_count": "150"}
874  
875  
876  def test_log_issue_reference_invalid_parameters():
877      with pytest.raises(MlflowException, match=r"The `issue_id` field must be specified"):
878          IssueReference(
879              issue_id=None,
880              issue_name="test_issue",
881              source=_CODE_ASSESSMENT_SOURCE,
882          )
883  
884  
885  def test_log_issue_reference_default_source(trace_id, tracking_uri):
886      if tracking_uri.startswith("file:"):
887          pytest.skip("Issue APIs are not supported with file-based tracking URI")
888  
889      tracing_client = mlflow.MlflowClient()._tracing_client
890      issue = tracing_client._create_issue(
891          experiment_id="0",
892          name="connection_issue",
893          description="Connection issues",
894          status=IssueStatus.PENDING,
895      )
896  
897      issue_ref = IssueReference(
898          issue_id=issue.issue_id,
899          issue_name="connection_issue",
900      )
901      mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref)
902  
903      trace = mlflow.get_trace(trace_id)
904      assert len(trace.info.assessments) == 1
905      assessment = trace.info.assessments[0]
906      assert assessment.name == issue.issue_id
907      assert assessment.issue_name == "connection_issue"
908      assert assessment.trace_id == trace_id
909      assert assessment.issue_id == issue.issue_id
910      assert assessment.source.source_type == AssessmentSourceType.LLM_JUDGE
911      assert assessment.source.source_id == "default"
912  
913  
914  def test_get_issue_reference_assessment(trace_id):
915      issue_ref = IssueReference(
916          issue_id="iss-55555",
917          issue_name="performance_issue",
918          metadata={"category": "latency"},
919      )
920      assessment_id = mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref).assessment_id
921  
922      result = mlflow.get_assessment(trace_id, assessment_id)
923  
924      assert isinstance(result, IssueReference)
925      assert result.name == "iss-55555"
926      assert result.issue_name == "performance_issue"
927      assert result.trace_id == trace_id
928      assert result.issue_id == "iss-55555"
929      assert result.source.source_type == AssessmentSourceType.LLM_JUDGE
930      assert result.source.source_id == "default"
931      assert result.create_time_ms is not None
932      assert result.last_update_time_ms is not None
933      assert result.metadata == {"category": "latency"}
934  
935  
936  def test_log_multiple_assessment_types(trace_id):
937      feedback = Feedback(
938          name="accuracy",
939          value=0.95,
940          source=_LLM_ASSESSMENT_SOURCE,
941      )
942      mlflow.log_assessment(trace_id=trace_id, assessment=feedback)
943  
944      expectation = Expectation(
945          name="expected_output",
946          value="MLflow",
947          source=_HUMAN_ASSESSMENT_SOURCE,
948      )
949      mlflow.log_assessment(trace_id=trace_id, assessment=expectation)
950  
951      issue_ref = IssueReference(
952          issue_id="iss-11111",
953          issue_name="data_quality_issue",
954          source=_CODE_ASSESSMENT_SOURCE,
955      )
956      mlflow.log_assessment(trace_id=trace_id, assessment=issue_ref)
957  
958      trace = mlflow.get_trace(trace_id)
959      assert len(trace.info.assessments) == 3
960  
961      assessments_by_type = {}
962      for a in trace.info.assessments:
963          if isinstance(a, Feedback):
964              assessments_by_type["feedback"] = a
965          elif isinstance(a, Expectation):
966              assessments_by_type["expectation"] = a
967          elif isinstance(a, IssueReference):
968              assessments_by_type["issue"] = a
969  
970      assert assessments_by_type["feedback"].name == "accuracy"
971      assert assessments_by_type["feedback"].value == 0.95
972  
973      assert assessments_by_type["expectation"].name == "expected_output"
974      assert assessments_by_type["expectation"].value == "MLflow"
975  
976      assert assessments_by_type["issue"].name == "iss-11111"
977      assert assessments_by_type["issue"].issue_name == "data_quality_issue"
978      assert assessments_by_type["issue"].issue_id == "iss-11111"