/ gasnet / data_utils.py
data_utils.py
  1  #coding=utf-8
  2  import sys
  3  sys.path.append('/home/dvalsamis/Documents/projects/GAS-Net-main')
  4  from os.path import join
  5  import torch
  6  from PIL import Image
  7  from torch.utils.data.dataset import Dataset
  8  import numpy as np
  9  import torchvision.transforms as transforms
 10  import os
 11  
 12  def is_image_file(filename):
 13      return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
 14  
 15  def calMetric_iou(predict, label):
 16      # print(np.logical_and(predict == 1, label == 1))
 17      # print(np.logical_and(predict == 1, label == 1).shape)
 18      tp = np.sum(np.logical_and(predict == 1, label == 1))
 19      # print(tp)
 20      fp = np.sum(predict==1)
 21      fn = np.sum(label == 1)
 22      return tp,fp+fn-tp
 23  
 24  def nearest_interpolation(lr_feature_high, fake_image):
 25      scale = fake_image.size(2)//lr_feature_high.size(2)
 26      batch_size = lr_feature_high.size(0)
 27      channels = lr_feature_high.size(1)
 28  
 29      tmp_feature = fake_image
 30      for m in range(batch_size):
 31          for n in range(channels):
 32              new_lr_feature_high = lr_feature_high[m][n].unsqueeze(0).unsqueeze(0)
 33              a = torch.nn.functional.interpolate(new_lr_feature_high, scale_factor = scale,  mode='nearest', align_corners=None)
 34              tmp_feature[m][n] = a.squeeze()
 35      return tmp_feature
 36  
 37  
 38  def make_one_hot(input, num_classes):
 39      """Convert class index tensor to one hot encoding tensor.
 40      Args:
 41           input: A tensor of shape [N, 1, *]
 42           num_classes: An int of number of class
 43      Returns:
 44          A tensor of shape [N, num_classes, *]
 45      """
 46      shape = np.array(input.shape)
 47      shape[1] = num_classes
 48      shape = tuple(shape)
 49      result = torch.zeros(shape)
 50      result = result.scatter_(1, input.cpu(), 1)
 51      return result
 52  
 53  
 54  def get_transform(convert=True, normalize=False):
 55      transform_list = []
 56      if convert:
 57          transform_list += [transforms.ToTensor()]
 58      if normalize:
 59          transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
 60                                                  (0.5, 0.5, 0.5))]
 61      return transforms.Compose(transform_list)
 62  
 63  
 64  class LoadDatasetFromFolder(Dataset):
 65      def __init__(self, args, hr1_path, lr2_path, hr2_path, lab_path):
 66          super(LoadDatasetFromFolder, self).__init__()
 67          datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if
 68                        os.path.splitext(name)[1] == item]
 69  
 70          self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
 71          self.lr2_filenames = [join(lr2_path, x) for x in datalist if is_image_file(x)]
 72          self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
 73          self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
 74  
 75          self.transform = get_transform(convert=True, normalize= False)
 76          self.label_transform = get_transform()
 77  
 78      def __getitem__(self, index):
 79          hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
 80          lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB'))
 81          hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
 82  
 83          label = self.label_transform(Image.open(self.lab_filenames[index]))
 84          label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
 85  
 86          return hr1_img, lr2_img, hr2_img, label
 87  
 88      def __len__(self):
 89          return len(self.hr1_filenames)
 90  
 91  class LoadDatasetFromFolder_CD(Dataset):
 92      def __init__(self, args, hr1_path, hr2_path, lab_path):
 93          super(LoadDatasetFromFolder_CD, self).__init__()
 94          datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if
 95                        os.path.splitext(name)[1] == item]
 96  
 97          self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
 98          self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
 99          self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
100  
101          self.transform = get_transform(convert=True, normalize= True)
102          self.label_transform = get_transform()
103  
104      def __getitem__(self, index):
105          hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
106          hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
107  
108          label = self.label_transform(Image.open(self.lab_filenames[index]))
109          label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
110  
111          return hr1_img, hr2_img, label
112  
113      def __len__(self):
114          return len(self.hr1_filenames)