/ training / fmaps / Feature_pred.py
Feature_pred.py
  1  #!/usr/bin/env python3
  2  # -*- coding: utf-8 -*-
  3  """
  4  Created on Sun Jun 11 12:02:38 2023
  5  
  6  @author: aleoikon
  7  """
  8  import sys
  9  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
 10  
 11  import numpy as np
 12  from architectures.conv_classifier import conv_classifier_two, conv_classifier_two_with_nspp,conv_classifier_two_with_aspp
 13  from architectures.branch import branch_cva, branch_cva_with_nspp, two_branch_cva_with_aspp,two_branch_cva_with_aspp_fmaps
 14  from changedetection.utils.layer_select import feature_selector_cva, feature_selector_cva_with_nspp, two_feature_selector_cva_aspp
 15  from changedetection.utils.fusion_maria import fusion
 16  import matplotlib.pyplot as plt
 17  from changedetection.utils.visualize import create_rgb
 18  from numpy import expand_dims
 19  from skimage.filters import threshold_otsu, threshold_triangle
 20  from skimage.morphology import remove_small_objects
 21  import glob
 22  import os
 23  from skimage import io
 24  
 25  #hide the gpus for timing the application
 26  os.environ['CUDA_VISIBLE_DEVICES'] = '1'
 27  
 28  def calculate_distancemap(f1, f2):
 29  
 30      dist_per_fmap= [(f2[i,:,:]-f1[i,:,:])**2 for i in range(f1.shape[0])]
 31      
 32      return np.sqrt(sum(dist_per_fmap))
 33  
 34  def read_rasters(path):
 35      im_names = glob.glob(os.path.join(path,'*B04.tif'))  # search for files with names containing 'B04.tif'
 36      print(im_names)
 37      r = io.imread(im_names[0])
 38      print(im_names)
 39      im_names = glob.glob(os.path.join(path,'*B03.tif'))  # search for files with names containing 'B03.tif'
 40      g = io.imread(im_names[0])
 41      print(im_names)
 42      im_names = glob.glob(os.path.join(path,'*B02.tif'))  # search for files with names containing 'B02.tif'
 43      b = io.imread(im_names[0])
 44      print(im_names)
 45      I = np.stack((r,g,b),axis=2).astype('float')
 46      return I
 47  
 48  def save_and_visualize_feature_maps(feature_maps, feature_model, source_model, layers_of_interest, img1, img2, feature_maps_dir):
 49      # Predict feature maps and classification/prediction masks
 50      feature_maps = feature_model.predict([img1, img2])
 51      y_pred_conv = source_model.predict([img1, img2])
 52      y_pred_conv = np.argmax(y_pred_conv, axis=3)  # Assuming a classification task
 53  
 54      # Directory handling
 55      if not os.path.exists(feature_maps_dir):
 56          os.makedirs(feature_maps_dir)
 57  
 58      # Setting up the plot for initial inputs and prediction results
 59      fig, axes = plt.subplots(1, 4, figsize=(20, 5))  # Added more plots to visualize all necessary components
 60  
 61      # Showing the first image in the batch
 62      example_index = 0  # assuming batch size or single image
 63  
 64      # Displaying the original images
 65      axes[0].imshow(img1[example_index])  # Assuming img1 is properly preprocessed if necessary
 66      axes[0].set_title('Input Image 1')
 67      axes[0].axis('off')
 68  
 69      axes[1].imshow(img2[example_index])  # Assuming img2 is properly preprocessed if necessary
 70      axes[1].set_title('Input Image 2')
 71      axes[1].axis('off')
 72  
 73      # Displaying the predicted mask
 74      axes[2].imshow(y_pred_conv[example_index], cmap='gray')
 75      axes[2].set_title('Predicted Mask')
 76      axes[2].axis('off')
 77  
 78      # Saving the initial inputs visualization
 79      plt.suptitle('Initial Inputs and Predicted Mask')
 80      plt.savefig(os.path.join(feature_maps_dir, 'inputs_and_predictions.png'))
 81      plt.show()
 82  
 83      # Visualize and save feature maps
 84      for fmap, layer_name in zip(feature_maps, layers_of_interest):
 85          layer_dir = os.path.join(feature_maps_dir, layer_name)
 86          if not os.path.exists(layer_dir):
 87              os.makedirs(layer_dir)
 88  
 89          num_filters = fmap.shape[-1]
 90          cols = 8  # Define columns
 91          rows = min(4, (num_filters + cols - 1) // cols)  # Define rows
 92          fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows))
 93          axes = axes.flatten()
 94  
 95          # Plot each feature map
 96          for i in range(min(num_filters, 32)):
 97              if i < len(axes):
 98                  ax = axes[i]
 99                  feature_image = fmap[example_index, :, :, i]
