/ tests / pyfunc / test_spark_connect.py
test_spark_connect.py
 1  import numpy as np
 2  import pandas as pd
 3  import pytest
 4  from pyspark.sql import SparkSession
 5  from sklearn.datasets import load_iris
 6  from sklearn.linear_model import LogisticRegression
 7  
 8  import mlflow
 9  
10  
11  @pytest.fixture(scope="module")
12  def spark():
13      spark = SparkSession.builder.remote("local[2]").getOrCreate()
14      yield spark
15      spark.stop()
16  
17  
18  def test_spark_udf_spark_connect(spark):
19      X, y = load_iris(return_X_y=True)
20      model = LogisticRegression().fit(X, y)
21      with mlflow.start_run():
22          info = mlflow.sklearn.log_model(model, name="model")
23      sdf = spark.createDataFrame(pd.DataFrame(X, columns=list("abcd")))
24      udf = mlflow.pyfunc.spark_udf(spark, info.model_uri, env_manager="local")
25      result = sdf.select(udf(*sdf.columns).alias("preds")).toPandas()
26      np.testing.assert_almost_equal(result["preds"].to_numpy(), model.predict(X))
27  
28  
29  @pytest.mark.parametrize("env_manager", ["conda", "virtualenv"])
30  def test_spark_udf_spark_connect_unsupported_env_manager(spark, tmp_path, env_manager):
31      with pytest.raises(
32          mlflow.MlflowException,
33          match=f"Environment manager {env_manager!r} is not supported",
34      ):
35          mlflow.pyfunc.spark_udf(spark, str(tmp_path), env_manager=env_manager)
36  
37  
38  def test_spark_udf_spark_connect_with_model_logging(spark, db_uri):
39      X, y = load_iris(return_X_y=True, as_frame=True)
40      model = LogisticRegression().fit(X, y)
41  
42      mlflow.set_tracking_uri(db_uri)
43      mlflow.set_experiment("test")
44      with mlflow.start_run():
45          signature = mlflow.models.infer_signature(X, y)
46          model_info = mlflow.sklearn.log_model(model, name="model", signature=signature)
47  
48      udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, env_manager="local")
49      X_test = X.head(5)
50      sdf = spark.createDataFrame(X_test)
51      preds = sdf.select(udf(*X_test.columns).alias("preds")).toPandas()["preds"]
52      np.testing.assert_array_almost_equal(preds, model.predict(X_test))