image_segmenter.py
1 import os 2 import torch 3 import torch.nn.functional as F 4 import numpy as np 5 import cv2 6 from PIL import Image 7 import albumentations as A 8 from model import UNet 9 from torchvision import transforms 10 11 class ImageSegmenter: 12 """A class for running image segmentation inference with tiling.""" 13 14 def __init__(self, model_path=None, tile_size=512, overlap=64, target_size=(256, 256), device=None): 15 """Initialize the segmenter with model and inference parameters.""" 16 self.tile_size = tile_size 17 self.overlap = overlap 18 self.target_size = target_size 19 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device) 20 self.model = self._load_model(model_path) 21 self.transform = None # Set externally if needed 22 print(f"Initialized ImageSegmenter on device: {self.device}") 23 24 def _load_model(self, model_path): 25 """Load a trained PyTorch model from the specified path.""" 26 if model_path is None: 27 basedir = os.path.abspath(os.pardir) 28 model_path = os.path.join(basedir, 'models', 'trained_model.pth') 29 if not os.path.exists(model_path): 30 raise FileNotFoundError(f"Model file {model_path} not found.") 31 32 # Initialize your model architecture (replace UNet with your actual model) 33 model = UNet(in_channels=3, out_channels=1) # Adjust in_channels, out_channels 34 state_dict = torch.load(model_path, map_location=self.device) 35 model.load_state_dict(state_dict) # Load weights into model 36 model.to(self.device) 37 model.eval() 38 return model 39 40 def _detect_bright_sky(self, image_np, threshold=0.3, brightness_threshold=180, 41 blue_threshold=150, very_bright_threshold=0.02): 42 """ 43 Detect if the image contains a bright sky by focusing on the upper half. 44 Args: 45 image_np: NumPy array of shape [H, W, 3] (RGB) 46 threshold: Fraction of upper half pixels that must be bright 47 brightness_threshold: Minimum intensity for average RGB 48 blue_threshold: Minimum blue channel value for sky-like regions 49 very_bright_threshold: Fraction of very bright pixels for sun detection 50 Returns: 51 bool: True if bright sky detected 52 """ 53 height, width, _ = image_np.shape 54 upper_half = image_np[:height//2, :, :] 55 avg_intensity = upper_half.mean(axis=2) 56 blue_intensity = upper_half[:, :, 2] 57 bright_pixels = np.sum((avg_intensity > brightness_threshold) | (blue_intensity > blue_threshold)) 58 total_pixels = upper_half.shape[0] * upper_half.shape[1] 59 very_bright_pixels = np.sum(avg_intensity > 240) 60 has_bright_spot = (very_bright_pixels / total_pixels) > very_bright_threshold 61 return ((bright_pixels / total_pixels) > threshold) or has_bright_spot 62 63 def _preprocess_image(self, image_path): 64 """Preprocess an image for inference, applying CLAHE if bright sky detected and normalization.""" 65 image_np = np.array(Image.open(image_path).convert("RGB")) 66 is_bright_sky = self._detect_bright_sky(image_np) 67 68 # Apply CLAHE for bright sky 69 if is_bright_sky: 70 clahe_transform = A.Compose([ 71 A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=1.0) 72 ]) 73 image_np = clahe_transform(image=image_np)['image'] 74 # image_np[:] = 255 - image_np 75 print(f"Bright sky detected in {image_path}, applied CLAHE") 76 77 # Apply main transform if provided (albumentations) 78 if self.transform: 79 augmented = self.transform(image=image_np) 80 image_np = augmented['image'] 81 82 # Convert to tensor and normalize 83 image = torch.from_numpy(image_np.transpose(2, 0, 1)).float() / 255.0 84 normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 85 image = normalize(image) 86 87 return image, image_np.shape, is_bright_sky 88 89 def _predict_single(self, image_path, output_path): 90 """Run inference on a single image and save the output.""" 91 # Load and preprocess image 92 image, (H_orig_np, W_orig_np, _), is_bright_sky = self._preprocess_image(image_path) 93 C, H_orig, W_orig = image.shape 94 image = image.to(self.device) 95 96 # Handle small images by resizing 97 if H_orig < self.tile_size or W_orig < self.tile_size: 98 image = F.interpolate(image.unsqueeze(0), size=(self.tile_size, self.tile_size), 99 mode='bilinear', align_corners=False) 100 H_padded, W_padded = self.tile_size, self.tile_size 101 else: 102 # Pad to ensure full tiles 103 pad_h = self.tile_size - (H_orig % self.tile_size) if H_orig % self.tile_size != 0 else 0 104 pad_w = self.tile_size - (W_orig % self.tile_size) if W_orig % self.tile_size != 0 else 0 105 pad_top = pad_h // 2 106 pad_bottom = pad_h - pad_top 107 pad_left = pad_w // 2 108 pad_right = pad_w - pad_left 109 image = F.pad(image, (pad_left, pad_right, pad_top, pad_bottom), mode='replicate') 110 H_padded, W_padded = image.shape[1], image.shape[2] 111 112 # Compute stride 113 stride = self.tile_size - self.overlap 114 buffer = None 115 count = torch.zeros((H_padded, W_padded), device=self.device) 116 117 # Generate start positions 118 start_hs = list(range(0, H_padded - self.tile_size + 1, stride)) 119 if start_hs and start_hs[-1] + self.tile_size < H_padded: 120 start_hs.append(H_padded - self.tile_size) 121 start_ws = list(range(0, W_padded - self.tile_size + 1, stride)) 122 if start_ws and start_ws[-1] + self.tile_size < W_padded: 123 start_ws.append(W_padded - self.tile_size) 124 125 # Process tiles 126 with torch.no_grad(): 127 if H_padded == self.tile_size and W_padded == self.tile_size: 128 tile_input = image.unsqueeze(0) 129 outputs = self.model(tile_input) 130 prob = torch.sigmoid(outputs) 131 buffer = prob.squeeze(0) 132 count[:] = 1 133 else: 134 for start_h in start_hs: 135 for start_w in start_ws: 136 tile = image[:, start_h:start_h + self.tile_size, start_w:start_w + self.tile_size] 137 tile_input = tile.unsqueeze(0) 138 outputs = self.model(tile_input) 139 prob = torch.sigmoid(outputs).squeeze(0) 140 if buffer is None: 141 buffer = torch.zeros((prob.shape[0], H_padded, W_padded), device=self.device) 142 buffer[:, start_h:start_h + self.tile_size, start_w:start_w + self.tile_size] += prob 143 count[start_h:start_h + self.tile_size, start_w:start_w + self.tile_size] += 1 144 145 # Average and crop 146 final_prob = buffer / count.unsqueeze(0).clamp(min=1e-6) 147 if H_orig >= self.tile_size and W_orig >= self.tile_size: 148 final_prob = final_prob[:, pad_top:H_orig + pad_top, pad_left:W_orig + pad_left] 149 150 # Post-process and save 151 output = final_prob.cpu().numpy() 152 threshold = 0.0025 if is_bright_sky else 0.15 153 output = (output > threshold).astype(np.uint8) * 255 # Binary mask 154 # output = (output > 0.01).astype(np.uint8) * 255 # Binary mask 155 output = output.squeeze() # Adjust based on your model's output 156 os.makedirs(os.path.dirname(output_path), exist_ok=True) 157 cv2.imwrite(output_path, output) 158 print(f"Saved prediction to {output_path}") 159 160 def predict(self, input_path=None, output_path=None): 161 """Run inference on a single image or batch of images.""" 162 if input_path and output_path: 163 # Single image inference 164 self._predict_single(input_path, output_path) 165 else: 166 # Batch inference 167 input_dir = os.path.join(os.path.abspath(os.getcwd()), 'data', 'inference', 'in') 168 output_dir = os.path.join(os.path.abspath(os.getcwd()), 'data', 'inference', 'out') 169 if not os.path.exists(input_dir): 170 raise FileNotFoundError(f"Input directory {input_dir} does not exist.") 171 for filename in os.listdir(input_dir): 172 if filename.lower().endswith(('.png', '.jpg', '.jpeg')): 173 input_path = os.path.join(input_dir, filename) 174 output_path = os.path.join(output_dir, os.path.splitext(filename)[0] + '_mask.png') 175 self._predict_single(input_path, output_path)