main.py
1 """ 2 Downloads the MovieLens dataset, ETLs it into Parquet, trains an 3 ALS model, and uses the ALS model to train a Keras neural network. 4 5 See README.md for more details. 6 """ 7 8 import os 9 10 import click 11 12 import mlflow 13 from mlflow.entities import RunStatus 14 from mlflow.tracking import MlflowClient 15 from mlflow.tracking.fluent import _get_experiment_id 16 from mlflow.utils import mlflow_tags 17 from mlflow.utils.logging_utils import eprint 18 19 20 def _already_ran(entry_point_name, parameters, git_commit, experiment_id=None): 21 """Best-effort detection of if a run with the given entrypoint name, 22 parameters, and experiment id already ran. The run must have completed 23 successfully and have at least the parameters provided. 24 """ 25 experiment_id = experiment_id if experiment_id is not None else _get_experiment_id() 26 client = MlflowClient() 27 all_runs = reversed(client.search_runs([experiment_id])) 28 for run in all_runs: 29 tags = run.data.tags 30 if tags.get(mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT, None) != entry_point_name: 31 continue 32 match_failed = False 33 for param_key, param_value in parameters.items(): 34 run_value = run.data.params.get(param_key) 35 if run_value != param_value: 36 match_failed = True 37 break 38 if match_failed: 39 continue 40 41 if run.info.to_proto().status != RunStatus.FINISHED: 42 eprint( 43 ("Run matched, but is not FINISHED, so skipping (run_id={}, status={})").format( 44 run.info.run_id, run.info.status 45 ) 46 ) 47 continue 48 49 previous_version = tags.get(mlflow_tags.MLFLOW_GIT_COMMIT, None) 50 if git_commit != previous_version: 51 eprint( 52 "Run matched, but has a different source version, so skipping " 53 f"(found={previous_version}, expected={git_commit})" 54 ) 55 continue 56 return client.get_run(run.info.run_id) 57 eprint("No matching run has been found.") 58 return None 59 60 61 # TODO(aaron): This is not great because it doesn't account for: 62 # - changes in code 63 # - changes in dependent steps 64 def _get_or_run(entrypoint, parameters, git_commit, use_cache=True): 65 existing_run = _already_ran(entrypoint, parameters, git_commit) 66 if use_cache and existing_run: 67 print(f"Found existing run for entrypoint={entrypoint} and parameters={parameters}") 68 return existing_run 69 print(f"Launching new run for entrypoint={entrypoint} and parameters={parameters}") 70 submitted_run = mlflow.run(".", entrypoint, parameters=parameters, env_manager="local") 71 return MlflowClient().get_run(submitted_run.run_id) 72 73 74 @click.command() 75 @click.option("--als-max-iter", default=10, type=int) 76 @click.option("--keras-hidden-units", default=20, type=int) 77 @click.option("--max-row-limit", default=100000, type=int) 78 def workflow(als_max_iter, keras_hidden_units, max_row_limit): 79 # Note: The entrypoint names are defined in MLproject. The artifact directories 80 # are documented by each step's .py file. 81 with mlflow.start_run() as active_run: 82 os.environ["SPARK_CONF_DIR"] = os.path.abspath(".") 83 git_commit = active_run.data.tags.get(mlflow_tags.MLFLOW_GIT_COMMIT) 84 load_raw_data_run = _get_or_run("load_raw_data", {}, git_commit) 85 ratings_csv_uri = os.path.join(load_raw_data_run.info.artifact_uri, "ratings-csv-dir") 86 etl_data_run = _get_or_run( 87 "etl_data", {"ratings_csv": ratings_csv_uri, "max_row_limit": max_row_limit}, git_commit 88 ) 89 ratings_parquet_uri = os.path.join(etl_data_run.info.artifact_uri, "ratings-parquet-dir") 90 91 # We specify a spark-defaults.conf to override the default driver memory. ALS requires 92 # significant memory. The driver memory property cannot be set by the application itself. 93 als_run = _get_or_run( 94 "als", {"ratings_data": ratings_parquet_uri, "max_iter": str(als_max_iter)}, git_commit 95 ) 96 als_model_uri = os.path.join(als_run.info.artifact_uri, "als-model") 97 98 keras_params = { 99 "ratings_data": ratings_parquet_uri, 100 "als_model_uri": als_model_uri, 101 "hidden_units": keras_hidden_units, 102 } 103 _get_or_run("train_keras", keras_params, git_commit, use_cache=False) 104 105 106 if __name__ == "__main__": 107 workflow()