/ tests / llama_index / test_llama_index_evaluate.py
test_llama_index_evaluate.py
  1  import pandas as pd
  2  import pytest
  3  
  4  import mlflow
  5  from mlflow.metrics import latency
  6  from mlflow.tracing.constant import TraceMetadataKey
  7  
  8  from tests.openai.test_openai_evaluate import purge_traces
  9  from tests.tracing.helper import get_traces, reset_autolog_state  # noqa: F401
 10  
 11  _EVAL_DATA = pd.DataFrame({
 12      "inputs": [
 13          "What is MLflow?",
 14          "What is Spark?",
 15      ],
 16      "ground_truth": [
 17          "MLflow is an open-source platform to manage the ML lifecycle.",
 18          "Spark is a unified analytics engine for big data processing.",
 19      ],
 20  })
 21  
 22  
 23  @pytest.mark.parametrize(
 24      "config",
 25      [
 26          None,
 27          {"log_traces": False},
 28          {"log_traces": True},
 29      ],
 30  )
 31  @pytest.mark.usefixtures("reset_autolog_state")
 32  def test_llama_index_evaluate(single_index, config):
 33      if config:
 34          mlflow.llama_index.autolog(**config)
 35          mlflow.openai.autolog(**config)  # Our model contains OpenAI call as well
 36  
 37      is_trace_disabled = config and not config.get("log_traces", True)
 38      is_trace_enabled = config and config.get("log_traces", True)
 39  
 40      engine = single_index.as_query_engine()
 41  
 42      def model(inputs):
 43          return [engine.query(question) for question in inputs["inputs"]]
 44  
 45      with mlflow.start_run() as run:
 46          eval_result = mlflow.evaluate(
 47              model,
 48              data=_EVAL_DATA,
 49              targets="ground_truth",
 50              extra_metrics=[latency()],
 51          )
 52      assert eval_result.metrics["latency/mean"] > 0
 53  
 54      # Traces should not be logged when disabled explicitly
 55      if is_trace_disabled:
 56          assert len(get_traces()) == 0
 57      else:
 58          assert len(get_traces()) == 2
 59          assert run.info.run_id == get_traces()[0].info.request_metadata[TraceMetadataKey.SOURCE_RUN]
 60  
 61      purge_traces()
 62  
 63      # Test original autolog configs is restored
 64      engine.query("text")
 65      assert len(get_traces()) == (1 if is_trace_enabled else 0)
 66  
 67  
 68  @pytest.mark.parametrize("engine_type", ["query", "chat"])
 69  @pytest.mark.usefixtures("reset_autolog_state")
 70  def test_llama_index_pyfunc_evaluate(engine_type, single_index):
 71      with mlflow.start_run() as run:
 72          model_info = mlflow.llama_index.log_model(
 73              single_index,
 74              name="llama_index",
 75              engine_type=engine_type,
 76          )
 77  
 78          eval_result = mlflow.evaluate(
 79              model_info.model_uri,
 80              data=_EVAL_DATA,
 81              targets="ground_truth",
 82              extra_metrics=[latency()],
 83          )
 84      assert eval_result.metrics["latency/mean"] > 0
 85  
 86      # Traces should be automatically enabled during evaluation
 87      assert len(get_traces()) == 2
 88      assert run.info.run_id == get_traces()[0].info.request_metadata[TraceMetadataKey.SOURCE_RUN]
 89  
 90  
 91  @pytest.mark.parametrize("globally_disabled", [True, False])
 92  @pytest.mark.usefixtures("reset_autolog_state")
 93  def test_llama_index_evaluate_should_not_log_traces_when_disabled(single_index, globally_disabled):
 94      if globally_disabled:
 95          mlflow.autolog(disable=True)
 96      else:
 97          mlflow.llama_index.autolog(disable=True)
 98          mlflow.openai.autolog(disable=True)  # Our model contains OpenAI call as well
 99  
100      def model(inputs):
101          engine = single_index.as_query_engine()
102          return [engine.query(question) for question in inputs["inputs"]]
103  
104      with mlflow.start_run():
105          eval_result = mlflow.evaluate(
106              model,
107              data=_EVAL_DATA,
108              targets="ground_truth",
109              extra_metrics=[latency()],
110          )
111      assert eval_result.metrics["latency/mean"] > 0
112      assert len(get_traces()) == 0