spark_udf.py
1 from pyspark.sql import SparkSession 2 from sklearn import datasets 3 from sklearn.neighbors import KNeighborsClassifier 4 5 import mlflow 6 from mlflow.models import infer_signature 7 8 with SparkSession.builder.getOrCreate() as spark: 9 X, y = datasets.load_iris(as_frame=True, return_X_y=True) 10 model = KNeighborsClassifier() 11 model.fit(X, y) 12 predictions = model.predict(X) 13 signature = infer_signature(X, predictions) 14 15 with mlflow.start_run(): 16 model_info = mlflow.sklearn.log_model(model, name="model", signature=signature) 17 18 infer_spark_df = spark.createDataFrame(X) 19 20 pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, env_manager="conda") 21 result = infer_spark_df.select(pyfunc_udf(*X.columns).alias("predictions")).toPandas() 22 23 print(result)