/ examples / spark_udf / spark_udf.py
spark_udf.py
 1  from pyspark.sql import SparkSession
 2  from sklearn import datasets
 3  from sklearn.neighbors import KNeighborsClassifier
 4  
 5  import mlflow
 6  from mlflow.models import infer_signature
 7  
 8  with SparkSession.builder.getOrCreate() as spark:
 9      X, y = datasets.load_iris(as_frame=True, return_X_y=True)
10      model = KNeighborsClassifier()
11      model.fit(X, y)
12      predictions = model.predict(X)
13      signature = infer_signature(X, predictions)
14  
15      with mlflow.start_run():
16          model_info = mlflow.sklearn.log_model(model, name="model", signature=signature)
17  
18      infer_spark_df = spark.createDataFrame(X)
19  
20      pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, env_manager="conda")
21      result = infer_spark_df.select(pyfunc_udf(*X.columns).alias("predictions")).toPandas()
22  
23      print(result)