spark_udf_datetime.py
1 import datetime 2 import random 3 4 from pyspark.sql import SparkSession 5 from sklearn.compose import ColumnTransformer 6 from sklearn.datasets import load_iris 7 from sklearn.neighbors import KNeighborsClassifier 8 from sklearn.pipeline import Pipeline 9 from sklearn.preprocessing import FunctionTransformer 10 11 import mlflow 12 13 14 def print_with_title(title, *args): 15 print(f"\n===== {title} =====\n") 16 for a in args: 17 print(a) 18 19 20 def extract_month(df): 21 print_with_title("extract_month input", df.head(), df.dtypes) 22 transformed = df.assign(month=df["timestamp"].dt.month) 23 print_with_title("extract_month output", transformed.head(), transformed.dtypes) 24 return transformed 25 26 27 def main(): 28 X, y = load_iris(as_frame=True, return_X_y=True) 29 X = X.assign( 30 timestamp=[datetime.datetime(2022, random.randint(1, 12), 1) for _ in range(len(X))] 31 ) 32 print_with_title("Ran input", X.head(30), X.dtypes) 33 34 signature = mlflow.models.infer_signature(X, y) 35 print_with_title("Signature", signature) 36 37 month_extractor = FunctionTransformer(extract_month, validate=False) 38 timestamp_remover = ColumnTransformer( 39 [("selector", "passthrough", X.columns.drop("timestamp"))], remainder="drop" 40 ) 41 model = Pipeline([ 42 ("month_extractor", month_extractor), 43 ("timestamp_remover", timestamp_remover), 44 ("knn", KNeighborsClassifier()), 45 ]) 46 model.fit(X, y) 47 48 with mlflow.start_run(): 49 model_info = mlflow.sklearn.log_model(model, name="model", signature=signature) 50 51 with SparkSession.builder.getOrCreate() as spark: 52 infer_spark_df = spark.createDataFrame(X.sample(n=10, random_state=42)) 53 print_with_title( 54 "Inference input", 55 infer_spark_df._jdf.showString(5, 20, False), # numRows, truncate, vertical 56 infer_spark_df._jdf.schema().treeString(), 57 ) 58 59 pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, env_manager="conda") 60 result = infer_spark_df.select(pyfunc_udf(*X.columns).alias("predictions")).toPandas() 61 print_with_title("Inference result", result) 62 63 64 if __name__ == "__main__": 65 main()