/ tests / evaluate / test_default_evaluator_delta.py
test_default_evaluator_delta.py
  1  import tempfile
  2  
  3  import pandas as pd
  4  import pytest
  5  from pyspark.sql import SparkSession
  6  
  7  import mlflow
  8  from mlflow.exceptions import MlflowException
  9  
 10  
 11  def language_model(inputs: list[str]) -> list[str]:
 12      return inputs
 13  
 14  
 15  def test_write_to_delta_fails_without_spark():
 16      with mlflow.start_run():
 17          model_info = mlflow.pyfunc.log_model(
 18              name="model", python_model=language_model, input_example=["a", "b"]
 19          )
 20          data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]})
 21          with pytest.raises(
 22              MlflowException,
 23              match="eval_results_path is only supported in Spark environment",
 24          ):
 25              mlflow.evaluate(
 26                  model_info.model_uri,
 27                  data,
 28                  extra_metrics=[mlflow.metrics.latency()],
 29                  evaluators="default",
 30                  evaluator_config={
 31                      "eval_results_path": "my_path",
 32                      "eval_results_mode": "overwrite",
 33                  },
 34              )
 35  
 36  
 37  @pytest.fixture
 38  def spark_session_with_delta():
 39      with tempfile.TemporaryDirectory() as tmpdir:
 40          with (
 41              SparkSession.builder
 42              .master("local[*]")
 43              .config("spark.jars.packages", "io.delta:delta-spark_2.13:4.0.0")
 44              .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
 45              .config(
 46                  "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog"
 47              )
 48              .config("spark.sql.warehouse.dir", tmpdir)
 49              .getOrCreate() as spark
 50          ):
 51              yield spark, tmpdir
 52  
 53  
 54  def test_write_to_delta_fails_with_invalid_mode(spark_session_with_delta):
 55      with mlflow.start_run():
 56          model_info = mlflow.pyfunc.log_model(
 57              name="model", python_model=language_model, input_example=["a", "b"]
 58          )
 59          data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]})
 60          with pytest.raises(
 61              MlflowException,
 62              match="eval_results_mode can only be 'overwrite' or 'append'",
 63          ):
 64              mlflow.evaluate(
 65                  model_info.model_uri,
 66                  data,
 67                  extra_metrics=[mlflow.metrics.latency()],
 68                  evaluators="default",
 69                  evaluator_config={
 70                      "eval_results_path": "my_path",
 71                      "eval_results_mode": "invalid_mode",
 72                  },
 73              )
 74  
 75  
 76  def test_write_eval_table_to_delta(spark_session_with_delta):
 77      spark_session, tmpdir = spark_session_with_delta
 78      with mlflow.start_run():
 79          model_info = mlflow.pyfunc.log_model(
 80              name="model", python_model=language_model, input_example=["a", "b"]
 81          )
 82          data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]})
 83          results = mlflow.evaluate(
 84              model_info.model_uri,
 85              data,
 86              extra_metrics=[mlflow.metrics.latency()],
 87              evaluators="default",
 88              evaluator_config={
 89                  "eval_results_path": "my_path",
 90                  "eval_results_mode": "overwrite",
 91              },
 92          )
 93  
 94          eval_table = results.tables["eval_results_table"].sort_values("text").reset_index(drop=True)
 95  
 96          eval_table_from_delta = (
 97              spark_session.read
 98              .format("delta")
 99              .load(f"{tmpdir}/my_path")
100              .toPandas()
101              .sort_values("text")
102              .reset_index(drop=True)
103          )
104  
105          pd.testing.assert_frame_equal(eval_table_from_delta, eval_table)
106  
107  
108  def test_write_eval_table_to_delta_append(spark_session_with_delta):
109      spark_session, tmpdir = spark_session_with_delta
110      with mlflow.start_run():
111          model_info = mlflow.pyfunc.log_model(
112              name="model", python_model=language_model, input_example=["a", "b"]
113          )
114          data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]})
115          mlflow.evaluate(
116              model_info.model_uri,
117              data,
118              extra_metrics=[mlflow.metrics.latency()],
119              evaluators="default",
120              evaluator_config={
121                  "eval_results_path": "my_path",
122                  "eval_results_mode": "overwrite",
123              },
124          )
125  
126          mlflow.evaluate(
127              model_info.model_uri,
128              data,
129              extra_metrics=[mlflow.metrics.latency()],
130              evaluators="default",
131              evaluator_config={
132                  "eval_results_path": "my_path",
133                  "eval_results_mode": "append",
134              },
135          )
136  
137          eval_table_from_delta = spark_session.read.format("delta").load(f"{tmpdir}/my_path")
138  
139          assert eval_table_from_delta.count() == 4