/ src / train.py
train.py
  1  import torch
  2  import torch.nn as nn
  3  import torch.optim as optim
  4  from torch.utils.data import Dataset, DataLoader
  5  from torchvision import transforms
  6  import torchvision.transforms as transforms
  7  from PIL import Image
  8  import albumentations as A
  9  from albumentations.pytorch import ToTensorV2
 10  import numpy as np
 11  import matplotlib.pyplot as plt
 12  import cv2
 13  import os
 14  import random
 15  import torch.nn.functional as F
 16  from model import UNet
 17  from data import PowerlineDataset
 18  
 19  # not using google colab
 20  google_drive_base = os.path.abspath(os.pardir)
 21  
 22  # declare global variables
 23  # Use forward slashes for paths in Google Colab
 24  img_root = google_drive_base + '/data/train'
 25  train_imgs = img_root + '/train/train_imgs'
 26  val_imgs = img_root + '/val/val_imgs'
 27  train_gt = img_root + '/train/train_gt'
 28  val_gt = img_root + '/val/val_gt'
 29  test_imgs = img_root + '/test/test_imgs'
 30  test_gt = img_root + '/test/test_gt'
 31  
 32  # print the file count
 33  print(len(os.listdir(train_imgs)), len(os.listdir(val_imgs)),
 34        len(os.listdir(train_gt)), len(os.listdir(val_gt)))
 35  
 36  # Device
 37  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 38  
 39  # Model
 40  model = UNet(in_channels=3, out_channels=1).to(device)
 41  
 42  
 43  
 44  # Load pre-trained weights from best_model.pth for transfer learning
 45  pretrained_path = google_drive_base + '/models/best_finetuned_model.pth'
 46  model.load_state_dict(torch.load(pretrained_path))
 47  print(f"Loaded pre-trained weights from {pretrained_path}")
 48  
 49  model_save_path = google_drive_base + '/models/trained_model_5_9_25.pth'
 50  
 51  # Optional: Freeze some layers (e.g., encoder) if you only want to fine-tune later layers
 52  # For UNet, you might freeze the encoder part (adjust based on your UNet implementation)
 53  # for name, param in model.named_parameters():
 54  #      if "encoder" in name:  # Adjust this condition based on your UNet structure
 55  #          param.requires_grad = False
 56  
 57  
 58  # # Freeze encoder layers (enc1, enc2, enc3, enc4)
 59  # for layer in [model.enc1, model.enc2, model.enc3, model.enc4]:
 60  #     for param in layer.parameters():
 61  #         param.requires_grad = False
 62  
 63  # # Verify which parameters are frozen
 64  # print("Parameters with requires_grad=True (trainable):")
 65  # for name, param in model.named_parameters():
 66  #     if param.requires_grad:
 67  #         print(name)
 68  
 69  
 70  
 71  def dice_loss(pred, target, smooth=1):
 72      pred = torch.sigmoid(pred)
 73      intersection = (pred * target).sum()
 74      return 1 - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
 75  
 76  class HybridLoss(nn.Module):
 77      def __init__(self, pos_weight=20.0):
 78          super().__init__()
 79          self.bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device='cuda'))
 80          self.dice_weight = 0.3
 81  
 82      def forward(self, pred, target):
 83          bce = self.bce(pred, target)
 84          dice = dice_loss(pred, target)
 85          return (1 - self.dice_weight) * bce + self.dice_weight * dice
 86  
 87  criterion = HybridLoss(pos_weight=15.5)
 88  
 89  
 90  # Optimizer and Scheduler
 91  # optimizer = optim.Adam(model.parameters(), lr=0.00005)
 92  optimizer = optim.SGD(model.parameters(), lr=0.00015, momentum=0.9)
 93  scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=20)
 94  
 95  batch_size = 16
 96  
 97  # Dataset and Dataloader
 98  train_dataset = PowerlineDataset(train_imgs, train_gt)
 99  val_dataset = PowerlineDataset(val_imgs, val_gt)
