data.py
1 import os 2 import random 3 import torchvision.transforms as transforms 4 from torch.utils.data import Dataset, DataLoader 5 6 class PowerlineDataset(Dataset): 7 def __init__(self, img_dir, mask_dir, transform=None): 8 self.img_dir = img_dir 9 self.mask_dir = mask_dir 10 self.transform = transform 11 self.palette = np.array([[120,120,120], [0,255,255]]) # background, powerline 12 13 # List of image filenames 14 self.images = [f for f in os.listdir(img_dir) if f.lower().endswith(".jpg")] 15 16 # # Define augmentation pipeline 17 if transform is None: 18 self.transform = A.Compose([ 19 A.HorizontalFlip(p=0.5), 20 A.VerticalFlip(p=0.5), 21 A.Rotate(limit=90, p=0.9), 22 A.RandomScale(scale_limit=0.2, p=0.5), 23 A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.5), 24 A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7), 25 A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5), 26 A.RandomGamma(gamma_limit=(80, 120), p=0.5), 27 A.GaussNoise(var_limit=(10.0, 50.0), p=0.3), 28 A.Resize(512, 512, interpolation=1), 29 A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 30 A.InvertImg(p=0.4), 31 ToTensorV2(), 32 ]) 33 34 def __len__(self): 35 return len(self.images) 36 37 def _convert_mask(self, mask): 38 mask = np.array(mask) # Convert PIL image to numpy array 39 mask_binary = np.all(np.equal(mask, self.palette[1]), axis=-1).astype(np.float32) # 1 for powerline, 0 for background 40 return mask_binary # No tensor conversion here, let albumentations handle it 41 42 43 def __getitem__(self, idx): 44 img_path = os.path.join(self.img_dir, self.images[idx]) 45 mask_path = os.path.join(self.mask_dir, os.path.splitext(self.images[idx])[0] + ".png") 46 47 # Load image and mask as numpy arrays 48 image = np.array(Image.open(img_path).convert("RGB")) 49 mask = np.array(Image.open(mask_path).convert("RGB")) 50 51 # Convert mask to binary 52 mask_binary = self._convert_mask(mask) 53 54 # Apply augmentations to both image and mask 55 if self.transform: 56 augmented = self.transform(image=image, mask=mask_binary) 57 image = augmented['image'] # Already a tensor 58 mask = augmented['mask'].unsqueeze(0) # Add channel dimension (1, H, W), still numpy -> tensor 59 60 return image, mask 61 62 63 def plot_sample(self, mask, pred): 64 mask = mask.squeeze() 65 pred = pred.squeeze() 66 67 # Move tensors to CPU and convert to numpy for matplotlib 68 mask_np = mask.cpu().numpy() 69 pred_np = pred.cpu().numpy() 70 71 # Create an empty RGB array for the color-coded result 72 height, width = mask_np.shape 73 colored_mask = np.zeros((height, width, 3), dtype=np.uint8) 74 75 # Define the colors 76 WHITE = [255, 255, 255] # True Positive 77 BLACK = [0, 0, 0] # True Negative 78 RED = [255, 0, 0] # False Positive 79 GREY = [128, 128, 128] # False Negative 80 81 colored_mask[(pred_np == 1) & (mask_np == 1)] = WHITE 82 colored_mask[(pred_np == 0) & (mask_np == 0)] = BLACK 83 colored_mask[(pred_np == 1) & (mask_np == 0)] = RED 84 colored_mask[(pred_np == 0) & (mask_np == 1)] = GREY 85 86 # Display the results 87 plt.figure(figsize=(15, 5)) 88 plt.subplot(1, 3, 1); 89 plt.imshow(image.permute(1, 2, 0).cpu()); 90 plt.title("Original Image") 91 92 plt.subplot(1, 3, 2); 93 plt.imshow(mask_np, cmap='gray'); 94 plt.title("Target Mask") 95 96 plt.subplot(1, 3, 3); 97 plt.imshow(colored_mask); 98 plt.title("TP: White | TN: Black | FP: Red | FN: Grey") 99 100 plt.show() 101 102