test_default_evaluator_delta.py
1 import tempfile 2 3 import pandas as pd 4 import pytest 5 from pyspark.sql import SparkSession 6 7 import mlflow 8 from mlflow.exceptions import MlflowException 9 10 11 def language_model(inputs: list[str]) -> list[str]: 12 return inputs 13 14 15 def test_write_to_delta_fails_without_spark(): 16 with mlflow.start_run(): 17 model_info = mlflow.pyfunc.log_model( 18 name="model", python_model=language_model, input_example=["a", "b"] 19 ) 20 data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]}) 21 with pytest.raises( 22 MlflowException, 23 match="eval_results_path is only supported in Spark environment", 24 ): 25 mlflow.evaluate( 26 model_info.model_uri, 27 data, 28 extra_metrics=[mlflow.metrics.latency()], 29 evaluators="default", 30 evaluator_config={ 31 "eval_results_path": "my_path", 32 "eval_results_mode": "overwrite", 33 }, 34 ) 35 36 37 @pytest.fixture 38 def spark_session_with_delta(): 39 with tempfile.TemporaryDirectory() as tmpdir: 40 with ( 41 SparkSession.builder 42 .master("local[*]") 43 .config("spark.jars.packages", "io.delta:delta-spark_2.13:4.0.0") 44 .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") 45 .config( 46 "spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog" 47 ) 48 .config("spark.sql.warehouse.dir", tmpdir) 49 .getOrCreate() as spark 50 ): 51 yield spark, tmpdir 52 53 54 def test_write_to_delta_fails_with_invalid_mode(spark_session_with_delta): 55 with mlflow.start_run(): 56 model_info = mlflow.pyfunc.log_model( 57 name="model", python_model=language_model, input_example=["a", "b"] 58 ) 59 data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]}) 60 with pytest.raises( 61 MlflowException, 62 match="eval_results_mode can only be 'overwrite' or 'append'", 63 ): 64 mlflow.evaluate( 65 model_info.model_uri, 66 data, 67 extra_metrics=[mlflow.metrics.latency()], 68 evaluators="default", 69 evaluator_config={ 70 "eval_results_path": "my_path", 71 "eval_results_mode": "invalid_mode", 72 }, 73 ) 74 75 76 def test_write_eval_table_to_delta(spark_session_with_delta): 77 spark_session, tmpdir = spark_session_with_delta 78 with mlflow.start_run(): 79 model_info = mlflow.pyfunc.log_model( 80 name="model", python_model=language_model, input_example=["a", "b"] 81 ) 82 data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]}) 83 results = mlflow.evaluate( 84 model_info.model_uri, 85 data, 86 extra_metrics=[mlflow.metrics.latency()], 87 evaluators="default", 88 evaluator_config={ 89 "eval_results_path": "my_path", 90 "eval_results_mode": "overwrite", 91 }, 92 ) 93 94 eval_table = results.tables["eval_results_table"].sort_values("text").reset_index(drop=True) 95 96 eval_table_from_delta = ( 97 spark_session.read 98 .format("delta") 99 .load(f"{tmpdir}/my_path") 100 .toPandas() 101 .sort_values("text") 102 .reset_index(drop=True) 103 ) 104 105 pd.testing.assert_frame_equal(eval_table_from_delta, eval_table) 106 107 108 def test_write_eval_table_to_delta_append(spark_session_with_delta): 109 spark_session, tmpdir = spark_session_with_delta 110 with mlflow.start_run(): 111 model_info = mlflow.pyfunc.log_model( 112 name="model", python_model=language_model, input_example=["a", "b"] 113 ) 114 data = pd.DataFrame({"text": ["Hello world", "My name is MLflow"]}) 115 mlflow.evaluate( 116 model_info.model_uri, 117 data, 118 extra_metrics=[mlflow.metrics.latency()], 119 evaluators="default", 120 evaluator_config={ 121 "eval_results_path": "my_path", 122 "eval_results_mode": "overwrite", 123 }, 124 ) 125 126 mlflow.evaluate( 127 model_info.model_uri, 128 data, 129 extra_metrics=[mlflow.metrics.latency()], 130 evaluators="default", 131 evaluator_config={ 132 "eval_results_path": "my_path", 133 "eval_results_mode": "append", 134 }, 135 ) 136 137 eval_table_from_delta = spark_session.read.format("delta").load(f"{tmpdir}/my_path") 138 139 assert eval_table_from_delta.count() == 4