/ examples / pytorch / HPOExample / hpo_mnist.py
hpo_mnist.py
  1  """
  2  Hyperparameter Optimization Example with Pure PyTorch and MLflow
  3  
  4  This example demonstrates:
  5  - Using MLflow to track hyperparameter optimization trials
  6  - Parent/child run structure for organizing HPO experiments
  7  - Pure PyTorch training (no Lightning dependencies)
  8  - Simple MNIST classification with configurable hyperparameters
  9  
 10  Run with: python hpo_mnist.py --n-trials 5 --max-epochs 3
 11  """
 12  
 13  import argparse
 14  
 15  import optuna
 16  import torch
 17  import torch.nn.functional as F
 18  from torch import nn
 19  from torch.utils.data import DataLoader
 20  from torchvision import datasets, transforms
 21  
 22  import mlflow
 23  
 24  
 25  class SimpleNet(nn.Module):
 26      def __init__(self, hidden_size, dropout_rate):
 27          super().__init__()
 28          self.fc1 = nn.Linear(784, hidden_size)
 29          self.dropout = nn.Dropout(dropout_rate)
 30          self.fc2 = nn.Linear(hidden_size, 10)
 31  
 32      def forward(self, x):
 33          x = x.view(-1, 784)
 34          x = F.relu(self.fc1(x))
 35          x = self.dropout(x)
 36          x = self.fc2(x)
 37          return F.log_softmax(x, dim=1)
 38  
 39  
 40  def train_epoch(model, device, train_loader, optimizer):
 41      model.train()
 42      for data, target in train_loader:
 43          data = data.to(device)
 44          target = target.to(device)
 45          optimizer.zero_grad()
 46          output = model(data)
 47          loss = F.nll_loss(output, target)
 48          loss.backward()
 49          optimizer.step()
 50  
 51  
 52  def evaluate(model, device, test_loader):
 53      model.eval()
 54      test_loss = 0
 55      correct = 0
 56      with torch.no_grad():
 57          for data, target in test_loader:
 58              data = data.to(device)
 59              target = target.to(device)
 60              output = model(data)
 61              test_loss += F.nll_loss(output, target, reduction="sum").item()
 62              pred = output.argmax(dim=1, keepdim=True)
 63              correct += pred.eq(target.view_as(pred)).sum().item()
 64  
 65      test_loss /= len(test_loader.dataset)
 66      accuracy = correct / len(test_loader.dataset)
 67      return test_loss, accuracy
 68  
 69  
 70  def objective(trial, args, train_loader, test_loader, device):
 71      # Suggest hyperparameters
 72      lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)
 73      hidden_size = trial.suggest_int("hidden_size", 64, 512, step=64)
 74      dropout_rate = trial.suggest_float("dropout_rate", 0.1, 0.5)
 75      batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
 76  
 77      # Recreate data loaders with new batch size
 78      train_loader = DataLoader(train_loader.dataset, batch_size=batch_size, shuffle=True)
 79      test_loader = DataLoader(test_loader.dataset, batch_size=batch_size, shuffle=False)
 80  
 81      # Start nested MLflow run for this trial
 82      with mlflow.start_run(nested=True, run_name=f"trial_{trial.number}"):
 83          # Log hyperparameters
 84          mlflow.log_params({
 85              "lr": lr,
 86              "hidden_size": hidden_size,
 87              "dropout_rate": dropout_rate,
 88              "batch_size": batch_size,
 89          })
 90  
 91          # Create model and optimizer
 92          model = SimpleNet(hidden_size, dropout_rate).to(device)
 93          optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 94  
 95          # Training loop
 96          for epoch in range(args.max_epochs):
 97              train_epoch(model, device, train_loader, optimizer)
 98              test_loss, accuracy = evaluate(model, device, test_loader)
 99  
100              # Log metrics for each epoch
101              mlflow.log_metrics({"test_loss": test_loss, "accuracy": accuracy}, step=epoch)
102  
103          # Return final accuracy for optimization
104          return accuracy
105  
106  
107  def main():
108      parser = argparse.ArgumentParser()
109      parser.add_argument("--n-trials", type=int, default=10, help="Number of HPO trials")
110      parser.add_argument("--max-epochs", type=int, default=5, help="Epochs per trial")
111      parser.add_argument("--batch-size", type=int, default=64, help="Initial batch size")
112      args = parser.parse_args()
113  
114      # Setup device
115      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116  
117      # Load MNIST data
118      transform = transforms.Compose([
119          transforms.ToTensor(),
120          transforms.Normalize((0.1307,), (0.3081,)),
121      ])
122  
123      train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
124      test_dataset = datasets.MNIST("./data", train=False, transform=transform)
125  
126      train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
127      test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False)
128  
129      # Start parent MLflow run
130      with mlflow.start_run(run_name="HPO_Parent"):
131          mlflow.log_params({"n_trials": args.n_trials, "max_epochs": args.max_epochs})
132  
133          # Create Optuna study
134          study = optuna.create_study(direction="maximize", study_name="mnist_hpo")
135  
136          # Run optimization
137          study.optimize(
138              lambda trial: objective(trial, args, train_loader, test_loader, device),
139              n_trials=args.n_trials,
140          )
141  
142          # Log best results to parent run
143          mlflow.log_metrics({
144              "best_accuracy": study.best_value,
145              "best_trial": study.best_trial.number,
146          })
147          # Log best hyperparameters with 'best_' prefix to avoid conflicts
148          best_params = {f"best_{k}": v for k, v in study.best_params.items()}
149          mlflow.log_params(best_params)
150  
151          print(f"\nBest trial: {study.best_trial.number}")
152          print(f"Best accuracy: {study.best_value:.4f}")
153          print(f"Best params: {study.best_params}")
154  
155  
156  if __name__ == "__main__":
157      main()