/ examples / hyperparam / search_random.py
search_random.py
  1  """
  2  Example of hyperparameter search in MLflow using simple random search.
  3  
  4  The run method will evaluate random combinations of parameters in a new MLflow run.
  5  
  6  The runs are evaluated based on validation set loss. Test set score is calculated to verify the
  7  results.
  8  
  9  Several runs can be run in parallel.
 10  """
 11  
 12  from concurrent.futures import ThreadPoolExecutor
 13  
 14  import click
 15  import numpy as np
 16  
 17  import mlflow
 18  import mlflow.projects
 19  from mlflow.tracking import MlflowClient
 20  
 21  _inf = np.finfo(np.float64).max
 22  
 23  
 24  @click.command(help="Perform grid search over train (main entry point).")
 25  @click.option("--max-runs", type=click.INT, default=32, help="Maximum number of runs to evaluate.")
 26  @click.option("--max-p", type=click.INT, default=1, help="Maximum number of parallel runs.")
 27  @click.option("--epochs", type=click.INT, default=32, help="Number of epochs")
 28  @click.option("--metric", type=click.STRING, default="rmse", help="Metric to optimize on.")
 29  @click.option("--seed", type=click.INT, default=97531, help="Seed for the random generator")
 30  @click.argument("training_data")
 31  def run(training_data, max_runs, max_p, epochs, metric, seed):
 32      train_metric = f"train_{metric}"
 33      val_metric = f"val_{metric}"
 34      test_metric = f"test_{metric}"
 35      np.random.seed(seed)
 36      tracking_client = MlflowClient()
 37  
 38      def new_eval(
 39          nepochs, experiment_id, null_train_loss=_inf, null_val_loss=_inf, null_test_loss=_inf
 40      ):
 41          def eval(params):
 42              lr, momentum = params
 43              with mlflow.start_run(nested=True) as child_run:
 44                  p = mlflow.projects.run(
 45                      run_id=child_run.info.run_id,
 46                      uri=".",
 47                      entry_point="train",
 48                      parameters={
 49                          "training_data": training_data,
 50                          "epochs": str(nepochs),
 51                          "learning_rate": str(lr),
 52                          "momentum": str(momentum),
 53                          "seed": str(seed),
 54                      },
 55                      experiment_id=experiment_id,
 56                      synchronous=False,
 57                  )
 58                  succeeded = p.wait()
 59                  mlflow.log_params({"lr": lr, "momentum": momentum})
 60              if succeeded:
 61                  training_run = tracking_client.get_run(p.run_id)
 62                  metrics = training_run.data.metrics
 63                  # cap the loss at the loss of the null model
 64                  train_loss = min(null_train_loss, metrics[train_metric])
 65                  val_loss = min(null_val_loss, metrics[val_metric])
 66                  test_loss = min(null_test_loss, metrics[test_metric])
 67              else:
 68                  # run failed => return null loss
 69                  tracking_client.set_terminated(p.run_id, "FAILED")
 70                  train_loss = null_train_loss
 71                  val_loss = null_val_loss
 72                  test_loss = null_test_loss
 73              mlflow.log_metrics({
 74                  f"train_{metric}": train_loss,
 75                  f"val_{metric}": val_loss,
 76                  f"test_{metric}": test_loss,
 77              })
 78              return p.run_id, train_loss, val_loss, test_loss
 79  
 80          return eval
 81  
 82      with mlflow.start_run() as run:
 83          experiment_id = run.info.experiment_id
 84          _, null_train_loss, null_val_loss, null_test_loss = new_eval(0, experiment_id)((0, 0))
 85          runs = [(np.random.uniform(1e-5, 1e-1), np.random.uniform(0, 1.0)) for _ in range(max_runs)]
 86          with ThreadPoolExecutor(max_workers=max_p) as executor:
 87              _ = executor.map(
 88                  new_eval(epochs, experiment_id, null_train_loss, null_val_loss, null_test_loss),
 89                  runs,
 90              )
 91  
 92          # find the best run, log its metrics as the final metrics of this run.
 93          client = MlflowClient()
 94          runs = client.search_runs(
 95              [experiment_id], f"tags.mlflow.parentRunId = '{run.info.run_id}' "
 96          )
 97          best_val_train = _inf
 98          best_val_valid = _inf
 99          best_val_test = _inf
100          best_run = None
101          for r in runs:
102              if r.data.metrics["val_rmse"] < best_val_valid:
103                  best_run = r
104                  best_val_train = r.data.metrics["train_rmse"]
105                  best_val_valid = r.data.metrics["val_rmse"]
106                  best_val_test = r.data.metrics["test_rmse"]
107          mlflow.set_tag("best_run", best_run.info.run_id)
108          mlflow.log_metrics({
109              f"train_{metric}": best_val_train,
110              f"val_{metric}": best_val_valid,
111              f"test_{metric}": best_val_test,
112          })
113  
114  
115  if __name__ == "__main__":
116      run()