/ src / image_segmenter.py
image_segmenter.py
  1  import os
  2  import torch
  3  import torch.nn.functional as F
  4  import numpy as np
  5  import cv2
  6  from PIL import Image
  7  import albumentations as A
  8  from model import UNet
  9  from torchvision import transforms
 10  
 11  class ImageSegmenter:
 12      """A class for running image segmentation inference with tiling."""
 13      
 14      def __init__(self, model_path=None, tile_size=512, overlap=64, target_size=(256, 256), device=None):
 15          """Initialize the segmenter with model and inference parameters."""
 16          self.tile_size = tile_size
 17          self.overlap = overlap
 18          self.target_size = target_size
 19          self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
 20          self.model = self._load_model(model_path)
 21          self.transform = None  # Set externally if needed
 22          print(f"Initialized ImageSegmenter on device: {self.device}")
 23  
 24      def _load_model(self, model_path):
 25          """Load a trained PyTorch model from the specified path."""
 26          if model_path is None:
 27              basedir = os.path.abspath(os.pardir)
 28              model_path = os.path.join(basedir, 'models', 'trained_model.pth')
 29          if not os.path.exists(model_path):
 30              raise FileNotFoundError(f"Model file {model_path} not found.")
 31          
 32          # Initialize your model architecture (replace UNet with your actual model)
 33          model = UNet(in_channels=3, out_channels=1)  # Adjust in_channels, out_channels
 34          state_dict = torch.load(model_path, map_location=self.device)
 35          model.load_state_dict(state_dict)  # Load weights into model
 36          model.to(self.device)
 37          model.eval()
 38          return model
 39  
 40      def _detect_bright_sky(self, image_np, threshold=0.3, brightness_threshold=180,
 41                             blue_threshold=150, very_bright_threshold=0.02):
 42          """
 43          Detect if the image contains a bright sky by focusing on the upper half.
 44          Args:
 45              image_np: NumPy array of shape [H, W, 3] (RGB)
 46              threshold: Fraction of upper half pixels that must be bright
 47              brightness_threshold: Minimum intensity for average RGB
 48              blue_threshold: Minimum blue channel value for sky-like regions
 49              very_bright_threshold: Fraction of very bright pixels for sun detection
 50          Returns:
 51              bool: True if bright sky detected
 52          """
 53          height, width, _ = image_np.shape
 54          upper_half = image_np[:height//2, :, :]
 55          avg_intensity = upper_half.mean(axis=2)
 56          blue_intensity = upper_half[:, :, 2]
 57          bright_pixels = np.sum((avg_intensity > brightness_threshold) | (blue_intensity > blue_threshold))
 58          total_pixels = upper_half.shape[0] * upper_half.shape[1]
 59          very_bright_pixels = np.sum(avg_intensity > 240)
 60          has_bright_spot = (very_bright_pixels / total_pixels) > very_bright_threshold        
 61          return ((bright_pixels / total_pixels) > threshold) or has_bright_spot
 62  
 63      def _preprocess_image(self, image_path):
 64          """Preprocess an image for inference, applying CLAHE if bright sky detected and normalization."""
 65          image_np = np.array(Image.open(image_path).convert("RGB"))
 66          is_bright_sky = self._detect_bright_sky(image_np)    
 67          
 68          # Apply CLAHE for bright sky
 69          if is_bright_sky:
 70              clahe_transform = A.Compose([
 71                  A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=1.0)
 72              ])
 73              image_np = clahe_transform(image=image_np)['image']
 74              # image_np[:] = 255 - image_np
 75              print(f"Bright sky detected in {image_path}, applied CLAHE")
 76  
 77          # Apply main transform if provided (albumentations)
 78          if self.transform:
 79              augmented = self.transform(image=image_np)
 80              image_np = augmented['image']
 81          
 82          # Convert to tensor and normalize
 83          image = torch.from_numpy(image_np.transpose(2, 0, 1)).float() / 255.0
 84          normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
 85          image = normalize(image)
 86          
 87          return image, image_np.shape, is_bright_sky
 88  
 89      def _predict_single(self, image_path, output_path):
 90          """Run inference on a single image and save the output."""
 91          # Load and preprocess image
 92          image, (H_orig_np, W_orig_np, _), is_bright_sky = self._preprocess_image(image_path)
 93          C, H_orig, W_orig = image.shape
 94          image = image.to(self.device)
 95  
 96          # Handle small images by resizing
 97          if H_orig < self.tile_size or W_orig < self.tile_size:
 98              image = F.interpolate(image.unsqueeze(0), size=(self.tile_size, self.tile_size),
 99                                    mode='bilinear', align_corners=False)
