als.py
1 """ 2 Trains an Alternating Least Squares (ALS) model for user/movie ratings. 3 The input is a Parquet ratings dataset (see etl_data.py), and we output 4 an mlflow artifact called 'als-model'. 5 """ 6 7 import click 8 import pyspark 9 from pyspark.ml import Pipeline 10 from pyspark.ml.evaluation import RegressionEvaluator 11 from pyspark.ml.recommendation import ALS 12 13 import mlflow 14 import mlflow.spark 15 16 17 @click.command() 18 @click.option("--ratings-data") 19 @click.option("--split-prop", default=0.8, type=float) 20 @click.option("--max-iter", default=10, type=int) 21 @click.option("--reg-param", default=0.1, type=float) 22 @click.option("--rank", default=12, type=int) 23 @click.option("--cold-start-strategy", default="drop") 24 def train_als(ratings_data, split_prop, max_iter, reg_param, rank, cold_start_strategy): 25 seed = 42 26 27 with pyspark.sql.SparkSession.builder.getOrCreate() as spark: 28 ratings_df = spark.read.parquet(ratings_data) 29 (training_df, test_df) = ratings_df.randomSplit([split_prop, 1 - split_prop], seed=seed) 30 training_df.cache() 31 test_df.cache() 32 33 mlflow.log_metric("training_nrows", training_df.count()) 34 mlflow.log_metric("test_nrows", test_df.count()) 35 36 print(f"Training: {training_df.count()}, test: {test_df.count()}") 37 38 als = ( 39 ALS() 40 .setUserCol("userId") 41 .setItemCol("movieId") 42 .setRatingCol("rating") 43 .setPredictionCol("predictions") 44 .setMaxIter(max_iter) 45 .setSeed(seed) 46 .setRegParam(reg_param) 47 .setColdStartStrategy(cold_start_strategy) 48 .setRank(rank) 49 ) 50 51 als_model = Pipeline(stages=[als]).fit(training_df) 52 53 reg_eval = RegressionEvaluator( 54 predictionCol="predictions", labelCol="rating", metricName="mse" 55 ) 56 57 predicted_test_dF = als_model.transform(test_df) 58 59 test_mse = reg_eval.evaluate(predicted_test_dF) 60 train_mse = reg_eval.evaluate(als_model.transform(training_df)) 61 62 print(f"The model had a MSE on the test set of {test_mse}") 63 print(f"The model had a MSE on the (train) set of {train_mse}") 64 mlflow.log_metric("test_mse", test_mse) 65 mlflow.log_metric("train_mse", train_mse) 66 mlflow.spark.log_model(als_model, artifact_path="als-model") 67 68 69 if __name__ == "__main__": 70 train_als()