search_random.py
1 """ 2 Example of hyperparameter search in MLflow using simple random search. 3 4 The run method will evaluate random combinations of parameters in a new MLflow run. 5 6 The runs are evaluated based on validation set loss. Test set score is calculated to verify the 7 results. 8 9 Several runs can be run in parallel. 10 """ 11 12 from concurrent.futures import ThreadPoolExecutor 13 14 import click 15 import numpy as np 16 17 import mlflow 18 import mlflow.projects 19 from mlflow.tracking import MlflowClient 20 21 _inf = np.finfo(np.float64).max 22 23 24 @click.command(help="Perform grid search over train (main entry point).") 25 @click.option("--max-runs", type=click.INT, default=32, help="Maximum number of runs to evaluate.") 26 @click.option("--max-p", type=click.INT, default=1, help="Maximum number of parallel runs.") 27 @click.option("--epochs", type=click.INT, default=32, help="Number of epochs") 28 @click.option("--metric", type=click.STRING, default="rmse", help="Metric to optimize on.") 29 @click.option("--seed", type=click.INT, default=97531, help="Seed for the random generator") 30 @click.argument("training_data") 31 def run(training_data, max_runs, max_p, epochs, metric, seed): 32 train_metric = f"train_{metric}" 33 val_metric = f"val_{metric}" 34 test_metric = f"test_{metric}" 35 np.random.seed(seed) 36 tracking_client = MlflowClient() 37 38 def new_eval( 39 nepochs, experiment_id, null_train_loss=_inf, null_val_loss=_inf, null_test_loss=_inf 40 ): 41 def eval(params): 42 lr, momentum = params 43 with mlflow.start_run(nested=True) as child_run: 44 p = mlflow.projects.run( 45 run_id=child_run.info.run_id, 46 uri=".", 47 entry_point="train", 48 parameters={ 49 "training_data": training_data, 50 "epochs": str(nepochs), 51 "learning_rate": str(lr), 52 "momentum": str(momentum), 53 "seed": str(seed), 54 }, 55 experiment_id=experiment_id, 56 synchronous=False, 57 ) 58 succeeded = p.wait() 59 mlflow.log_params({"lr": lr, "momentum": momentum}) 60 if succeeded: 61 training_run = tracking_client.get_run(p.run_id) 62 metrics = training_run.data.metrics 63 # cap the loss at the loss of the null model 64 train_loss = min(null_train_loss, metrics[train_metric]) 65 val_loss = min(null_val_loss, metrics[val_metric]) 66 test_loss = min(null_test_loss, metrics[test_metric]) 67 else: 68 # run failed => return null loss 69 tracking_client.set_terminated(p.run_id, "FAILED") 70 train_loss = null_train_loss 71 val_loss = null_val_loss 72 test_loss = null_test_loss 73 mlflow.log_metrics({ 74 f"train_{metric}": train_loss, 75 f"val_{metric}": val_loss, 76 f"test_{metric}": test_loss, 77 }) 78 return p.run_id, train_loss, val_loss, test_loss 79 80 return eval 81 82 with mlflow.start_run() as run: 83 experiment_id = run.info.experiment_id 84 _, null_train_loss, null_val_loss, null_test_loss = new_eval(0, experiment_id)((0, 0)) 85 runs = [(np.random.uniform(1e-5, 1e-1), np.random.uniform(0, 1.0)) for _ in range(max_runs)] 86 with ThreadPoolExecutor(max_workers=max_p) as executor: 87 _ = executor.map( 88 new_eval(epochs, experiment_id, null_train_loss, null_val_loss, null_test_loss), 89 runs, 90 ) 91 92 # find the best run, log its metrics as the final metrics of this run. 93 client = MlflowClient() 94 runs = client.search_runs( 95 [experiment_id], f"tags.mlflow.parentRunId = '{run.info.run_id}' " 96 ) 97 best_val_train = _inf 98 best_val_valid = _inf 99 best_val_test = _inf 100 best_run = None 101 for r in runs: 102 if r.data.metrics["val_rmse"] < best_val_valid: 103 best_run = r 104 best_val_train = r.data.metrics["train_rmse"] 105 best_val_valid = r.data.metrics["val_rmse"] 106 best_val_test = r.data.metrics["test_rmse"] 107 mlflow.set_tag("best_run", best_run.info.run_id) 108 mlflow.log_metrics({ 109 f"train_{metric}": best_val_train, 110 f"val_{metric}": best_val_valid, 111 f"test_{metric}": best_val_test, 112 }) 113 114 115 if __name__ == "__main__": 116 run()