main.py
  1  """
  2  Downloads the MovieLens dataset, ETLs it into Parquet, trains an
  3  ALS model, and uses the ALS model to train a Keras neural network.
  4  
  5  See README.md for more details.
  6  """
  7  
  8  import os
  9  
 10  import click
 11  
 12  import mlflow
 13  from mlflow.entities import RunStatus
 14  from mlflow.tracking import MlflowClient
 15  from mlflow.tracking.fluent import _get_experiment_id
 16  from mlflow.utils import mlflow_tags
 17  from mlflow.utils.logging_utils import eprint
 18  
 19  
 20  def _already_ran(entry_point_name, parameters, git_commit, experiment_id=None):
 21      """Best-effort detection of if a run with the given entrypoint name,
 22      parameters, and experiment id already ran. The run must have completed
 23      successfully and have at least the parameters provided.
 24      """
 25      experiment_id = experiment_id if experiment_id is not None else _get_experiment_id()
 26      client = MlflowClient()
 27      all_runs = reversed(client.search_runs([experiment_id]))
 28      for run in all_runs:
 29          tags = run.data.tags
 30          if tags.get(mlflow_tags.MLFLOW_PROJECT_ENTRY_POINT, None) != entry_point_name:
 31              continue
 32          match_failed = False
 33          for param_key, param_value in parameters.items():
 34              run_value = run.data.params.get(param_key)
 35              if run_value != param_value:
 36                  match_failed = True
 37                  break
 38          if match_failed:
 39              continue
 40  
 41          if run.info.to_proto().status != RunStatus.FINISHED:
 42              eprint(
 43                  ("Run matched, but is not FINISHED, so skipping (run_id={}, status={})").format(
 44                      run.info.run_id, run.info.status
 45                  )
 46              )
 47              continue
 48  
 49          previous_version = tags.get(mlflow_tags.MLFLOW_GIT_COMMIT, None)
 50          if git_commit != previous_version:
 51              eprint(
 52                  "Run matched, but has a different source version, so skipping "
 53                  f"(found={previous_version}, expected={git_commit})"
 54              )
 55              continue
 56          return client.get_run(run.info.run_id)
 57      eprint("No matching run has been found.")
 58      return None
 59  
 60  
 61  # TODO(aaron): This is not great because it doesn't account for:
 62  # - changes in code
 63  # - changes in dependent steps
 64  def _get_or_run(entrypoint, parameters, git_commit, use_cache=True):
 65      existing_run = _already_ran(entrypoint, parameters, git_commit)
 66      if use_cache and existing_run:
 67          print(f"Found existing run for entrypoint={entrypoint} and parameters={parameters}")
 68          return existing_run
 69      print(f"Launching new run for entrypoint={entrypoint} and parameters={parameters}")
 70      submitted_run = mlflow.run(".", entrypoint, parameters=parameters, env_manager="local")
 71      return MlflowClient().get_run(submitted_run.run_id)
 72  
 73  
 74  @click.command()
 75  @click.option("--als-max-iter", default=10, type=int)
 76  @click.option("--keras-hidden-units", default=20, type=int)
 77  @click.option("--max-row-limit", default=100000, type=int)
 78  def workflow(als_max_iter, keras_hidden_units, max_row_limit):
 79      # Note: The entrypoint names are defined in MLproject. The artifact directories
 80      # are documented by each step's .py file.
 81      with mlflow.start_run() as active_run:
 82          os.environ["SPARK_CONF_DIR"] = os.path.abspath(".")
 83          git_commit = active_run.data.tags.get(mlflow_tags.MLFLOW_GIT_COMMIT)
 84          load_raw_data_run = _get_or_run("load_raw_data", {}, git_commit)
 85          ratings_csv_uri = os.path.join(load_raw_data_run.info.artifact_uri, "ratings-csv-dir")
 86          etl_data_run = _get_or_run(
 87              "etl_data", {"ratings_csv": ratings_csv_uri, "max_row_limit": max_row_limit}, git_commit
 88          )
 89          ratings_parquet_uri = os.path.join(etl_data_run.info.artifact_uri, "ratings-parquet-dir")
 90  
 91          # We specify a spark-defaults.conf to override the default driver memory. ALS requires
 92          # significant memory. The driver memory property cannot be set by the application itself.
 93          als_run = _get_or_run(
 94              "als", {"ratings_data": ratings_parquet_uri, "max_iter": str(als_max_iter)}, git_commit
 95          )
 96          als_model_uri = os.path.join(als_run.info.artifact_uri, "als-model")
 97  
 98          keras_params = {
 99              "ratings_data": ratings_parquet_uri,
100              "als_model_uri": als_model_uri,
101              "hidden_units": keras_hidden_units,
102          }
103          _get_or_run("train_keras", keras_params, git_commit, use_cache=False)
104  
105  
106  if __name__ == "__main__":
107      workflow()