/ examples / spark_udf / spark_udf_datetime.py
spark_udf_datetime.py
 1  import datetime
 2  import random
 3  
 4  from pyspark.sql import SparkSession
 5  from sklearn.compose import ColumnTransformer
 6  from sklearn.datasets import load_iris
 7  from sklearn.neighbors import KNeighborsClassifier
 8  from sklearn.pipeline import Pipeline
 9  from sklearn.preprocessing import FunctionTransformer
10  
11  import mlflow
12  
13  
14  def print_with_title(title, *args):
15      print(f"\n===== {title} =====\n")
16      for a in args:
17          print(a)
18  
19  
20  def extract_month(df):
21      print_with_title("extract_month input", df.head(), df.dtypes)
22      transformed = df.assign(month=df["timestamp"].dt.month)
23      print_with_title("extract_month output", transformed.head(), transformed.dtypes)
24      return transformed
25  
26  
27  def main():
28      X, y = load_iris(as_frame=True, return_X_y=True)
29      X = X.assign(
30          timestamp=[datetime.datetime(2022, random.randint(1, 12), 1) for _ in range(len(X))]
31      )
32      print_with_title("Ran input", X.head(30), X.dtypes)
33  
34      signature = mlflow.models.infer_signature(X, y)
35      print_with_title("Signature", signature)
36  
37      month_extractor = FunctionTransformer(extract_month, validate=False)
38      timestamp_remover = ColumnTransformer(
39          [("selector", "passthrough", X.columns.drop("timestamp"))], remainder="drop"
40      )
41      model = Pipeline([
42          ("month_extractor", month_extractor),
43          ("timestamp_remover", timestamp_remover),
44          ("knn", KNeighborsClassifier()),
45      ])
46      model.fit(X, y)
47  
48      with mlflow.start_run():
49          model_info = mlflow.sklearn.log_model(model, name="model", signature=signature)
50  
51      with SparkSession.builder.getOrCreate() as spark:
52          infer_spark_df = spark.createDataFrame(X.sample(n=10, random_state=42))
53          print_with_title(
54              "Inference input",
55              infer_spark_df._jdf.showString(5, 20, False),  # numRows, truncate, vertical
56              infer_spark_df._jdf.schema().treeString(),
57          )
58  
59          pyfunc_udf = mlflow.pyfunc.spark_udf(spark, model_info.model_uri, env_manager="conda")
60          result = infer_spark_df.select(pyfunc_udf(*X.columns).alias("predictions")).toPandas()
61          print_with_title("Inference result", result)
62  
63  
64  if __name__ == "__main__":
65      main()