/ tests / cli / test_eval.py
test_eval.py
  1  import re
  2  from unittest import mock
  3  
  4  import click
  5  import pandas as pd
  6  import pytest
  7  
  8  import mlflow
  9  from mlflow.cli.eval import evaluate_traces
 10  from mlflow.entities import Trace, TraceInfo
 11  from mlflow.genai.scorers.base import scorer
 12  
 13  
 14  def test_evaluate_traces_with_single_trace_table_output():
 15      experiment_id = mlflow.create_experiment("test_experiment")
 16  
 17      mock_trace = mock.Mock(spec=Trace)
 18      mock_trace.info = mock.Mock(spec=TraceInfo)
 19      mock_trace.info.trace_id = "tr-test-123"
 20      mock_trace.info.experiment_id = experiment_id
 21  
 22      mock_results = mock.Mock()
 23      mock_results.run_id = "run-eval-456"
 24      mock_results.result_df = pd.DataFrame([
 25          {
 26              "trace_id": "tr-test-123",
 27              "assessments": [
 28                  {
 29                      "assessment_name": "RelevanceToQuery",
 30                      "feedback": {"value": "yes"},
 31                      "rationale": "The answer is relevant",
 32                      "metadata": {"mlflow.assessment.sourceRunId": "run-eval-456"},
 33                  }
 34              ],
 35          }
 36      ])
 37  
 38      with (
 39          mock.patch(
 40              "mlflow.cli.eval.MlflowClient.get_trace", return_value=mock_trace
 41          ) as mock_get_trace,
 42          mock.patch("mlflow.cli.eval.evaluate", return_value=mock_results) as mock_evaluate,
 43      ):
 44          evaluate_traces(
 45              experiment_id=experiment_id,
 46              trace_ids="tr-test-123",
 47              scorers="RelevanceToQuery",
 48              output_format="table",
 49          )
 50  
 51          mock_get_trace.assert_called_once_with("tr-test-123", display=False)
 52  
 53          assert mock_evaluate.call_count == 1
 54          call_args = mock_evaluate.call_args
 55          assert "data" in call_args.kwargs
 56  
 57          expected_df = pd.DataFrame([{"trace_id": "tr-test-123", "trace": mock_trace}])
 58          pd.testing.assert_frame_equal(call_args.kwargs["data"], expected_df)
 59  
 60          assert "scorers" in call_args.kwargs
 61          assert len(call_args.kwargs["scorers"]) == 1
 62          assert call_args.kwargs["scorers"][0].__class__.__name__ == "RelevanceToQuery"
 63  
 64  
 65  def test_evaluate_traces_with_multiple_traces_json_output():
 66      experiment = mlflow.create_experiment("test_experiment_multi")
 67  
 68      mock_trace1 = mock.Mock(spec=Trace)
 69      mock_trace1.info = mock.Mock(spec=TraceInfo)
 70      mock_trace1.info.trace_id = "tr-test-1"
 71      mock_trace1.info.experiment_id = experiment
 72  
 73      mock_trace2 = mock.Mock(spec=Trace)
 74      mock_trace2.info = mock.Mock(spec=TraceInfo)
 75      mock_trace2.info.trace_id = "tr-test-2"
 76      mock_trace2.info.experiment_id = experiment
 77  
 78      mock_results = mock.Mock()
 79      mock_results.run_id = "run-eval-789"
 80      mock_results.result_df = pd.DataFrame([
 81          {
 82              "trace_id": "tr-test-1",
 83              "assessments": [
 84                  {
 85                      "assessment_name": "Correctness",
 86                      "feedback": {"value": "correct"},
 87                      "rationale": "Content is correct",
 88                      "metadata": {"mlflow.assessment.sourceRunId": "run-eval-789"},
 89                  }
 90              ],
 91          },
 92          {
 93              "trace_id": "tr-test-2",
 94              "assessments": [
 95                  {
 96                      "assessment_name": "Correctness",
 97                      "feedback": {"value": "correct"},
 98                      "rationale": "Also correct",
 99                      "metadata": {"mlflow.assessment.sourceRunId": "run-eval-789"},
100                  }
101              ],
102          },
103      ])
104  
105      with (
106          mock.patch(
107              "mlflow.cli.eval.MlflowClient.get_trace",
108              side_effect=[mock_trace1, mock_trace2],
109          ) as mock_get_trace,
110          mock.patch("mlflow.cli.eval.evaluate", return_value=mock_results) as mock_evaluate,
111      ):
112          evaluate_traces(
113              experiment_id=experiment,
114              trace_ids="tr-test-1,tr-test-2",
115              scorers="Correctness",
116              output_format="json",
117          )
118  
119          assert mock_get_trace.call_count == 2
120          mock_get_trace.assert_any_call("tr-test-1", display=False)
121          mock_get_trace.assert_any_call("tr-test-2", display=False)
122  
123          assert mock_evaluate.call_count == 1
124          call_args = mock_evaluate.call_args
125          expected_df = pd.DataFrame([
126              {"trace_id": "tr-test-1", "trace": mock_trace1},
127              {"trace_id": "tr-test-2", "trace": mock_trace2},
128          ])
129          pd.testing.assert_frame_equal(call_args.kwargs["data"], expected_df)
130  
131  
132  def test_evaluate_traces_with_nonexistent_trace():
133      experiment = mlflow.create_experiment("test_experiment_error")
134  
135      with mock.patch("mlflow.cli.eval.MlflowClient.get_trace", return_value=None) as mock_get_trace:
136          with pytest.raises(click.UsageError, match="Trace with ID 'tr-nonexistent' not found"):
137              evaluate_traces(
138                  experiment_id=experiment,
139                  trace_ids="tr-nonexistent",
140                  scorers="RelevanceToQuery",
141                  output_format="table",
142              )
143  
144          mock_get_trace.assert_called_once_with("tr-nonexistent", display=False)
145  
146  
147  def test_evaluate_traces_with_trace_from_wrong_experiment():
148      experiment1 = mlflow.create_experiment("test_experiment_1")
149      experiment2 = mlflow.create_experiment("test_experiment_2")
150  
151      mock_trace = mock.Mock(spec=Trace)
152      mock_trace.info = mock.Mock(spec=TraceInfo)
153      mock_trace.info.trace_id = "tr-test-123"
154      mock_trace.info.experiment_id = experiment2
155  
156      with mock.patch(
157          "mlflow.cli.eval.MlflowClient.get_trace", return_value=mock_trace
158      ) as mock_get_trace:
159          with pytest.raises(click.UsageError, match="belongs to experiment"):
160              evaluate_traces(
161                  experiment_id=experiment1,
162                  trace_ids="tr-test-123",
163                  scorers="RelevanceToQuery",
164                  output_format="table",
165              )
166  
167          mock_get_trace.assert_called_once_with("tr-test-123", display=False)
168  
169  
170  def test_evaluate_traces_integration():
171      experiment_id = mlflow.create_experiment("test_experiment_integration")
172      mlflow.set_experiment(experiment_id=experiment_id)
173  
174      # Create a few real traces with inputs and outputs
175      trace_ids = []
176      for i in range(3):
177          with mlflow.start_span(name=f"test_span_{i}") as span:
178              span.set_inputs({"question": f"What is test {i}?"})
179              span.set_outputs(f"This is answer {i}")
180              trace_ids.append(span.trace_id)
181  
182      # Define a simple code-based scorer inline
183      @scorer
184      def simple_scorer(outputs):
185          """Extract the digit from the output string and return it as the score"""
186          if match := re.search(r"\d+", outputs):
187              return float(match.group())
188          return 0.0
189  
190      with mock.patch(
191          "mlflow.cli.eval.resolve_scorers", return_value=[simple_scorer]
192      ) as mock_resolve:
193          evaluate_traces(
194              experiment_id=experiment_id,
195              trace_ids=",".join(trace_ids),
196              scorers="simple_scorer",  # This will be intercepted by our mock
197              output_format="table",
198          )
199          mock_resolve.assert_called_once()
200  
201      # Verify that the evaluation results are as expected
202      traces = mlflow.search_traces(locations=[experiment_id], return_type="list")
203      assert len(traces) == 3
204  
205      # Sort traces by their outputs to get consistent ordering
206      traces = sorted(traces, key=lambda t: t.data.spans[0].outputs)
207  
208      for i, trace in enumerate(traces):
209          assessments = trace.info.assessments
210          assert len(assessments) > 0
211  
212          scorer_assessments = [a for a in assessments if a.name == "simple_scorer"]
213          assert len(scorer_assessments) == 1
214  
215          assessment = scorer_assessments[0]
216          # Each trace should have a score equal to its index (0, 1, 2)
217          assert assessment.value == float(i)