/ src / data.py
data.py
  1  import os
  2  import random
  3  import torchvision.transforms as transforms
  4  from torch.utils.data import Dataset, DataLoader
  5  
  6  class PowerlineDataset(Dataset):
  7      def __init__(self, img_dir, mask_dir, transform=None):
  8          self.img_dir = img_dir
  9          self.mask_dir = mask_dir
 10          self.transform = transform
 11          self.palette = np.array([[120,120,120], [0,255,255]]) # background, powerline
 12  
 13          # List of image filenames
 14          self.images = [f for f in os.listdir(img_dir) if f.lower().endswith(".jpg")]
 15  
 16          # # Define augmentation pipeline
 17          if transform is None:
 18              self.transform = A.Compose([
 19                  A.HorizontalFlip(p=0.5),
 20                  A.VerticalFlip(p=0.5),
 21                  A.Rotate(limit=90, p=0.9),
 22                  A.RandomScale(scale_limit=0.2, p=0.5),
 23                  A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.5),
 24                  A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
 25                  A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
 26                  A.RandomGamma(gamma_limit=(80, 120), p=0.5),
 27                  A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
 28                  A.Resize(512, 512, interpolation=1),
 29                  A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
 30                  A.InvertImg(p=0.4), 
 31                  ToTensorV2(),
 32              ])
 33  
 34      def __len__(self):
 35          return len(self.images)
 36  
 37      def _convert_mask(self, mask):
 38          mask = np.array(mask)  # Convert PIL image to numpy array
 39          mask_binary = np.all(np.equal(mask, self.palette[1]), axis=-1).astype(np.float32)  # 1 for powerline, 0 for background
 40          return mask_binary  # No tensor conversion here, let albumentations handle it
 41  
 42  
 43      def __getitem__(self, idx):
 44          img_path = os.path.join(self.img_dir, self.images[idx])
 45          mask_path = os.path.join(self.mask_dir, os.path.splitext(self.images[idx])[0] + ".png")
 46  
 47          # Load image and mask as numpy arrays
 48          image = np.array(Image.open(img_path).convert("RGB"))
 49          mask = np.array(Image.open(mask_path).convert("RGB"))
 50  
 51          # Convert mask to binary
 52          mask_binary = self._convert_mask(mask)
 53  
 54          # Apply augmentations to both image and mask
 55          if self.transform:
 56              augmented = self.transform(image=image, mask=mask_binary)
 57              image = augmented['image']  # Already a tensor
 58              mask = augmented['mask'].unsqueeze(0)  # Add channel dimension (1, H, W), still numpy -> tensor
 59  
 60          return image, mask
 61  
 62  
 63      def plot_sample(self, mask, pred):
 64          mask = mask.squeeze()
 65          pred = pred.squeeze()
 66  
 67          # Move tensors to CPU and convert to numpy for matplotlib
 68          mask_np = mask.cpu().numpy()
 69          pred_np = pred.cpu().numpy()
 70  
 71          # Create an empty RGB array for the color-coded result
 72          height, width = mask_np.shape
 73          colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
 74  
 75          # Define the colors
 76          WHITE = [255, 255, 255] # True Positive
 77          BLACK = [0, 0, 0]       # True Negative
 78          RED = [255, 0, 0]       # False Positive
 79          GREY = [128, 128, 128]  # False Negative
 80  
 81          colored_mask[(pred_np == 1) & (mask_np == 1)] = WHITE
 82          colored_mask[(pred_np == 0) & (mask_np == 0)] = BLACK
 83          colored_mask[(pred_np == 1) & (mask_np == 0)] = RED
 84          colored_mask[(pred_np == 0) & (mask_np == 1)] = GREY
 85  
 86          # Display the results
 87          plt.figure(figsize=(15, 5))
 88          plt.subplot(1, 3, 1);
 89          plt.imshow(image.permute(1, 2, 0).cpu());
 90          plt.title("Original Image")
 91  
 92          plt.subplot(1, 3, 2);
 93          plt.imshow(mask_np, cmap='gray');
 94          plt.title("Target Mask")
 95  
 96          plt.subplot(1, 3, 3);
 97          plt.imshow(colored_mask);
 98          plt.title("TP: White | TN: Black | FP: Red | FN: Grey")
 99  
100          plt.show()
101  
102