100                  ax.imshow(feature_image, cmap='viridis')
101                  ax.set_title(f'Filter {i}', fontsize=10)
102                  ax.axis('off')
103                  plt.savefig(os.path.join(layer_dir, f'filter_{i}_example_{example_index}.png'))
104  
105          # Turn off unused axes
106          for j in range(i + 1, len(axes)):
107              axes[j].axis('off')
108  
109          plt.tight_layout()
110          plt.suptitle(f'Feature Maps from Layer: {layer_name}', fontsize=16)
111          plt.show()
112  
113          # Save each feature map as an NPY file
114          for i in range(num_filters):
115              npy_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.npy')
116              np.save(npy_filename, fmap[example_index, :, :, i])
117  
118  
119  def change_detection(dataset_name, input_rasters):
120  
121  
122      # Load images
123      imgs = []
124      for timeframe in input_rasters:
125          print(timeframe)
126          imgs.append(read_rasters(timeframe))
127      img1_or = imgs[0]
128      img2_or = imgs[1]
129  
130  
131  
132      # Load change detection model
133      path = f'/home/dvalsamis/Documents/data/Onera/Predictions/Depth_2/Trainable_Onera/{dataset_name}'
134      # Load change detection model
135      saved_model = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/CD_Simple_Onera_0a8d.h5'         
136      
137  
138      # Check if the folder exists
139      if not os.path.exists(path):
140          # Create the folder
141          os.makedirs(path)
142  
143      shape = img1_or.shape
144      depth = 2
145      dropout = 0.1
146      decay = 0.0001
147      ImageSize_X = shape[0]
148      ImageSize_Y = shape[1]
149      n_ch = shape[2]
150      
151      source_model = conv_classifier_two_with_aspp(depth, dropout, decay, ImageSize_X, ImageSize_Y, n_ch)
152  
153      cd_model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/CD_Simple_CBMI_c141.h5'
154      source_model.load_weights(cd_model_path)
155  
156  
157      branch_model,feature_model = two_branch_cva_with_aspp_fmaps(dropout, decay, depth, ImageSize_X, ImageSize_Y, n_ch)
158      branch_model = two_feature_selector_cva_aspp(depth, source_model, branch_model)
159      
160      #predictions
161      img1 = expand_dims(img1_or, axis=0)
162      img1 = (img1 - img1.mean()) / img1.std()
163      img2 = expand_dims(img2_or, axis=0)
164      img2 = (img2 - img2.mean()) / img2.std()
165     
166  
167      layers_of_interest = ['aspp_reduced_relu', 'abs_diff_2','reduced_aspp_output','output']
168  
169      
170  
171      # Example usage
172      feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fmaps/fmaps_total_0'
173      layers_of_interest = ['aspp_reduced_relu', 'abs_diff_2','reduced_aspp_output','output']
174  
175      save_and_visualize_feature_maps(None, feature_model, source_model, layers_of_interest, img1, img2, feature_maps_dir)
176  
177  
178  # Read the dataset names from the all.txt file
179  
180  dataset_names_file = '/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/all.txt'
181  
182  #dataset_names_file =  '/home/aleoikon/Documents/data/onera/Onera Satellite Change Detection dataset - Images/all.txt'
183  
184  
185  with open(dataset_names_file, 'r') as file:
186      dataset_names_line = file.readline().strip()  # Read the line containing all dataset names
187      dataset_names = dataset_names_line.split(',')  # Split the line into individual dataset names
188  
189  # Iterate over dataset names
190  for dataset_name in dataset_names:
191      try:
192          # Define input rasters for the current dataset
193          input_rasters = [
194  
195              f'/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/{dataset_name}//img1_cropped/',
196              f'/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/{dataset_name}//img2_cropped/'
197          ]
198  
199          # input_rasters = [
200          #     f'/home/aleoikon/Documents/data/onera/Onera Satellite Change Detection dataset - Images//{dataset_name}/imgs_1/',
201          #     f'/home/aleoikon/Documents/data/onera/Onera Satellite Change Detection dataset - Images//{dataset_name}/imgs_2/'
202          # ]
203          # Perform change detection for the current dataset
204          change_mask = change_detection(dataset_name, input_rasters)
205  
206      except Exception as e:
207          print(f"Error occurred in dataset: {dataset_name}")
208          print(f"Error message: {str(e)}")
209  
210  
211  
212  
213      print("Done")