hpo_mnist.py
1 """ 2 Hyperparameter Optimization Example with Pure PyTorch and MLflow 3 4 This example demonstrates: 5 - Using MLflow to track hyperparameter optimization trials 6 - Parent/child run structure for organizing HPO experiments 7 - Pure PyTorch training (no Lightning dependencies) 8 - Simple MNIST classification with configurable hyperparameters 9 10 Run with: python hpo_mnist.py --n-trials 5 --max-epochs 3 11 """ 12 13 import argparse 14 15 import optuna 16 import torch 17 import torch.nn.functional as F 18 from torch import nn 19 from torch.utils.data import DataLoader 20 from torchvision import datasets, transforms 21 22 import mlflow 23 24 25 class SimpleNet(nn.Module): 26 def __init__(self, hidden_size, dropout_rate): 27 super().__init__() 28 self.fc1 = nn.Linear(784, hidden_size) 29 self.dropout = nn.Dropout(dropout_rate) 30 self.fc2 = nn.Linear(hidden_size, 10) 31 32 def forward(self, x): 33 x = x.view(-1, 784) 34 x = F.relu(self.fc1(x)) 35 x = self.dropout(x) 36 x = self.fc2(x) 37 return F.log_softmax(x, dim=1) 38 39 40 def train_epoch(model, device, train_loader, optimizer): 41 model.train() 42 for data, target in train_loader: 43 data = data.to(device) 44 target = target.to(device) 45 optimizer.zero_grad() 46 output = model(data) 47 loss = F.nll_loss(output, target) 48 loss.backward() 49 optimizer.step() 50 51 52 def evaluate(model, device, test_loader): 53 model.eval() 54 test_loss = 0 55 correct = 0 56 with torch.no_grad(): 57 for data, target in test_loader: 58 data = data.to(device) 59 target = target.to(device) 60 output = model(data) 61 test_loss += F.nll_loss(output, target, reduction="sum").item() 62 pred = output.argmax(dim=1, keepdim=True) 63 correct += pred.eq(target.view_as(pred)).sum().item() 64 65 test_loss /= len(test_loader.dataset) 66 accuracy = correct / len(test_loader.dataset) 67 return test_loss, accuracy 68 69 70 def objective(trial, args, train_loader, test_loader, device): 71 # Suggest hyperparameters 72 lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True) 73 hidden_size = trial.suggest_int("hidden_size", 64, 512, step=64) 74 dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5) 75 batch_size = trial.suggest_categorical("batch_size", [32, 64, 128]) 76 77 # Recreate data loaders with new batch size 78 train_loader = DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=True) 79 test_loader = DataLoader(test_loader.dataset, batch_size=batch_size, shuffle=False) 80 81 # Start nested MLflow run for this trial 82 with mlflow.start_run(nested=True, run_name=f"trial_{trial.number}"): 83 # Log hyperparameters 84 mlflow.log_params({ 85 "lr": lr, 86 "hidden_size": hidden_size, 87 "dropout_rate": dropout_rate, 88 "batch_size": batch_size, 89 }) 90 91 # Create model and optimizer 92 model = SimpleNet(hidden_size, dropout_rate).to(device) 93 optimizer = torch.optim.Adam(model.parameters(), lr=lr) 94 95 # Training loop 96 for epoch in range(args.max_epochs): 97 train_epoch(model, device, train_loader, optimizer) 98 test_loss, accuracy = evaluate(model, device, test_loader) 99 100 # Log metrics for each epoch 101 mlflow.log_metrics({"test_loss": test_loss, "accuracy": accuracy}, step=epoch) 102 103 # Return final accuracy for optimization 104 return accuracy 105 106 107 def main(): 108 parser = argparse.ArgumentParser() 109 parser.add_argument("--n-trials", type=int, default=10, help="Number of HPO trials") 110 parser.add_argument("--max-epochs", type=int, default=5, help="Epochs per trial") 111 parser.add_argument("--batch-size", type=int, default=64, help="Initial batch size") 112 args = parser.parse_args() 113 114 # Setup device 115 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 116 117 # Load MNIST data 118 transform = transforms.Compose([ 119 transforms.ToTensor(), 120 transforms.Normalize((0.1307,), (0.3081,)), 121 ]) 122 123 train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform) 124 test_dataset = datasets.MNIST("./data", train=False, transform=transform) 125 126 train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 127 test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 128 129 # Start parent MLflow run 130 with mlflow.start_run(run_name="HPO_Parent"): 131 mlflow.log_params({"n_trials": args.n_trials, "max_epochs": args.max_epochs}) 132 133 # Create Optuna study 134 study = optuna.create_study(direction="maximize", study_name="mnist_hpo") 135 136 # Run optimization 137 study.optimize( 138 lambda trial: objective(trial, args, train_loader, test_loader, device), 139 n_trials=args.n_trials, 140 ) 141 142 # Log best results to parent run 143 mlflow.log_metrics({ 144 "best_accuracy": study.best_value, 145 "best_trial": study.best_trial.number, 146 }) 147 # Log best hyperparameters with 'best_' prefix to avoid conflicts 148 best_params = {f"best_{k}": v for k, v in study.best_params.items()} 149 mlflow.log_params(best_params) 150 151 print(f"\nBest trial: {study.best_trial.number}") 152 print(f"Best accuracy: {study.best_value:.4f}") 153 print(f"Best params: {study.best_params}") 154 155 156 if __name__ == "__main__": 157 main()