feature_maps_Onera.py
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Created on Thu Mar 10 09:40:56 2022 5 6 @author: aleoikon 7 """ 8 9 import sys 10 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 11 import time 12 13 14 from skimage.filters import threshold_otsu, threshold_triangle 15 from tensorflow import keras 16 from architectures.branch import branches_nopool, branch_cva, branch_cva_aspp,two_branch_cva_with_aspp,two_branch_cva_with_aspp_fmaps 17 #from tests import change_detection_noup, change_detection_noup_1x1convs 18 from architectures.similarity_detection import pretext_task_one_nopool 19 from tensorflow.keras.optimizers import Adam 20 from tensorflow.keras.utils import plot_model 21 import pandas as pd 22 import tensorflow as tf 23 import numpy as np 24 import os 25 import random 26 from numpy import expand_dims 27 from skimage.morphology import remove_small_objects 28 import matplotlib 29 matplotlib.use('TkAgg') 30 import matplotlib.pyplot as plt 31 from utils.layer_select import feature_selector_cva, feature_selector_cva_aspp,two_feature_selector_cva_aspp 32 from utils.log_params import log_params_cva 33 from architectures.conv_classifier import conv_classifier, conv_classifier_two,conv_classifier_two_with_aspp 34 from keras.models import Sequential 35 from skimage.filters import gaussian 36 from utils.my_metrics import recall, accuracy, specificity, precision, f_measure, get_roc 37 import uuid 38 from datetime import datetime 39 #os.environ["CUDA_VISIBLE_DEVICES"]="0" 40 41 42 #Euclidean Distance 43 def calculate_distancemap(f1, f2): 44 """ 45 calcualtes pixelwise euclidean distance between images with multiple imput channels 46 47 Parameters 48 ---------- 49 f1 : np.ndarray of shape (N,M,D) 50 image 1 with the channels in the third dimension 51 f2 : np.ndarray of shape (N,M,D) 52 image 2 with the channels in the third dimension 53 54 Returns 55 ------- 56 np.ndarray of shape(N,M) 57 pixelwise euclidean distance between image 1 and image 2 58 59 """ 60 dist_per_fmap= [(f2[i,:,:]-f1[i,:,:])**2 for i in range(f1.shape[0])] 61 62 return np.sqrt(sum(dist_per_fmap)) 63 64 def create_rgb_onera(x,channel): 65 if channel == 'red': 66 r = x[:,:,2] 67 r = np.expand_dims(r, axis=2) 68 return r 69 if channel == 'green': 70 g = x[:,:,1] 71 g = np.expand_dims(g, axis=2) 72 return g 73 if channel == 'blue': 74 b = x[:,:,0] 75 b = np.expand_dims(b, axis=2) 76 return b 77 if channel == 'rgb': 78 r = x[:,:,2] 79 g = x[:,:,1] 80 b = x[:,:,0] 81 rgb = np.dstack((r,g,b)) 82 return(rgb) 83 if channel == 'rgbvnir': 84 r = x[:,:,2] 85 g = x[:,:,1] 86 b = x[:,:,0] 87 vnir = x[:,:,3] 88 rgbvnir = np.stack((r,g,b,vnir),axis=2).astype('float') 89 return(rgbvnir) 90 else: 91 return x 92 print("NOT CORRECT CHANNELS") 93 94 def normalize(x): 95 img =((x - x.mean()) / x.std()) 96 return img 97 98 def scaleMinMax(x): 99 return ((x - np.nanpercentile(x,2)) / (np.nanpercentile(x,98) - np.nanpercentile(x,2))) 100 101 102 def create_rgb(x, channels): 103 if channels == 'red': 104 r = x[:,:,0] 105 r = scaleMinMax(r) 106 return r 107 if channels == 'green': 108 g = x[:,:,0] 109 g = scaleMinMax(g) 110 return g 111 if channels == 'blue': 112 b = x[:,:,0] 113 b = scaleMinMax(b) 114 return b 115 if channels == 'rgb': 116 r = x[:,:,2] 117 r = scaleMinMax(r) 118 g = x[:,:,1] 119 g = scaleMinMax(g) 120 b = x[:,:,0] 121 b = scaleMinMax(b) 122 rgb = np.dstack((r,g,b)) 123 return(rgb) 124 125 126 127 #---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 128 129 depth = 2 130 dropout = 0.1 131 decay = 0.0001 132 NORM = True 133 ImageSize = 96 134 n_ch = 3 135 channel = 'rgb' 136 137 # Models ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 138 source_model = conv_classifier_two_with_aspp(depth, dropout, decay, ImageSize, ImageSize, n_ch) 139 mtype = 'conv_classifier_two' 140 141 cd_model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/CD_Simple_Onera_0a8d.h5' 142 cd_model_name = 'CD_Simple_Onera_0a8d' 143 model_id = cd_model_name.split('_')[-1].split('.')[0] 144 145 146 source_model.load_weights(cd_model_path) 147 148 149 branch_model,feature_model = two_branch_cva_with_aspp_fmaps(dropout, decay, depth, ImageSize, ImageSize, n_ch) 150 branch_model = two_feature_selector_cva_aspp(depth, source_model, branch_model) 151 #plot_model(branch_model, to_file='/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/graphs_cva/'+model_id+'_model_plot.png', show_shapes=True, show_layer_names=True) 152 153 154 # Data ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 155 156 157 158 onera_train_target = '/home/aleoikon/Documents/data/ssl/onera_npys/patches/downstream/train/' 159 onera_test_target = '/home/aleoikon/Documents/data/ssl/onera_npys/patches/downstream/test/' 160 161 162 163 onera_test_df = pd.read_csv(onera_test_target + "dataset_test.csv") 164 165 dsize = len(onera_test_df) # Total size of the dataset 166 trsize = len(onera_train_target) # Size of the training set 167 tesize = len(onera_test_target) # Size of the testing set 168 169 170 X1 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch)) 171 X2 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch)) 172 173 nonNorm1 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch)) 174 nonNorm2 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch)) 175 176 y = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize)) 177 y_pred = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize)) 178 y_pred_otsu = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize)) 179 y_pred_triangle = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize)) 180 181 182 183 for i in range(len(onera_test_df)): 184 img1 = np.load(onera_test_target+onera_test_df['pair1'][i]) 185 img2 = np.load(onera_test_target+onera_test_df['pair2'][i]) 186 img1 = create_rgb_onera(img1, channel) 187 img2 = create_rgb_onera(img2, channel) 188 189 190 191 nonNorm1[i] = img1 192 nonNorm2[i] = img2 193 194 if NORM: 195 X1[i] = normalize(img1) 196 X2[i] = normalize(img2) 197 else: 198 X1[i] = img1 199 X2[i] = img2 200 y[i] = np.load(onera_test_target+onera_test_df['change_mask'][i]) 201 202 203 layers_of_interest = ['aspp_reduced_relu', 'abs_diff_2','reduced_aspp_output','output'] 204 205 feature_maps = feature_model.predict([X1, X2]) 206 207 208 branch_model.summary() 209 210 211 # y_pred_conv = branch_model.predict([X1,X2]) 212 # y_pred_conv = np.argmax(y_pred_conv, axis=3) 213 # # Assuming y_pred is a list with one element per output layer 214 # y_pred_single = y_pred[0] # Assuming the first element is the desired output 215 216 # pos = 57 217 # fig, ax = plt.subplots(2, 3, figsize=(15, 10), constrained_layout=True) 218 219 # ax[0, 0].imshow(create_rgb(nonNorm1[pos], channel)) 220 # ax[0, 0].set_title('Left Image', fontsize=20) 221 # ax[0, 0].axis('off') 222 223 # ax[0, 1].imshow(y[pos], cmap='gray') 224 # ax[0, 1].set_title('Ground Truth', fontsize=20) 225 # ax[0, 1].axis('off') 226 227 # ax[0, 2].imshow(create_rgb(nonNorm2[pos], channel)) 228 # ax[0, 2].set_title('Right Image', fontsize=20) 229 # ax[0, 2].axis('off') 230 231 # ax[1, 0].imshow(create_rgb(nonNorm1[pos], channel)) 232 # ax[1, 0].set_title('Left Image', fontsize=20) 233 # ax[1, 0].axis('off') 234 235 # ax[1, 1].imshow(y[pos], cmap='gray') # Adjust based on the actual structure 236 # ax[1, 1].set_title('Predicted Mask', fontsize=20) 237 # ax[1, 1].axis('off') 238 239 # ax[1, 2].imshow(create_rgb(nonNorm2[pos], channel)) 240 # ax[1, 2].set_title('Right Image', fontsize=20) 241 # ax[1, 2].axis('off') 242 243 # plt.show() 244 245 # print("") 246 247 248 249 # Feature Maps ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 250 251 # import matplotlib.pyplot as plt 252 # import numpy as np 253 # import os 254 255 # feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fmaps' 256 257 258 259 # if not os.path.exists(feature_maps_dir): 260 # os.makedirs(feature_maps_dir) 261 262 # # Function to plot a single feature map 263 # def plot_feature_map(feature_map, index, map_index, save_dir): 264 # plt.imshow(feature_map, cmap='viridis') 265 # plt.colorbar() 266 # plt.title(f"Feature Map {map_index} for Sample {index}") 267 # plt.savefig(os.path.join(save_dir, f"feature_map_{index}_{map_index}.png")) 268 # plt.close() 269 270 # # Function to plot a composite image from all feature maps 271 # def plot_composite_feature_map(feature_maps, index, save_dir): 272 # composite_map = np.mean(feature_maps, axis=-1) 273 # plt.imshow(composite_map, cmap='viridis') 274 # plt.colorbar() 275 # plt.title(f"Composite Feature Map for Sample {index}") 276 # plt.savefig(os.path.join(save_dir, f"composite_feature_map_{index}.png")) 277 # plt.close() 278 279 280 281 # for index in range(10): 282 # img1 = np.expand_dims(X1[index], axis=0) 283 # img2 = np.expand_dims(X2[index], axis=0) 284 285 # # Assuming branch_model.predict now returns ASPP feature maps and final output 286 # aspp_feature_maps, _ = branch_model.predict([img1, img2]) 287 # feature_maps = aspp_feature_maps[0] # Extract the first (and only) item in the batch 288 289 # feature_maps = feature_model.predict([X1, X2]) 290 291 # layer_names = [layer.name for layer in source_model.layers if 'relu' in layer.name] 292 # plot_feature_map(feature_maps, layer_names) 293 294 # # Save the feature maps 295 # feature_map_path = os.path.join(feature_maps_dir, f"feature_map_{index}.npy") 296 # np.save(feature_map_path, feature_maps) 297 298 299 300 # # # Optionally, plot each feature map individually 301 # # for map_index in range(feature_maps.shape[-1]): # Loop through each channel 302 # # plot_feature_map(feature_maps[:, :, map_index], index, map_index, feature_maps_dir) 303 304 # # And/or plot a composite image 305 # plot_composite_feature_map(feature_maps, index, feature_maps_dir) 306 307 # ---------------------------------------------------------------------------------------------------------------------- 308 309 # def save_and_visualize_feature_maps(feature_maps, layer_names, directory, initial_inputs, example_index=0): 310 # # Ensure the output directory exists 311 # if not os.path.exists(directory): 312 # os.makedirs(directory) 313 314 # # Plot and save initial inputs for the specified example index 315 # fig, axes = plt.subplots(1, len(initial_inputs), figsize=(10, 5)) 316 # for i, ax in enumerate(axes.flat): 317 # input_image = initial_inputs[i][example_index] # Select the example_index-th example 318 # if input_image.shape[-1] == 1: # If grayscale, reshape to remove the last dimension for plotting 319 # input_image = input_image.reshape(input_image.shape[0], input_image.shape[1]) 320 # ax.imshow(input_image, cmap='viridis') 321 # ax.set_title(f'Input {i+1}') 322 # ax.axis('off') 323 # plt.suptitle(f'Initial Inputs for Example {example_index}') 324 # plt.savefig(os.path.join(directory, f'initial_input_{example_index}.png')) 325 # plt.show() 326 327 # # Process each feature map corresponding to each layer 328 # for fmap, layer_name in zip(feature_maps, layer_names): 329 # # Define the directory for this particular layer 330 # layer_dir = os.path.join(directory, layer_name) 331 # if not os.path.exists(layer_dir): 332 # os.makedirs(layer_dir) 333 334 # num_filters = fmap.shape[-1] 335 # fig, axes = plt.subplots(1, min(num_filters, 20), figsize=(20, 2)) 336 337 # if num_filters == 1: 338 # axes = [axes] # Make it iterable 339 340 # # Plot and save each filter's feature map for the specified example index 341 # for i, ax in enumerate(axes): 342 # if i < 20: 343 # feature_image = fmap[example_index, :, :, i] 344 # ax.imshow(feature_image, cmap='viridis') 345 # ax.set_title(f'Filter {i}') 346 # ax.axis('off') 347 # plot_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.png') 348 # plt.savefig(plot_filename) 349 350 # plt.suptitle(f'Feature Maps from Layer: {layer_name} for Example {example_index}') 351 # plt.show() 352 353 # # Save each feature map as NPY file for the specified example index 354 # for i in range(num_filters): 355 # npy_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.npy') 356 # np.save(npy_filename, fmap[example_index, :, :, i]) 357 358 # # Example usage 359 # feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fmaps_3' 360 # initial_inputs = [X1, X2] 361 # example_index = 0 # Change this to visualize different examples 362 # save_and_visualize_feature_maps(feature_maps, layers_of_interest, feature_maps_dir, initial_inputs, example_index) 363 # ---------------------------------------------------------------------------------------------------------------------- 364 365 366 def save_and_visualize_feature_maps(feature_maps, layer_names, directory, initial_inputs, example_index=0, y_true=None, y_pred_conv=None): 367 # Ensure the output directory exists 368 if not os.path.exists(directory): 369 os.makedirs(directory) 370 371 # Plot and save initial inputs and associated masks for the specified example index 372 fig, axes = plt.subplots(1, len(initial_inputs) + 2, figsize=(15, 5)) # Adjusted for two additional plots 373 for i, ax in enumerate(axes[:-2]): 374 input_image = initial_inputs[i][example_index] 375 if input_image.shape[-1] == 1: # Handle grayscale images 376 input_image = input_image.reshape(input_image.shape[0], input_image.shape[1]) 377 ax.imshow(input_image, cmap='viridis') 378 ax.set_title(f'Input {i+1}') 379 ax.axis('off') 380 381 # Plot and save the ground truth and predicted mask 382 if y_true is not None and y_pred_conv is not None: 383 axes[-2].imshow(y_true[example_index], cmap='gray') 384 axes[-2].set_title('Ground Truth') 385 axes[-2].axis('off') 386 axes[-1].imshow(y_pred_conv[example_index], cmap='gray') 387 axes[-1].set_title('Predicted Mask') 388 axes[-1].axis('off') 389 390 plt.suptitle(f'Initial Inputs and Masks for Example {example_index}') 391 plt.savefig(os.path.join(directory, f'input_and_masks_{example_index}.png')) 392 plt.show() 393 394 # Process and save each feature map corresponding to each layer 395 for fmap, layer_name in zip(feature_maps, layer_names): 396 layer_dir = os.path.join(directory, layer_name) 397 if not os.path.exists(layer_dir): 398 os.makedirs(layer_dir) 399 400 num_filters = fmap.shape[-1] 401 cols = 8 # Define the number of columns 402 rows = min(4, (num_filters + cols - 1) // cols) # Define rows, maximum of 4 rows 403 fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows)) 404 axes = axes.flatten() # Flatten the array for easier indexing 405 406 # Plot each feature image in its respective subplot 407 for i in range(min(num_filters, 32)): # Show up to 32 filters 408 if i < len(axes): # Check if there are enough subplots available 409 ax = axes[i] 410 feature_image = fmap[example_index, :, :, i] 411 ax.imshow(feature_image, cmap='viridis') 412 ax.set_title(f'Filter {i}', fontsize=10) 413 ax.axis('off') 414 plot_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.png') 415 plt.savefig(plot_filename) # Save the image file for this filter 416 417 # Turn off unused axes 418 for j in range(i + 1, len(axes)): 419 axes[j].axis('off') 420 421 plt.tight_layout() 422 plt.suptitle(f'Feature Maps from Layer: {layer_name} for Example {example_index}', fontsize=16) 423 plt.show() 424 425 # Save each feature map as NPY file 426 for i in range(num_filters): 427 npy_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.npy') 428 np.save(npy_filename, fmap[example_index, :, :, i]) 429 430 # Example usage 431 feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fonera/fmaps_57' 432 initial_inputs = [X1, X2] 433 example_index = 57 # Change this to visualize different examples 434 y_true = y 435 y_pred_conv = source_model.predict([X1, X2]) 436 y_pred_conv = np.argmax(y_pred_conv, axis=3) 437 save_and_visualize_feature_maps(feature_maps, layers_of_interest, feature_maps_dir, initial_inputs, example_index, y_true, y_pred_conv) 438 439 440 441 print("Done")