/ tests / openai / test_openai_evaluate.py
test_openai_evaluate.py
  1  from unittest import mock
  2  
  3  import openai
  4  import pandas as pd
  5  import pytest
  6  
  7  import mlflow
  8  from mlflow.models.evaluation import evaluate
  9  from mlflow.tracing.constant import TraceMetadataKey
 10  
 11  from tests.tracing.helper import get_traces, purge_traces, reset_autolog_state  # noqa: F401
 12  
 13  _EVAL_DATA = pd.DataFrame({
 14      "inputs": [
 15          "What is MLflow?",
 16          "What is Spark?",
 17      ],
 18      "ground_truth": [
 19          "MLflow is an open-source platform to manage the ML lifecycle.",
 20          "Spark is a unified analytics engine for big data processing.",
 21      ],
 22  })
 23  
 24  
 25  @pytest.fixture
 26  def client(monkeypatch, mock_openai):
 27      monkeypatch.setenv("OPENAI_API_KEY", "test")
 28      monkeypatch.setenv("OPENAI_API_BASE", mock_openai)
 29      return openai.OpenAI(api_key="test", base_url=mock_openai)
 30  
 31  
 32  @pytest.mark.parametrize(
 33      "config",
 34      [
 35          None,
 36          {"log_traces": False},
 37          {"log_traces": True},
 38      ],
 39  )
 40  @pytest.mark.usefixtures("reset_autolog_state")
 41  def test_openai_evaluate(client, config):
 42      if config:
 43          mlflow.openai.autolog(**config)
 44  
 45      is_trace_disabled = config and not config.get("log_traces", True)
 46      is_trace_enabled = config and config.get("log_traces", True)
 47  
 48      def model(inputs):
 49          return [
 50              client.chat.completions
 51              .create(
 52                  messages=[{"role": "user", "content": question}],
 53                  model="gpt-4o-mini",
 54                  temperature=0.0,
 55              )
 56              .choices[0]
 57              .message.content
 58              for question in inputs["inputs"]
 59          ]
 60  
 61      with mock.patch("mlflow.openai.log_model") as log_model_mock:
 62          with mlflow.start_run() as run:
 63              evaluate(
 64                  model,
 65                  data=_EVAL_DATA,
 66                  targets="ground_truth",
 67                  extra_metrics=[mlflow.metrics.exact_match()],
 68              )
 69          log_model_mock.assert_not_called()
 70  
 71      # Traces should not be logged when disabled explicitly
 72      if is_trace_disabled:
 73          assert len(get_traces()) == 0
 74      else:
 75          assert len(get_traces()) == 2
 76          assert run.info.run_id == get_traces()[0].info.request_metadata[TraceMetadataKey.SOURCE_RUN]
 77  
 78      purge_traces()
 79  
 80      # Test original autolog configs is restored
 81      client.chat.completions.create(
 82          messages=[{"role": "user", "content": "hi"}], model="gpt-4o-mini"
 83      )
 84  
 85      assert len(get_traces()) == (1 if is_trace_enabled else 0)
 86  
 87  
 88  @pytest.mark.usefixtures("reset_autolog_state")
 89  def test_openai_pyfunc_evaluate(client):
 90      with mlflow.start_run() as run:
 91          model_info = mlflow.openai.log_model(
 92              "gpt-4o-mini",
 93              "chat.completions",
 94              name="model",
 95              messages=[{"role": "system", "content": "You are an MLflow expert."}],
 96          )
 97  
 98          evaluate(
 99              model_info.model_uri,
100              data=_EVAL_DATA,
101              targets="ground_truth",
102              extra_metrics=[mlflow.metrics.exact_match()],
103          )
104      assert len(get_traces()) == 2
105      assert run.info.run_id == get_traces()[0].info.request_metadata[TraceMetadataKey.SOURCE_RUN]
106  
107  
108  @pytest.mark.parametrize("globally_disabled", [True, False])
109  @pytest.mark.usefixtures("reset_autolog_state")
110  def test_openai_evaluate_should_not_log_traces_when_disabled(client, globally_disabled):
111      if globally_disabled:
112          mlflow.autolog(disable=True)
113      else:
114          mlflow.openai.autolog(disable=True)
115  
116      def model(inputs):
117          return [
118              client.chat.completions
119              .create(
120                  messages=[{"role": "user", "content": question}],
121                  model="gpt-4o-mini",
122                  temperature=0.0,
123              )
124              .choices[0]
125              .message.content
126              for question in inputs["inputs"]
127          ]
128  
129      with mlflow.start_run():
130          evaluate(
131              model,
132              data=_EVAL_DATA,
133              targets="ground_truth",
134              extra_metrics=[mlflow.metrics.exact_match()],
135          )
136  
137      assert len(get_traces()) == 0