100  train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
101  val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
102  
103  
104  def compute_metrics(pred, target, threshold=0.1):
105      # Convert predictions to binary using the threshold
106      pred = (torch.sigmoid(pred) > threshold).to(torch.bool)  # Convert to boolean for bitwise operations
107      target = target.to(torch.bool)  # Ensure target is also boolean
108  
109      # Compute intersection, union, etc.
110      intersection = (pred & target).sum()  # Bitwise AND for intersection
111      union = (pred | target).sum()  # Bitwise OR for union
112  
113      # Calculate IoU
114      iou = (intersection + 1e-6) / (union + 1e-6)
115  
116      # Calculate TP, FP, FN for precision, recall, and F1
117      tp = intersection  # True Positives (intersection)
118      fp = (pred & ~target).sum()  # False Positives
119      fn = (~pred & target).sum()  # False Negatives
120  
121      # Calculate precision, recall, and F1
122      precision = (tp + 1e-6) / (tp + fp + 1e-6)
123      recall = (tp + 1e-6) / (tp + fn + 1e-6)
124      f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
125  
126      return iou, f1, precision, recall
127  
128  # Training Loop with Gradient Accumulation
129  num_epochs = 150
130  best_iou = 0.0
131  patience = 50
132  trigger_times = 0
133  accumulation_steps = 4  # This is the key: 4 * 16 = 64
134  
135  
136  for epoch in range(num_epochs):
137      model.train()
138      train_loss = 0.0
139      optimizer.zero_grad() # Clear gradients at the start of the epoch
140  
141      for i, (images, masks) in enumerate(train_dataloader):
142          images, masks = images.to(device), masks.to(device)
143  
144          # Forward pass
145          outputs = model(images)
146          loss = criterion(outputs, masks)
147  
148          # Normalize loss to account for accumulation
149          loss = loss / accumulation_steps
150  
151          # Backward pass
152          loss.backward()
153  
154          # Update weights only after accumulating gradients for `accumulation_steps`
155          if (i + 1) % accumulation_steps == 0:
156              optimizer.step()
157              optimizer.zero_grad()
158  
159          train_loss += loss.item() * accumulation_steps  # Revert normalization for correct loss logging
160  
161      # Handle any remaining gradients if the dataset size is not divisible by accumulation_steps
162      if (len(train_dataloader) % accumulation_steps != 0):
163          optimizer.step()
164          optimizer.zero_grad()
165  
166      train_loss /= len(train_dataloader)
167  
168      # Validation
169      model.eval()
170      val_iou, val_f1, _, _ = 0.0, 0.0, 0.0, 0.0
171      with torch.no_grad():
172          for images, masks in val_dataloader:
173              images, masks = images.to(device), masks.to(device)
174              outputs = model(images)
175              iou, f1, _, _ = compute_metrics(outputs, masks, threshold=0.1)
176              val_iou += iou.item()
177              val_f1 += f1.item()
178          val_iou /= len(val_dataloader)
179          val_f1 /= len(val_dataloader)
180  
181      print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Val IoU: {val_iou:.4f}, Val F1: {val_f1:.4f}")
182  
183      # Learning Rate Scheduling
184      scheduler.step(train_loss)
185  
186      # Early Stopping
187      if val_iou > best_iou:
188          best_iou = val_iou
189          trigger_times = 0 # Correct logic: reset patience counter
190          torch.save(model.state_dict(), model_save_path)
191      else:
192          trigger_times += 1
193          if trigger_times >= patience:
194              print(f"Early stopping at epoch {epoch+1}")
195              break
196  
197      # Visualize every 5 epochs
198      if epoch % 5 == 0:
199          # Use torch.no_grad() to save memory and computations during inference
200          with torch.no_grad():
201              # Get one batch of data
202              images, masks = next(iter(val_dataloader))
203  
204              # Take the first image and mask from the batch
205              # This is where we ensure the tensor has the correct shape
206              image = images[0].to(device)
207              mask = masks[0].to(device)
208  
209              # Add a batch dimension of size 1 for a single image
210              # The shape of 'image' should be [3, 512, 512]
211              # The shape of 'image.unsqueeze(0)' will be [1, 3, 512, 512]
212              pred = model(image.unsqueeze(0))
213  
214          # Get the predicted mask by applying a sigmoid and threshold
215          pred = (torch.sigmoid(pred.squeeze(0)) > 0.5).float()
216  
217          # plot out the gt and prediction
218          train_dataset.plot_sample(mask, pred)
219  
220  
221  
222  
223  # Load best model
224  # model.load_state_dict(torch.load('best_finetuned_model.pth'))