train.py
1 import torch 2 import torch.nn as nn 3 import torch.optim as optim 4 from torch.utils.data import Dataset, DataLoader 5 from torchvision import transforms 6 import torchvision.transforms as transforms 7 from PIL import Image 8 import albumentations as A 9 from albumentations.pytorch import ToTensorV2 10 import numpy as np 11 import matplotlib.pyplot as plt 12 import cv2 13 import os 14 import random 15 import torch.nn.functional as F 16 from model import UNet 17 from data import PowerlineDataset 18 19 # not using google colab 20 google_drive_base = os.path.abspath(os.pardir) 21 22 # declare global variables 23 # Use forward slashes for paths in Google Colab 24 img_root = google_drive_base + '/data/train' 25 train_imgs = img_root + '/train/train_imgs' 26 val_imgs = img_root + '/val/val_imgs' 27 train_gt = img_root + '/train/train_gt' 28 val_gt = img_root + '/val/val_gt' 29 test_imgs = img_root + '/test/test_imgs' 30 test_gt = img_root + '/test/test_gt' 31 32 # print the file count 33 print(len(os.listdir(train_imgs)), len(os.listdir(val_imgs)), 34 len(os.listdir(train_gt)), len(os.listdir(val_gt))) 35 36 # Device 37 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 39 # Model 40 model = UNet(in_channels=3, out_channels=1).to(device) 41 42 43 44 # Load pre-trained weights from best_model.pth for transfer learning 45 pretrained_path = google_drive_base + '/models/best_finetuned_model.pth' 46 model.load_state_dict(torch.load(pretrained_path)) 47 print(f"Loaded pre-trained weights from {pretrained_path}") 48 49 model_save_path = google_drive_base + '/models/trained_model_5_9_25.pth' 50 51 # Optional: Freeze some layers (e.g., encoder) if you only want to fine-tune later layers 52 # For UNet, you might freeze the encoder part (adjust based on your UNet implementation) 53 # for name, param in model.named_parameters(): 54 # if "encoder" in name: # Adjust this condition based on your UNet structure 55 # param.requires_grad = False 56 57 58 # # Freeze encoder layers (enc1, enc2, enc3, enc4) 59 # for layer in [model.enc1, model.enc2, model.enc3, model.enc4]: 60 # for param in layer.parameters(): 61 # param.requires_grad = False 62 63 # # Verify which parameters are frozen 64 # print("Parameters with requires_grad=True (trainable):") 65 # for name, param in model.named_parameters(): 66 # if param.requires_grad: 67 # print(name) 68 69 70 71 def dice_loss(pred, target, smooth=1): 72 pred = torch.sigmoid(pred) 73 intersection = (pred * target).sum() 74 return 1 - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth) 75 76 class HybridLoss(nn.Module): 77 def __init__(self, pos_weight=20.0): 78 super().__init__() 79 self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device='cuda')) 80 self.dice_weight = 0.3 81 82 def forward(self, pred, target): 83 bce = self.bce(pred, target) 84 dice = dice_loss(pred, target) 85 return (1 - self.dice_weight) * bce + self.dice_weight * dice 86 87 criterion = HybridLoss(pos_weight=15.5) 88 89 90 # Optimizer and Scheduler 91 # optimizer = optim.Adam(model.parameters(), lr=0.00005) 92 optimizer = optim.SGD(model.parameters(), lr=0.00015, momentum=0.9) 93 scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=20) 94 95 batch_size = 16 96 97 # Dataset and Dataloader 98 train_dataset = PowerlineDataset(train_imgs, train_gt) 99 val_dataset = PowerlineDataset(val_imgs, val_gt) 100 train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 101 val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 102 103 104 def compute_metrics(pred, target, threshold=0.1): 105 # Convert predictions to binary using the threshold 106 pred = (torch.sigmoid(pred) > threshold).to(torch.bool) # Convert to boolean for bitwise operations 107 target = target.to(torch.bool) # Ensure target is also boolean 108 109 # Compute intersection, union, etc. 110 intersection = (pred & target).sum() # Bitwise AND for intersection 111 union = (pred | target).sum() # Bitwise OR for union 112 113 # Calculate IoU 114 iou = (intersection + 1e-6) / (union + 1e-6) 115 116 # Calculate TP, FP, FN for precision, recall, and F1 117 tp = intersection # True Positives (intersection) 118 fp = (pred & ~target).sum() # False Positives 119 fn = (~pred & target).sum() # False Negatives 120 121 # Calculate precision, recall, and F1 122 precision = (tp + 1e-6) / (tp + fp + 1e-6) 123 recall = (tp + 1e-6) / (tp + fn + 1e-6) 124 f1 = 2 * (precision * recall) / (precision + recall + 1e-6) 125 126 return iou, f1, precision, recall 127 128 # Training Loop with Gradient Accumulation 129 num_epochs = 150 130 best_iou = 0.0 131 patience = 50 132 trigger_times = 0 133 accumulation_steps = 4 # This is the key: 4 * 16 = 64 134 135 136 for epoch in range(num_epochs): 137 model.train() 138 train_loss = 0.0 139 optimizer.zero_grad() # Clear gradients at the start of the epoch 140 141 for i, (images, masks) in enumerate(train_dataloader): 142 images, masks = images.to(device), masks.to(device) 143 144 # Forward pass 145 outputs = model(images) 146 loss = criterion(outputs, masks) 147 148 # Normalize loss to account for accumulation 149 loss = loss / accumulation_steps 150 151 # Backward pass 152 loss.backward() 153 154 # Update weights only after accumulating gradients for `accumulation_steps` 155 if (i + 1) % accumulation_steps == 0: 156 optimizer.step() 157 optimizer.zero_grad() 158 159 train_loss += loss.item() * accumulation_steps # Revert normalization for correct loss logging 160 161 # Handle any remaining gradients if the dataset size is not divisible by accumulation_steps 162 if (len(train_dataloader) % accumulation_steps != 0): 163 optimizer.step() 164 optimizer.zero_grad() 165 166 train_loss /= len(train_dataloader) 167 168 # Validation 169 model.eval() 170 val_iou, val_f1, _, _ = 0.0, 0.0, 0.0, 0.0 171 with torch.no_grad(): 172 for images, masks in val_dataloader: 173 images, masks = images.to(device), masks.to(device) 174 outputs = model(images) 175 iou, f1, _, _ = compute_metrics(outputs, masks, threshold=0.1) 176 val_iou += iou.item() 177 val_f1 += f1.item() 178 val_iou /= len(val_dataloader) 179 val_f1 /= len(val_dataloader) 180 181 print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Val IoU: {val_iou:.4f}, Val F1: {val_f1:.4f}") 182 183 # Learning Rate Scheduling 184 scheduler.step(train_loss) 185 186 # Early Stopping 187 if val_iou > best_iou: 188 best_iou = val_iou 189 trigger_times = 0 # Correct logic: reset patience counter 190 torch.save(model.state_dict(), model_save_path) 191 else: 192 trigger_times += 1 193 if trigger_times >= patience: 194 print(f"Early stopping at epoch {epoch+1}") 195 break 196 197 # Visualize every 5 epochs 198 if epoch % 5 == 0: 199 # Use torch.no_grad() to save memory and computations during inference 200 with torch.no_grad(): 201 # Get one batch of data 202 images, masks = next(iter(val_dataloader)) 203 204 # Take the first image and mask from the batch 205 # This is where we ensure the tensor has the correct shape 206 image = images[0].to(device) 207 mask = masks[0].to(device) 208 209 # Add a batch dimension of size 1 for a single image 210 # The shape of 'image' should be [3, 512, 512] 211 # The shape of 'image.unsqueeze(0)' will be [1, 3, 512, 512] 212 pred = model(image.unsqueeze(0)) 213 214 # Get the predicted mask by applying a sigmoid and threshold 215 pred = (torch.sigmoid(pred.squeeze(0)) > 0.5).float() 216 217 # plot out the gt and prediction 218 train_dataset.plot_sample(mask, pred) 219 220 221 222 223 # Load best model 224 # model.load_state_dict(torch.load('best_finetuned_model.pth'))