100              H_padded, W_padded = self.tile_size, self.tile_size
101          else:
102              # Pad to ensure full tiles
103              pad_h = self.tile_size - (H_orig % self.tile_size) if H_orig % self.tile_size != 0 else 0
104              pad_w = self.tile_size - (W_orig % self.tile_size) if W_orig % self.tile_size != 0 else 0
105              pad_top = pad_h // 2
106              pad_bottom = pad_h - pad_top
107              pad_left = pad_w // 2
108              pad_right = pad_w - pad_left
109              image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate')
110              H_padded, W_padded = image.shape[1], image.shape[2]
111  
112          # Compute stride
113          stride = self.tile_size - self.overlap
114          buffer = None
115          count = torch.zeros((H_padded, W_padded), device=self.device)
116  
117          # Generate start positions
118          start_hs = list(range(0, H_padded - self.tile_size + 1, stride))
119          if start_hs and start_hs[-1] + self.tile_size < H_padded:
120              start_hs.append(H_padded - self.tile_size)
121          start_ws = list(range(0, W_padded - self.tile_size + 1, stride))
122          if start_ws and start_ws[-1] + self.tile_size < W_padded:
123              start_ws.append(W_padded - self.tile_size)
124  
125          # Process tiles
126          with torch.no_grad():
127              if H_padded == self.tile_size and W_padded == self.tile_size:
128                  tile_input = image.unsqueeze(0)
129                  outputs = self.model(tile_input)
130                  prob = torch.sigmoid(outputs)
131                  buffer = prob.squeeze(0)
132                  count[:] = 1
133              else:
134                  for start_h in start_hs:
135                      for start_w in start_ws:
136                          tile = image[:, start_h:start_h + self.tile_size, start_w:start_w + self.tile_size]
137                          tile_input = tile.unsqueeze(0)
138                          outputs = self.model(tile_input)
139                          prob = torch.sigmoid(outputs).squeeze(0)
140                          if buffer is None:
141                              buffer = torch.zeros((prob.shape[0], H_padded, W_padded), device=self.device)
142                          buffer[:, start_h:start_h + self.tile_size, start_w:start_w + self.tile_size] += prob
143                          count[start_h:start_h + self.tile_size, start_w:start_w + self.tile_size] += 1
144  
145          # Average and crop
146          final_prob = buffer / count.unsqueeze(0).clamp(min=1e-6)
147          if H_orig >= self.tile_size and W_orig >= self.tile_size:
148              final_prob = final_prob[:, pad_top:H_orig + pad_top, pad_left:W_orig + pad_left]
149  
150          # Post-process and save
151          output = final_prob.cpu().numpy()
152          threshold = 0.0025 if is_bright_sky else 0.15
153          output = (output > threshold).astype(np.uint8) * 255  # Binary mask
154          # output = (output > 0.01).astype(np.uint8) * 255  # Binary mask
155          output = output.squeeze()  # Adjust based on your model's output
156          os.makedirs(os.path.dirname(output_path), exist_ok=True)
157          cv2.imwrite(output_path, output)
158          print(f"Saved prediction to {output_path}")
159  
160      def predict(self, input_path=None, output_path=None):
161          """Run inference on a single image or batch of images."""
162          if input_path and output_path:
163              # Single image inference
164              self._predict_single(input_path, output_path)
165          else:
166              # Batch inference
167              input_dir = os.path.join(os.path.abspath(os.getcwd()), 'data', 'inference', 'in')
168              output_dir = os.path.join(os.path.abspath(os.getcwd()), 'data', 'inference', 'out')
169              if not os.path.exists(input_dir):
170                  raise FileNotFoundError(f"Input directory {input_dir} does not exist.")
171              for filename in os.listdir(input_dir):
172                  if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
173                      input_path = os.path.join(input_dir, filename)
174                      output_path = os.path.join(output_dir, os.path.splitext(filename)[0] + '_mask.png')
175                      self._predict_single(input_path, output_path)