test_spark_connect.py
1 import numpy as np 2 import pandas as pd 3 import pytest 4 from pyspark.sql import SparkSession 5 from sklearn.datasets import load_iris 6 from sklearn.linear_model import LogisticRegression 7 8 import mlflow 9 10 11 @pytest.fixture(scope="module") 12 def spark(): 13 spark = SparkSession.builder.remote("local[2]").getOrCreate() 14 yield spark 15 spark.stop() 16 17 18 def test_spark_udf_spark_connect(spark): 19 X, y = load_iris(return_X_y=True) 20 model = LogisticRegression().fit(X, y) 21 with mlflow.start_run(): 22 info = mlflow.sklearn.log_model(model, name="model") 23 sdf = spark.createDataFrame(pd.DataFrame(X, columns=list("abcd"))) 24 udf = mlflow.pyfunc.spark_udf(spark, info.model_uri, env_manager="local") 25 result = sdf.select(udf(*sdf.columns).alias("preds")).toPandas() 26 np.testing.assert_almost_equal(result["preds"].to_numpy(), model.predict(X)) 27 28 29 @pytest.mark.parametrize("env_manager", ["conda", "virtualenv"]) 30 def test_spark_udf_spark_connect_unsupported_env_manager(spark, tmp_path, env_manager): 31 with pytest.raises( 32 mlflow.MlflowException, 33 match=f"Environment manager {env_manager!r} is not supported", 34 ): 35 mlflow.pyfunc.spark_udf(spark, str(tmp_path), env_manager=env_manager) 36 37 38 def test_spark_udf_spark_connect_with_model_logging(spark, db_uri): 39 X, y = load_iris(return_X_y=True, as_frame=True) 40 model = LogisticRegression().fit(X, y) 41 42 mlflow.set_tracking_uri(db_uri) 43 mlflow.set_experiment("test") 44 with mlflow.start_run(): 45 signature = mlflow.models.infer_signature(X, y) 46 model_info = mlflow.sklearn.log_model(model, name="model", signature=signature) 47 48 udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, env_manager="local") 49 X_test = X.head(5) 50 sdf = spark.createDataFrame(X_test) 51 preds = sdf.select(udf(*X_test.columns).alias("preds")).toPandas()["preds"] 52 np.testing.assert_array_almost_equal(preds, model.predict(X_test))