data_utils.py
1 #coding=utf-8 2 import sys 3 sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main') 4 from os.path import join 5 import torch 6 from PIL import Image 7 from torch.utils.data.dataset import Dataset 8 import numpy as np 9 import torchvision.transforms as transforms 10 import os 11 12 def is_image_file(filename): 13 return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']) 14 15 def calMetric_iou(predict, label): 16 # print(np.logical_and(predict == 1, label == 1)) 17 # print(np.logical_and(predict == 1, label == 1).shape) 18 tp = np.sum(np.logical_and(predict == 1, label == 1)) 19 # print(tp) 20 fp = np.sum(predict==1) 21 fn = np.sum(label == 1) 22 return tp,fp+fn-tp 23 24 def nearest_interpolation(lr_feature_high, fake_image): 25 scale = fake_image.size(2)//lr_feature_high.size(2) 26 batch_size = lr_feature_high.size(0) 27 channels = lr_feature_high.size(1) 28 29 tmp_feature = fake_image 30 for m in range(batch_size): 31 for n in range(channels): 32 new_lr_feature_high = lr_feature_high[m][n].unsqueeze(0).unsqueeze(0) 33 a = torch.nn.functional.interpolate(new_lr_feature_high, scale_factor = scale, mode='nearest', align_corners=None) 34 tmp_feature[m][n] = a.squeeze() 35 return tmp_feature 36 37 38 def make_one_hot(input, num_classes): 39 """Convert class index tensor to one hot encoding tensor. 40 Args: 41 input: A tensor of shape [N, 1, *] 42 num_classes: An int of number of class 43 Returns: 44 A tensor of shape [N, num_classes, *] 45 """ 46 shape = np.array(input.shape) 47 shape[1] = num_classes 48 shape = tuple(shape) 49 result = torch.zeros(shape) 50 result = result.scatter_(1, input.cpu(), 1) 51 return result 52 53 54 def get_transform(convert=True, normalize=False): 55 transform_list = [] 56 if convert: 57 transform_list += [transforms.ToTensor()] 58 if normalize: 59 transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 60 (0.5, 0.5, 0.5))] 61 return transforms.Compose(transform_list) 62 63 64 class LoadDatasetFromFolder(Dataset): 65 def __init__(self, args, hr1_path, lr2_path, hr2_path, lab_path): 66 super(LoadDatasetFromFolder, self).__init__() 67 datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if 68 os.path.splitext(name)[1] == item] 69 70 self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)] 71 self.lr2_filenames = [join(lr2_path, x) for x in datalist if is_image_file(x)] 72 self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)] 73 self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)] 74 75 self.transform = get_transform(convert=True, normalize= False) 76 self.label_transform = get_transform() 77 78 def __getitem__(self, index): 79 hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB')) 80 lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB')) 81 hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB')) 82 83 label = self.label_transform(Image.open(self.lab_filenames[index])) 84 label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) 85 86 return hr1_img, lr2_img, hr2_img, label 87 88 def __len__(self): 89 return len(self.hr1_filenames) 90 91 class LoadDatasetFromFolder_CD(Dataset): 92 def __init__(self, args, hr1_path, hr2_path, lab_path): 93 super(LoadDatasetFromFolder_CD, self).__init__() 94 datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if 95 os.path.splitext(name)[1] == item] 96 97 self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)] 98 self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)] 99 self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)] 100 101 self.transform = get_transform(convert=True, normalize= True) 102 self.label_transform = get_transform() 103 104 def __getitem__(self, index): 105 hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB')) 106 hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB')) 107 108 label = self.label_transform(Image.open(self.lab_filenames[index])) 109 label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) 110 111 return hr1_img, hr2_img, label 112 113 def __len__(self): 114 return len(self.hr1_filenames)