/ training / fmaps / feature_maps_S2W.py
feature_maps_S2W.py
  1  #!/usr/bin/env python3
  2  # -*- coding: utf-8 -*-
  3  """
  4  Created on Thu Mar 10 09:40:56 2022
  5  
  6  @author: aleoikon
  7  """
  8  
  9  import sys
 10  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
 11  import time
 12  
 13  
 14  from skimage.filters import threshold_otsu, threshold_triangle
 15  from tensorflow import keras
 16  from architectures.branch import branches_nopool, branch_cva, branch_cva_aspp,two_branch_cva_with_aspp,two_branch_cva_with_aspp_fmaps
 17  #from tests import change_detection_noup, change_detection_noup_1x1convs
 18  from architectures.similarity_detection import pretext_task_one_nopool
 19  from tensorflow.keras.optimizers import Adam
 20  from tensorflow.keras.utils import plot_model
 21  import pandas as pd
 22  import tensorflow as tf
 23  import numpy as np
 24  import os 
 25  import random
 26  from numpy import expand_dims
 27  from skimage.morphology import remove_small_objects
 28  import matplotlib
 29  matplotlib.use('TkAgg')
 30  import matplotlib.pyplot as plt
 31  from utils.layer_select import feature_selector_cva, feature_selector_cva_aspp,two_feature_selector_cva_aspp
 32  from utils.log_params import log_params_cva
 33  from architectures.conv_classifier import conv_classifier, conv_classifier_two,conv_classifier_two_with_aspp
 34  from keras.models import Sequential
 35  from skimage.filters import gaussian
 36  from utils.my_metrics import recall, accuracy, specificity, precision, f_measure, get_roc
 37  import uuid
 38  from datetime import datetime
 39  os.environ["CUDA_VISIBLE_DEVICES"]="1"
 40  
 41  
 42  #Euclidean Distance 
 43  def calculate_distancemap(f1, f2):
 44      """
 45      calcualtes pixelwise euclidean distance between images with multiple imput channels
 46  
 47      Parameters
 48      ----------
 49      f1 : np.ndarray of shape (N,M,D)
 50          image 1 with the channels in the third dimension 
 51      f2 : np.ndarray of shape (N,M,D)
 52          image 2 with the channels in the third dimension 
 53          
 54      Returns
 55      -------
 56      np.ndarray of shape(N,M)
 57          pixelwise euclidean distance between image 1 and image 2
 58  
 59      """
 60      dist_per_fmap= [(f2[i,:,:]-f1[i,:,:])**2 for i in range(f1.shape[0])]
 61      
 62      return np.sqrt(sum(dist_per_fmap))
 63  
 64  def create_rgb_onera(x,channel):
 65      if channel == 'red':
 66          r = x[:,:,2]
 67          r = np.expand_dims(r, axis=2)
 68          return r
 69      if channel == 'green':
 70          g = x[:,:,1]
 71          g = np.expand_dims(g, axis=2)
 72          return g
 73      if channel == 'blue':
 74          b = x[:,:,0]
 75          b = np.expand_dims(b, axis=2)
 76          return b
 77      if channel == 'rgb':
 78          r = x[:,:,2]
 79          g = x[:,:,1]
 80          b  = x[:,:,0]
 81          rgb = np.dstack((r,g,b))
 82          return(rgb)
 83      if channel == 'rgbvnir':
 84          r = x[:,:,2]
 85          g = x[:,:,1]
 86          b  = x[:,:,0]
 87          vnir = x[:,:,3]
 88          rgbvnir = np.stack((r,g,b,vnir),axis=2).astype('float')
 89          return(rgbvnir)
 90      else:
 91          return x
 92          print("NOT CORRECT CHANNELS")
 93  
 94  def normalize(x):
 95      img =((x - x.mean()) / x.std())
 96      return img
 97  
 98  def scaleMinMax(x):
 99      return ((x - np.nanpercentile(x,2)) / (np.nanpercentile(x,98) - np.nanpercentile(x,2)))
