/ training / downstream_tasks / levir_cd.py
levir_cd.py
  1  #!/usr/bin/env python3
  2  # -*- coding: utf-8 -*-
  3  """
  4  Created on Thu Mar 10 09:40:56 2022
  5  
  6  @author: aleoikon
  7  """
  8  
  9  import sys
 10  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
 11  import time
 12  
 13  
 14  from skimage.filters import threshold_otsu, threshold_triangle
 15  from tensorflow import keras
 16  from architectures.branch import branches_nopool, branch_cva, branch_cva_aspp,two_branch_cva_with_aspp,two_branch_cva_with_aspp_fmaps
 17  #from tests import change_detection_noup, change_detection_noup_1x1convs
 18  from architectures.similarity_detection import pretext_task_one_nopool
 19  from tensorflow.keras.optimizers import Adam
 20  from tensorflow.keras.utils import plot_model
 21  import pandas as pd
 22  import tensorflow as tf
 23  import numpy as np
 24  import os 
 25  import random
 26  from numpy import expand_dims
 27  from skimage.morphology import remove_small_objects
 28  import matplotlib
 29  matplotlib.use('TkAgg')
 30  import matplotlib.pyplot as plt
 31  from utils.layer_select import feature_selector_cva, feature_selector_cva_aspp,two_feature_selector_cva_aspp
 32  from utils.log_params import log_params_cva
 33  from architectures.conv_classifier import conv_classifier, conv_classifier_two,conv_classifier_two_with_aspp
 34  from keras.models import Sequential
 35  from skimage.filters import gaussian
 36  from utils.my_metrics import recall, accuracy, specificity, precision, f_measure, get_roc
 37  import uuid
 38  from datetime import datetime
 39  #os.environ["CUDA_VISIBLE_DEVICES"]="0"
 40  
 41  
 42  #Euclidean Distance 
 43  def calculate_distancemap(f1, f2):
 44      """
 45      calcualtes pixelwise euclidean distance between images with multiple imput channels
 46  
 47      Parameters
 48      ----------
 49      f1 : np.ndarray of shape (N,M,D)
 50          image 1 with the channels in the third dimension 
 51      f2 : np.ndarray of shape (N,M,D)
 52          image 2 with the channels in the third dimension 
 53          
 54      Returns
 55      -------
 56      np.ndarray of shape(N,M)
 57          pixelwise euclidean distance between image 1 and image 2
 58  
 59      """
 60      dist_per_fmap= [(f2[i,:,:]-f1[i,:,:])**2 for i in range(f1.shape[0])]
 61      
 62      return np.sqrt(sum(dist_per_fmap))
 63  
 64  def create_rgb_onera(x,channel):
 65      if channel == 'red':
 66          r = x[:,:,2]
 67          r = np.expand_dims(r, axis=2)
 68          return r
 69      if channel == 'green':
 70          g = x[:,:,1]
 71          g = np.expand_dims(g, axis=2)
 72          return g
 73      if channel == 'blue':
 74          b = x[:,:,0]
 75          b = np.expand_dims(b, axis=2)
 76          return b
 77      if channel == 'rgb':
 78          r = x[:,:,2]
 79          g = x[:,:,1]
 80          b  = x[:,:,0]
 81          rgb = np.dstack((r,g,b))
 82          return(rgb)
 83      if channel == 'rgbvnir':
 84          r = x[:,:,2]
 85          g = x[:,:,1]
 86          b  = x[:,:,0]
 87          vnir = x[:,:,3]
 88          rgbvnir = np.stack((r,g,b,vnir),axis=2).astype('float')
 89          return(rgbvnir)
 90      else:
 91          return x
 92          print("NOT CORRECT CHANNELS")
 93  
 94  def normalize(x):
 95      img =((x - x.mean()) / x.std())
 96      return img
 97  
 98  def scaleMinMax(x):
 99      return ((x - np.nanpercentile(x,2)) / (np.nanpercentile(x,98) - np.nanpercentile(x,2)))
