/ changedetection / change_detection_dm.py
change_detection_dm.py
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Created on Sun Jun 11 12:02:38 2023 5 6 @author: aleoikon 7 """ 8 import sys 9 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 10 11 import numpy as np 12 from architectures.conv_classifier import conv_classifier_two, conv_classifier_two_with_nspp,conv_classifier_two_with_aspp 13 from architectures.branch import branch_cva, branch_cva_with_nspp, two_branch_cva_with_aspp 14 from changedetection.utils.layer_select import feature_selector_cva, feature_selector_cva_with_nspp, two_feature_selector_cva_aspp 15 from changedetection.utils.fusion_maria import fusion 16 import matplotlib.pyplot as plt 17 from changedetection.utils.visualize import create_rgb 18 from numpy import expand_dims 19 from skimage.filters import threshold_otsu, threshold_triangle 20 from skimage.morphology import remove_small_objects 21 import glob 22 import os 23 from skimage import io 24 25 #hide the gpus for timing the application 26 os.environ['CUDA_VISIBLE_DEVICES'] = '1' 27 28 def calculate_distancemap(f1, f2): 29 """ 30 calcualtes pixelwise euclidean distance between images with multiple imput channels 31 32 Parameters 33 ---------- 34 f1 : np.ndarray of shape (N,M,D) 35 image 1 with the channels in the third dimension 36 f2 : np.ndarray of shape (N,M,D) 37 image 2 with the channels in the third dimension 38 39 Returns 40 ------- 41 np.ndarray of shape(N,M) 42 pixelwise euclidean distance between image 1 and image 2 43 44 """ 45 dist_per_fmap= [(f2[i,:,:]-f1[i,:,:])**2 for i in range(f1.shape[0])] 46 47 return np.sqrt(sum(dist_per_fmap)) 48 49 def read_rasters(path): 50 im_names = glob.glob(os.path.join(path,'*B04.tif')) # search for files with names containing 'B04.tif' 51 print(im_names) 52 r = io.imread(im_names[0]) 53 print(im_names) 54 im_names = glob.glob(os.path.join(path,'*B03.tif')) # search for files with names containing 'B03.tif' 55 g = io.imread(im_names[0]) 56 print(im_names) 57 im_names = glob.glob(os.path.join(path,'*B02.tif')) # search for files with names containing 'B02.tif' 58 b = io.imread(im_names[0]) 59 print(im_names) 60 I = np.stack((r,g,b),axis=2).astype('float') 61 return I 62 63 64 def change_detection(dataset_name, input_rasters): 65 66 67 # Load images 68 imgs = [] 69 for timeframe in input_rasters: 70 print(timeframe) 71 imgs.append(read_rasters(timeframe)) 72 img1_or = imgs[0] 73 img2_or = imgs[1] 74 75 76 77 # Load change detection model 78 path = f'/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/Predictions/Depth_2/Onera_Tr/{dataset_name}' 79 # Load change detection model 80 saved_model = '/home/dvalsamis/Documents/projects/saved_models/depth:2_rgb_0.1_0.0001_0.001_50_convclastwo[0.2 2.25].h5' 81 82 83 # Check if the folder exists 84 if not os.path.exists(path): 85 # Create the folder 86 os.makedirs(path) 87 88 shape = img1_or.shape 89 depth = 2 90 dropout = 0.1 91 decay = 0.0001 92 ImageSize_X = shape[0] 93 ImageSize_Y = shape[1] 94 n_ch = shape[2] 95 96 cd_model = conv_classifier_two(depth, dropout, decay, ImageSize_X, ImageSize_Y, n_ch) 97 cd_model.load_weights(saved_model) 98 99 #predictions 100 img1 = expand_dims(img1_or, axis=0) 101 img1 = (img1 - img1.mean()) / img1.std() 102 img2 = expand_dims(img2_or, axis=0) 103 img2 = (img2 - img2.mean()) / img2.std() 104 105 #conv predictions 106 prediction = cd_model.predict([img1,img2]) 107 y_preds = np.argmax(prediction[0], axis=2) 108 y_pred_conv = y_preds 109 110 fig, ax = plt.subplots(1, 3, figsize=(20,10), constrained_layout=True) 111 font=20 112 #create subplots 113 ax[0].imshow(create_rgb(img1_or)) 114 ax[0].set_title('t1', fontsize=font) 115 ax[0].axis('off') 116 ax[1].imshow(create_rgb(img2_or)) 117 ax[1].set_title('t2', fontsize=font) 118 ax[1].axis('off') 119 ax[2].imshow(y_pred_conv,cmap='gray') 120 ax[2].set_title('Conv Prediction', fontsize=font) 121 ax[2].axis('off') 122 123 plt.savefig(os.path.join(path, f'{dataset_name}_conv_cm.png')) 124 np.save(os.path.join(path, f'{dataset_name}_conv_cm.npy'), y_pred_conv) 125 126 #Otsu & triangle 127 branch_model = branch_cva(dropout, decay, depth, ImageSize_X,ImageSize_Y,n_ch) 128 branch_model = feature_selector_cva(depth, cd_model, branch_model) 129 130 feature_maps_left = branch_model.predict(img1[:,:,:,0:3]) 131 feature_maps_right = branch_model.predict(img2[:,:,:,0:3]) 132 133 left = np.ndarray(shape=(32,ImageSize_X,ImageSize_Y)) 134 right = np.ndarray(shape=(32,ImageSize_X,ImageSize_Y)) 135 for i in range(left.shape[0]): 136 left[i] = feature_maps_left[0,:,:,i] 137 right[i] = feature_maps_right[0,:,:,i] 138 139 distmap = calculate_distancemap(left, right) 140 141 binary_otsu = distmap > threshold_otsu(distmap) 142 binary_otsu = remove_small_objects(binary_otsu,min_size=55) 143 144 binary_triangle = distmap > threshold_triangle(distmap) 145 binary_triangle = remove_small_objects(binary_triangle,min_size=55) 146 147 y_pred_otsu=binary_otsu 148 y_pred_triangle=binary_triangle 149 150 #Otsu outputs------------------------------------------------------------------------------------------ 151 152 fig, ax = plt.subplots(1, 3, figsize=(20,10), constrained_layout=True) 153 font=20 154 #create subplots 155 ax[0].imshow(create_rgb(img1_or)) 156 ax[0].set_title('t1', fontsize=font) 157 ax[0].axis('off') 158 ax[1].imshow(create_rgb(img2_or)) 159 ax[1].set_title('t2', fontsize=font) 160 ax[1].axis('off') 161 ax[2].imshow(y_pred_otsu,cmap='gray') 162 ax[2].set_title('Otsu Prediction', fontsize=font) 163 ax[2].axis('off') 164 165 plt.savefig(os.path.join(path, f'{dataset_name}_otsu_cm.png')) 166 np.save(os.path.join(path, f'{dataset_name}_otsu_cm.npy'), y_pred_otsu) 167 168 #Triangle Outputs------------------------------------------------------------------------------------------ 169 170 fig, ax = plt.subplots(1, 3, figsize=(20,10), constrained_layout=True) 171 font=20 172 #create subplots 173 ax[0].imshow(create_rgb(img1_or)) 174 ax[0].set_title('t1', fontsize=font) 175 ax[0].axis('off') 176 ax[1].imshow(create_rgb(img2_or)) 177 ax[1].set_title('t2', fontsize=font) 178 ax[1].axis('off') 179 ax[2].imshow(y_pred_conv,cmap='gray') 180 ax[2].set_title('Triangle Prediction', fontsize=font) 181 ax[2].axis('off') 182 183 plt.savefig(os.path.join(path, f'{dataset_name}_triangle_cm.png')) 184 np.save(os.path.join(path, f'{dataset_name}_triangle_cm.npy'), y_pred_triangle) 185 186 187 #Conv Outputs------------------------------------------------------------------------------------------ 188 189 param = 3.25 190 fused_mask = fusion(y_pred_triangle,y_pred_otsu,y_pred_conv, param) 191 y_pred_fusion = fused_mask 192 193 fig, ax = plt.subplots(1, 3, figsize=(20,10), constrained_layout=True) 194 font=20 195 #create subplots 196 ax[0].imshow(create_rgb(img1_or)) 197 ax[0].set_title('t1', fontsize=font) 198 ax[0].axis('off') 199 ax[1].imshow(create_rgb(img2_or)) 200 ax[1].set_title('t2', fontsize=font) 201 ax[1].axis('off') 202 ax[2].imshow(y_pred_fusion,cmap='gray') 203 ax[2].set_title('Fusion Prediction', fontsize=font) 204 ax[2].axis('off') 205 206 plt.savefig(os.path.join(path, f'{dataset_name}_fusion_cm.png')) 207 np.save(os.path.join(path, f'{dataset_name}_fusion_cm.npy'), y_pred_fusion) 208 return y_pred_fusion 209 210 211 dataset_names_file = '/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/all.txt' 212 213 214 215 with open(dataset_names_file, 'r') as file: 216 dataset_names_line = file.readline().strip() # Read the line containing all dataset names 217 dataset_names = dataset_names_line.split(',') # Split the line into individual dataset names 218 219 # Iterate over dataset names 220 for dataset_name in dataset_names: 221 try: 222 # Define input rasters for the current dataset 223 input_rasters = [ 224 225 f'/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/{dataset_name}//img1_cropped/', 226 f'/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/{dataset_name}//img2_cropped/' 227 228 ] 229 230 # Perform change detection for the current dataset 231 change_mask = change_detection(dataset_name, input_rasters) 232 233 except Exception as e: 234 print(f"Error occurred in dataset: {dataset_name}") 235 print(f"Error message: {str(e)}") 236 237 238 239 print("Done")