als.py
 1  """
 2  Trains an Alternating Least Squares (ALS) model for user/movie ratings.
 3  The input is a Parquet ratings dataset (see etl_data.py), and we output
 4  an mlflow artifact called 'als-model'.
 5  """
 6  
 7  import click
 8  import pyspark
 9  from pyspark.ml import Pipeline
10  from pyspark.ml.evaluation import RegressionEvaluator
11  from pyspark.ml.recommendation import ALS
12  
13  import mlflow
14  import mlflow.spark
15  
16  
17  @click.command()
18  @click.option("--ratings-data")
19  @click.option("--split-prop", default=0.8, type=float)
20  @click.option("--max-iter", default=10, type=int)
21  @click.option("--reg-param", default=0.1, type=float)
22  @click.option("--rank", default=12, type=int)
23  @click.option("--cold-start-strategy", default="drop")
24  def train_als(ratings_data, split_prop, max_iter, reg_param, rank, cold_start_strategy):
25      seed = 42
26  
27      with pyspark.sql.SparkSession.builder.getOrCreate() as spark:
28          ratings_df = spark.read.parquet(ratings_data)
29          (training_df, test_df) = ratings_df.randomSplit([split_prop, 1 - split_prop], seed=seed)
30          training_df.cache()
31          test_df.cache()
32  
33          mlflow.log_metric("training_nrows", training_df.count())
34          mlflow.log_metric("test_nrows", test_df.count())
35  
36          print(f"Training: {training_df.count()}, test: {test_df.count()}")
37  
38          als = (
39              ALS()
40              .setUserCol("userId")
41              .setItemCol("movieId")
42              .setRatingCol("rating")
43              .setPredictionCol("predictions")
44              .setMaxIter(max_iter)
45              .setSeed(seed)
46              .setRegParam(reg_param)
47              .setColdStartStrategy(cold_start_strategy)
48              .setRank(rank)
49          )
50  
51          als_model = Pipeline(stages=[als]).fit(training_df)
52  
53          reg_eval = RegressionEvaluator(
54              predictionCol="predictions", labelCol="rating", metricName="mse"
55          )
56  
57          predicted_test_dF = als_model.transform(test_df)
58  
59          test_mse = reg_eval.evaluate(predicted_test_dF)
60          train_mse = reg_eval.evaluate(als_model.transform(training_df))
61  
62          print(f"The model had a MSE on the test set of {test_mse}")
63          print(f"The model had a MSE on the (train) set of {train_mse}")
64          mlflow.log_metric("test_mse", test_mse)
65          mlflow.log_metric("train_mse", train_mse)
66          mlflow.spark.log_model(als_model, artifact_path="als-model")
67  
68  
69  if __name__ == "__main__":
70      train_als()