100  
101  
102  def create_rgb(x, channels):
103      if channels == 'red':
104          r = x[:,:,0]
105          r = scaleMinMax(r)
106          return r
107      if channels == 'green':    
108          g = x[:,:,0]
109          g = scaleMinMax(g)
110          return g
111      if channels == 'blue':
112          b  = x[:,:,0]
113          b = scaleMinMax(b)
114          return b
115      if channels == 'rgb':
116          r = x[:,:,2]
117          r = scaleMinMax(r)
118          g = x[:,:,1]
119          g = scaleMinMax(g)
120          b  = x[:,:,0]
121          b = scaleMinMax(b)
122          rgb = np.dstack((r,g,b))
123          return(rgb)
124      
125  
126  
127  #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
128  
129  depth = 2
130  dropout = 0.1
131  decay = 0.0001
132  NORM = True
133  ImageSize = 96
134  n_ch = 3
135  channel = 'rgb'
136  
137  # Models ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
138  source_model = conv_classifier_two_with_aspp(depth, dropout, decay, ImageSize, ImageSize, n_ch)
139  mtype = 'conv_classifier_two'
140  
141  cd_model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/CD_Simple_Onera_0a8d.h5'
142  cd_model_name = 'CD_Simple_CBMI_0a8d'
143  model_id = cd_model_name.split('_')[-1].split('.')[0]
144  
145  
146  source_model.load_weights(cd_model_path)
147  
148  
149  branch_model = two_branch_cva_with_aspp(dropout, decay, depth, ImageSize, ImageSize, n_ch)
150  branch_model = two_feature_selector_cva_aspp(depth, source_model, branch_model)
151  #plot_model(branch_model, to_file='/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/graphs_cva/'+model_id+'_model_plot.png', show_shapes=True, show_layer_names=True)
152  
153      
154  # Data ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
155  
156  
157  onera_train_target =  '/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'  
158  onera_test_target = '/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
159  
160  
161  onera_test_df = pd.read_csv(onera_test_target + "dataset_test.csv")
162  
163  dsize = len(onera_test_df)  # Total size of the dataset
164  trsize = len(onera_train_target)  # Size of the training set
165  tesize = len(onera_test_target)  # Size of the testing set
166  
167  
168  X1 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
169  X2 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
170  
171  nonNorm1 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
172  nonNorm2 = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize,n_ch))
173  
174  y = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
175  y_pred = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
176  y_pred_otsu = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
177  y_pred_triangle = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
178  
179  
180  
181  for i in range(len(onera_test_df)):
182      img1 =  np.load(onera_test_target+onera_test_df['pair1'][i])
183      img2 = np.load(onera_test_target+onera_test_df['pair2'][i])
184      img1 = create_rgb_onera(img1, channel)
185      img2 = create_rgb_onera(img2, channel)
186  
187      nonNorm1[i] = img1
188      nonNorm2[i] = img2
189  
190      if NORM:
191          X1[i] = normalize(img1)
192          X2[i] = normalize(img2)
193      else:
194          X1[i] = img1
195          X2[i] = img2
196      y[i] =  np.load(onera_test_target+onera_test_df['change_mask'][i])
197  
198  
199  branch_model.summary()
200  
201   
202  # Setting Predictions----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
203  
204  
205  # Final two Branch Model
206  
207  for index in range(len(X1)):
208      img1 = np.expand_dims(X1[index], axis=0)
209      img2 = np.expand_dims(X2[index], axis=0)
210      
211      # Get the output from the model, which is (1, 96, 96, 32)
212      combined_features = branch_model.predict([img1, img2])
213  
214      # Calculate the mean across the channels to get a (96, 96) map
215      mean_feature_map = np.mean(combined_features, axis=-1)[0]
216      
217      
218      # Otsu's thresholding
219      binary_otsu = mean_feature_map > threshold_otsu(mean_feature_map)
220      binary_otsu = remove_small_objects(binary_otsu, min_size=100)
221  
222      # Triangle thresholding
223      binary_triangle = mean_feature_map > threshold_triangle(mean_feature_map)
224      binary_triangle = remove_small_objects(binary_triangle, min_size=100)
225      
226      y_pred_otsu[index] = binary_otsu
227      y_pred_triangle[index] = binary_triangle
228      
229  y_pred_conv = source_model.predict([X1,X2])
230  y_pred_conv = np.argmax(y_pred_conv, axis=3)
231  
232  print('Recall',recall(y,y_pred_conv))
233  print('Specificity',specificity(y,y_pred_conv))
234  print('Precision',precision(y,y_pred_conv)) 
235  print('F1',f_measure(y,y_pred_conv)) 
236  print('Accuracy',accuracy(y,y_pred_conv)) 
237  
238  
239  
240  
241  
242  #----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
243  
244  pos = random.randint(0, len(y))
245  print(pos)
246  print(len(y))
247  #pos = 157
248  #pos = 712
249  
250  fig, ax = plt.subplots(3, 3, figsize=(10,10),constrained_layout=True)
251  
252  ax[0,0].imshow(create_rgb(X1[pos], channel))
253  ax[0,0].set_title('Left Image', fontsize=20)
254  ax[0,0].axis('off')
255  
256  ax[0,1].imshow(y[pos], cmap='gray')
257  ax[0,1].set_title('Ground Truth', fontsize=20)
258  ax[0,1].axis('off')
259  
260  ax[0,2].imshow(create_rgb(X2[pos], channel))
261  ax[0,2].set_title('Right Image', fontsize=20)
262  ax[0,2].axis('off')
263  
264  ax[1,0].imshow(create_rgb(X1[pos], channel))
265  ax[1,0].set_title('Left Image', fontsize=20)
266  ax[1,0].axis('off')
267  
268  ax[1,1].imshow(y_pred_otsu[pos], cmap='gray')
269  ax[1,1].set_title('CVA(with Otsu)', fontsize=20)
270  ax[1,1].axis('off')
271  
272  ax[1,2].imshow(create_rgb(X2[pos], channel))
273  ax[1,2].set_title('Right Image', fontsize=20)
274  ax[1,2].axis('off')
275  
276  ax[2,0].imshow(create_rgb(X1[pos], channel))
277  ax[2,0].set_title('Left Image', fontsize=20)
278  ax[2,0].axis('off')
279  
280  ax[2,1].imshow(y_pred_triangle[pos], cmap='gray')
281  ax[2,1].set_title('CVA(with Triagle)', fontsize=20)
282  ax[2,1].axis('off')
283  
284  ax[2,2].imshow(create_rgb(X2[pos], channel))
285  ax[2,2].set_title('Right Image', fontsize=20)
286  ax[2,2].axis('off')
287  
288  
289  
290  recall(y,y_pred_otsu)
291  get_roc(y,y_pred_otsu)
292  
293  
294  log_params_cva('CBMI_Test', model_id, cd_model_name, mtype, depth, dropout, decay, ImageSize, n_ch, channel, NORM)
295  
296  #####Metrics#####
297  data_dict = {'Pretext':'Task 1',
298               'Model ID': model_id,
299               'Downstream':'CVA+Otsu(min size = 100)', 
300               'Sensitivity/Recall':recall(y,y_pred_otsu), 
301               'Specificity':specificity(y,y_pred_otsu), 
302               'Precision':precision(y,y_pred_otsu), 
303               'F1':f_measure(y,y_pred_otsu), 
304               'Accuracy':accuracy(y,y_pred_otsu), 
305               'Set':'CBMI Test', 
306               'ImageSize':ImageSize,
307               'Norm':NORM,
308               'Pretext Model':'-', 
309               'CD model':cd_model_name}
310  
311  # Make data frame of above data
312  df = pd.DataFrame(data_dict, index=[0])
313  
314  results_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/logs/cva_results_1.csv'
315  
316  # append data frame to CSV file
317  #df.to_csv(results_path, mode='a', index=False, header=False)
318  
319  # Append data to the CSV file
320  with open(results_path, 'a') as f:
321      # If the file is empty, write the headers
322      if f.tell() == 0:
323          pd.DataFrame([data_dict.keys()]).to_csv(f, header=False, index=False)
324      # Append data
325      pd.DataFrame([data_dict.values()]).to_csv(f, header=False, index=False)
326  
327  # print message
328  print("Data logged successfully.")
329  
330  data_dict = {'Pretext':'Task 1',
331               'Model ID': model_id,
332               'Downstream':'CVA+Triangle(min size = 100)', 
333               'Sensitivity/Recall':recall(y,y_pred_triangle), 
334               'Specificity':specificity(y,y_pred_triangle), 
335               'Precision':precision(y,y_pred_triangle), 
336               'F1':f_measure(y,y_pred_triangle), 
337               'Accuracy':accuracy(y,y_pred_triangle), 
338               'Set':'CBMI Test', 
339               'ImageSize':ImageSize,
340               'Norm':NORM,
341               'Pretext Model':'-', 
342               'CD model':cd_model_name}
343  
344  # Make data frame of above data
345  df = pd.DataFrame(data_dict, index=[0])
346  # append data frame to CSV file
347  
348  # Append data to the CSV file
349  with open(results_path, 'a') as f:
350      # If the file is empty, write the headers
351      if f.tell() == 0:
352          pd.DataFrame([data_dict.keys()]).to_csv(f, header=False, index=False)
353      # Append data
354      pd.DataFrame([data_dict.values()]).to_csv(f, header=False, index=False)
355  # print message
356  print("Data logged successfully.")
357  
358  
359  cd_results_df = pd.read_csv(results_path)
360  
361  ################fusion####################
362  from architectures.fusion_maria import fusion
363  
364  y_pred_fusion = np.ndarray(shape=(len(onera_test_df),ImageSize,ImageSize))
365  
366  for param in range(0, 13):
367      for cm_pos in range(len(y_pred_fusion)):
368          fused_mask = fusion(y_pred_triangle[cm_pos],y_pred_otsu[cm_pos],y_pred_conv[cm_pos], param)
369          y_pred_fusion[cm_pos] = fused_mask
370          
371      data_dict = {'Pretext':'Task 1',
372                   'Model ID': model_id,
373                   'Downstream':'Fusion', 
374                   'Sensitivity/Recall':recall(y,y_pred_fusion), 
375                   'Specificity':specificity(y,y_pred_fusion), 
376                   'Precision':precision(y,y_pred_fusion), 
377                   'F1':f_measure(y,y_pred_fusion), 
378                   'Accuracy':accuracy(y,y_pred_fusion), 
379                   'Set':'CBMI Test', 
380                   'ImageSize':ImageSize,
381                   'Norm':"param="+str(param),
382                   'Pretext Model':'-', 
383                   'CD model':cd_model_name}
384      
385      # Make data frame of above data
386      df = pd.DataFrame(data_dict, index=[0])
387      # Append data to the CSV file
388      with open(results_path, 'a') as f:
389          # If the file is empty, write the headers
390          if f.tell() == 0:
391              pd.DataFrame([data_dict.keys()]).to_csv(f, header=False, index=False)
392          # Append data
393          pd.DataFrame([data_dict.values()]).to_csv(f, header=False, index=False)
394      # print message
395      print("Data logged successfully.", param)
396      
397  #cd_results_df = pd.read_csv('cd_results.csv')
398  
399  #plot fusion
400  cm_pos = 0
401  for cm_pos in range(len(y_pred_fusion)):
402      fused_mask = fusion(y_pred_triangle[cm_pos],y_pred_otsu[cm_pos],y_pred_conv[cm_pos], 10)
403      y_pred_fusion[cm_pos] = fused_mask
404  
405  pos = random.randint(0, len(y))
406  print(pos)
407  print(len(y))
408  #pos = 157
409  #pos = 712
410  #pos = 577
411  fig, ax = plt.subplots(5, 1, figsize=(10,40),constrained_layout=True)
412  
413  ax[0].imshow(y[pos], cmap='gray')
414  ax[0].set_title('Ground Truth', fontsize=30)
415  ax[0].axis('off')
416  
417  ax[1].imshow(y_pred_conv[pos], cmap='gray')
418  ax[1].set_title('Conv classifier', fontsize=30)
419  ax[1].axis('off')
420  
421  ax[2].imshow(y_pred_otsu[pos], cmap='gray')
422  ax[2].set_title('CVA(with Otsu)', fontsize=30)
423  ax[2].axis('off')
424  
425  ax[3].imshow(y_pred_triangle[pos], cmap='gray')
426  ax[3].set_title('CVA(with Triagle)', fontsize=30)
427  ax[3].axis('off')
428  
429  ax[4].imshow(y_pred_fusion[pos], cmap='gray')
430  ax[4].set_title('Fusion', fontsize=30)
431  ax[4].axis('off')
432  
433  fig, ax = plt.subplots(1, 2, figsize=(10,40),constrained_layout=True)
434  ax[0].imshow(create_rgb(X1[pos], 'rgb'))
435  ax[0].set_title('Left Image', fontsize=20)
436  ax[0].axis('off')
437  
438  
439  
440  ax[1].imshow(create_rgb(X2[pos], 'rgb'))
441  ax[1].set_title('Right Image', fontsize=20)
442  ax[1].axis('off')
443  
444  
445  
446  
447  print("End of first excecution")
448