100  
101  
102  def create_rgb(x, channels):
103      if channels == 'red':
104          r = x[:,:,0]
105          r = scaleMinMax(r)
106          return r
107      if channels == 'green':    
108          g = x[:,:,0]
109          g = scaleMinMax(g)
110          return g
111      if channels == 'blue':
112          b  = x[:,:,0]
113          b = scaleMinMax(b)
114          return b
115      if channels == 'rgb':
116          r = x[:,:,2]
117          r = scaleMinMax(r)
118          g = x[:,:,1]
119          g = scaleMinMax(g)
120          b  = x[:,:,0]
121          b = scaleMinMax(b)
122          rgb = np.dstack((r,g,b))
123          return(rgb)
124      
125  
126  
127  #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
128  
129  depth = 2
130  dropout = 0.1
131  decay = 0.0001
132  NORM = True
133  ImageSize = 96
134  n_ch = 3
135  channel = 'rgb'
136  
137  # Models ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
138  source_model = conv_classifier_two_with_aspp(depth, dropout, decay, ImageSize, ImageSize, n_ch)
139  mtype = 'conv_classifier_two'
140  
141  cd_model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/CD_Simple_CBMI_c141.h5'
142  cd_model_name = 'CD_Simple_CBMI_c141'
143  model_id = cd_model_name.split('_')[-1].split('.')[0]
144  
145  
146  source_model.load_weights(cd_model_path)
147  
148  
149  branch_model,feature_model = two_branch_cva_with_aspp_fmaps(dropout, decay, depth, ImageSize, ImageSize, n_ch)
150  branch_model = two_feature_selector_cva_aspp(depth, source_model, branch_model)
151  #plot_model(branch_model, to_file='/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/graphs_cva/'+model_id+'_model_plot.png', show_shapes=True, show_layer_names=True)
152  
153      
154  # Data ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
155  
156  
157  
158  # onera_train_target =  '/home/aleoikon/Documents/data/ssl/onera_npys/patches/downstream/train/'  
159  # onera_test_target = '/home/aleoikon/Documents/data/ssl/onera_npys/patches/downstream/test/' 
160  
161  
162  
163  onera_train_target =  '/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'  
164  onera_test_target = '/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
165  
166  
167  onera_test_df = pd.read_csv(onera_test_target + "dataset_test.csv")
168  
169  dsize = len(onera_test_df)  # Total size of the dataset
170  trsize = len(onera_train_target)  # Size of the training set
171  tesize = len(onera_test_target)  # Size of the testing set
172  
173  
174  X1 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
175  X2 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
176  input_1 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
177  input_2 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
178  y = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
179  y_pred_otsu = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
180  y_pred_triangle = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
181  y_pred_conv = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
182  
183  
184  
185  for i in range(len(onera_test_df)):
186      img1 =  np.load(onera_test_target+onera_test_df['pair1'][i])
187      img2 = np.load(onera_test_target+onera_test_df['pair2'][i])
188      img1 = create_rgb_onera(img1, channel)
189      img2 = create_rgb_onera(img2, channel)
190  
191      input_1[i] = img1
192      input_2[i] = img2
193  
194      if NORM:
195          X1[i] = normalize(img1)
196          X2[i] = normalize(img2)
197      else:
198          X1[i] = img1
199          X2[i] = img2
200      y[i] =  np.load(onera_test_target+onera_test_df['change_mask'][i])
201  
202  
203  layers_of_interest = ['aspp_reduced_relu', 'abs_diff_2','reduced_aspp_output','output']
204  
205  feature_maps = feature_model.predict([X1, X2])
206  
207  
208  branch_model.summary()
209  
210  
211  # y_pred_conv = branch_model.predict([X1,X2])
212  # y_pred_conv = np.argmax(y_pred_conv, axis=3)
213  # # Assuming y_pred is a list with one element per output layer
214  # y_pred_single = y[0]  # Assuming the first element is the desired output
215  
216  # pos = 57
217  # fig, ax = plt.subplots(2, 3, figsize=(15, 10), constrained_layout=True)
218  
219  # ax[0, 0].imshow(create_rgb(input_1[pos], channel))
220  # ax[0, 0].set_title('Left Image', fontsize=20)
221  # ax[0, 0].axis('off')
222  
223  # ax[0, 1].imshow(y[pos], cmap='gray')
224  # ax[0, 1].set_title('Ground Truth', fontsize=20)
225  # ax[0, 1].axis('off')
226  
227  # ax[0, 2].imshow(create_rgb(input_2[pos], channel))
228  # ax[0, 2].set_title('Right Image', fontsize=20)
229  # ax[0, 2].axis('off')
230  
231  # ax[1, 0].imshow(create_rgb(input_1[pos], channel))
232  # ax[1, 0].set_title('Left Image', fontsize=20)
233  # ax[1, 0].axis('off')
234  
235  # ax[1, 1].imshow(y[pos], cmap='gray')  # Adjust based on the actual structure
236  # ax[1, 1].set_title('Predicted Mask', fontsize=20)
237  # ax[1, 1].axis('off')
238  
239  # ax[1, 2].imshow(create_rgb(input_2[pos], channel))
240  # ax[1, 2].set_title('Right Image', fontsize=20)
241  # ax[1, 2].axis('off')
242  
243  # plt.show()
244  
245  # print("")
246   
247  
248  
249  # Feature Maps ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
250  
251  # import matplotlib.pyplot as plt
252  # import numpy as np
253  # import os
254  
255  # feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fmaps'
256  
257  
258  
259  # if not os.path.exists(feature_maps_dir):
260  #     os.makedirs(feature_maps_dir)
261  
262  # # Function to plot a single feature map
263  # def plot_feature_map(feature_map, index, map_index, save_dir):
264  #     plt.imshow(feature_map, cmap='viridis')
265  #     plt.colorbar()
266  #     plt.title(f"Feature Map {map_index} for Sample {index}")
267  #     plt.savefig(os.path.join(save_dir, f"feature_map_{index}_{map_index}.png"))
268  #     plt.close()
269  
270  # # Function to plot a composite image from all feature maps
271  # def plot_composite_feature_map(feature_maps, index, save_dir):
272  #     composite_map = np.mean(feature_maps, axis=-1)
273  #     plt.imshow(composite_map, cmap='viridis')
274  #     plt.colorbar()
275  #     plt.title(f"Composite Feature Map for Sample {index}")
276  #     plt.savefig(os.path.join(save_dir, f"composite_feature_map_{index}.png"))
277  #     plt.close()
278  
279  
280  
281  # for index in range(10):
282  #     img1 = np.expand_dims(X1[index], axis=0)
283  #     img2 = np.expand_dims(X2[index], axis=0)
284      
285  #     # Assuming branch_model.predict now returns ASPP feature maps and final output
286  #     aspp_feature_maps, _ = branch_model.predict([img1, img2])
287  #     feature_maps = aspp_feature_maps[0]  # Extract the first (and only) item in the batch
288  
289  #     feature_maps = feature_model.predict([X1, X2])
290  
291  #     layer_names = [layer.name for layer in source_model.layers if 'relu' in layer.name]
292  #     plot_feature_map(feature_maps, layer_names)
293  
294  #     # Save the feature maps
295  #     feature_map_path = os.path.join(feature_maps_dir, f"feature_map_{index}.npy")
296  #     np.save(feature_map_path, feature_maps)
297  
298  
299  
300  #     # # Optionally, plot each feature map individually
301  #     # for map_index in range(feature_maps.shape[-1]):  # Loop through each channel
302  #     #     plot_feature_map(feature_maps[:, :, map_index], index, map_index, feature_maps_dir)
303      
304  #     # And/or plot a composite image
305  #     plot_composite_feature_map(feature_maps, index, feature_maps_dir)
306  
307  # ----------------------------------------------------------------------------------------------------------------------
308  
309  # def save_and_visualize_feature_maps(feature_maps, layer_names, directory, initial_inputs, example_index=0):
310  #     # Ensure the output directory exists
311  #     if not os.path.exists(directory):
312  #         os.makedirs(directory)
313  
314  #     # Plot and save initial inputs for the specified example index
315  #     fig, axes = plt.subplots(1, len(initial_inputs), figsize=(10, 5))
316  #     for i, ax in enumerate(axes.flat):
317  #         input_image = initial_inputs[i][example_index]  # Select the example_index-th example
318  #         if input_image.shape[-1] == 1:  # If grayscale, reshape to remove the last dimension for plotting
319  #             input_image = input_image.reshape(input_image.shape[0], input_image.shape[1])
320  #         ax.imshow(input_image, cmap='viridis')
321  #         ax.set_title(f'Input {i+1}')
322  #         ax.axis('off')
323  #     plt.suptitle(f'Initial Inputs for Example {example_index}')
324  #     plt.savefig(os.path.join(directory, f'initial_input_{example_index}.png'))
325  #     plt.show()
326  
327  #     # Process each feature map corresponding to each layer
328  #     for fmap, layer_name in zip(feature_maps, layer_names):
329  #         # Define the directory for this particular layer
330  #         layer_dir = os.path.join(directory, layer_name)
331  #         if not os.path.exists(layer_dir):
332  #             os.makedirs(layer_dir)
333  
334  #         num_filters = fmap.shape[-1]
335  #         fig, axes = plt.subplots(1, min(num_filters, 20), figsize=(20, 2))
336  
337  #         if num_filters == 1:
338  #             axes = [axes]  # Make it iterable
339  
340  #         # Plot and save each filter's feature map for the specified example index
341  #         for i, ax in enumerate(axes):
342  #             if i < 20:
343  #                 feature_image = fmap[example_index, :, :, i]
344  #                 ax.imshow(feature_image, cmap='viridis')
345  #                 ax.set_title(f'Filter {i}')
346  #                 ax.axis('off')
347  #                 plot_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.png')
348  #                 plt.savefig(plot_filename)
349  
350  #         plt.suptitle(f'Feature Maps from Layer: {layer_name} for Example {example_index}')
351  #         plt.show()
352  
353  #         # Save each feature map as NPY file for the specified example index
354  #         for i in range(num_filters):
355  #             npy_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.npy')
356  #             np.save(npy_filename, fmap[example_index, :, :, i])
357  
358  # # Example usage
359  # feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fmaps_3'
360  # initial_inputs = [X1, X2]
361  # example_index = 0  # Change this to visualize different examples
362  # save_and_visualize_feature_maps(feature_maps, layers_of_interest, feature_maps_dir, initial_inputs, example_index)
363  # ----------------------------------------------------------------------------------------------------------------------
364  
365  
366  def save_and_visualize_feature_maps(feature_maps, layer_names, directory, initial_inputs, example_index=0, y_true=None, y_pred_conv=None):
367      # Ensure the output directory exists
368      if not os.path.exists(directory):
369          os.makedirs(directory)
370  
371      # Plot and save initial inputs and associated masks for the specified example index
372      fig, axes = plt.subplots(1, len(initial_inputs) + 2, figsize=(15, 5))  # Adjusted for two additional plots
373      for i, ax in enumerate(axes[:-2]):
374          input_image = initial_inputs[i][example_index]
375          if input_image.shape[-1] == 1:  # Handle grayscale images
376              input_image = input_image.reshape(input_image.shape[0], input_image.shape[1])
377          ax.imshow(input_image, cmap='viridis')
378          ax.set_title(f'Input {i+1}')
379          ax.axis('off')
380  
381      # Plot and save the ground truth and predicted mask
382      if y_true is not None and y_pred_conv is not None:
383          axes[-2].imshow(y_true[example_index], cmap='gray')
384          axes[-2].set_title('Ground Truth')
385          axes[-2].axis('off')
386          axes[-1].imshow(y_pred_conv[example_index], cmap='gray')
387          axes[-1].set_title('Predicted Mask')
388          axes[-1].axis('off')
389  
390      plt.suptitle(f'Initial Inputs and Masks for Example {example_index}')
391      plt.savefig(os.path.join(directory, f'input_and_masks_{example_index}.png'))
392      plt.show()
393  
394      # Process and save each feature map corresponding to each layer
395      for fmap, layer_name in zip(feature_maps, layer_names):
396          layer_dir = os.path.join(directory, layer_name)
397          if not os.path.exists(layer_dir):
398              os.makedirs(layer_dir)
399  
400          num_filters = fmap.shape[-1]
401          cols = 8  # Define the number of columns
402          rows = min(4, (num_filters + cols - 1) // cols)  # Define rows, maximum of 4 rows
403          fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows))
404          axes = axes.flatten()  # Flatten the array for easier indexing
405  
406          # Plot each feature image in its respective subplot
407          for i in range(min(num_filters, 32)):  # Show up to 32 filters
408              if i < len(axes):  # Check if there are enough subplots available
409                  ax = axes[i]
410                  feature_image = fmap[example_index, :, :, i]
411                  ax.imshow(feature_image, cmap='viridis')
412                  ax.set_title(f'Filter {i}', fontsize=10)
413                  ax.axis('off')
414                  plot_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.png')
415                  plt.savefig(plot_filename)  # Save the image file for this filter
416  
417          # Turn off unused axes
418          for j in range(i + 1, len(axes)):
419              axes[j].axis('off')
420  
421          plt.tight_layout()
422          plt.suptitle(f'Feature Maps from Layer: {layer_name} for Example {example_index}', fontsize=16)
423          plt.show()
424  
425          # Save each feature map as NPY file
426          for i in range(num_filters):
427              npy_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.npy')
428              np.save(npy_filename, fmap[example_index, :, :, i])
429  
430  # Example usage
431  feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fmaps/fmaps_1'
432  initial_inputs = [X1, X2]
433  example_index = 57  # Change this to visualize different examples
434  y_true = y
435  y_pred_conv = source_model.predict([X1, X2])
436  y_pred_conv = np.argmax(y_pred_conv, axis=3)
437  save_and_visualize_feature_maps(feature_maps, layers_of_interest, feature_maps_dir, initial_inputs, example_index, y_true, y_pred_conv)
438  
439  
440  print("Done")