Feature_pred.py
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Created on Sun Jun 11 12:02:38 2023 5 6 @author: aleoikon 7 """ 8 import sys 9 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 10 11 import numpy as np 12 from architectures.conv_classifier import conv_classifier_two, conv_classifier_two_with_nspp,conv_classifier_two_with_aspp 13 from architectures.branch import branch_cva, branch_cva_with_nspp, two_branch_cva_with_aspp,two_branch_cva_with_aspp_fmaps 14 from changedetection.utils.layer_select import feature_selector_cva, feature_selector_cva_with_nspp, two_feature_selector_cva_aspp 15 from changedetection.utils.fusion_maria import fusion 16 import matplotlib.pyplot as plt 17 from changedetection.utils.visualize import create_rgb 18 from numpy import expand_dims 19 from skimage.filters import threshold_otsu, threshold_triangle 20 from skimage.morphology import remove_small_objects 21 import glob 22 import os 23 from skimage import io 24 25 #hide the gpus for timing the application 26 os.environ['CUDA_VISIBLE_DEVICES'] = '1' 27 28 def calculate_distancemap(f1, f2): 29 30 dist_per_fmap= [(f2[i,:,:]-f1[i,:,:])**2 for i in range(f1.shape[0])] 31 32 return np.sqrt(sum(dist_per_fmap)) 33 34 def read_rasters(path): 35 im_names = glob.glob(os.path.join(path,'*B04.tif')) # search for files with names containing 'B04.tif' 36 print(im_names) 37 r = io.imread(im_names[0]) 38 print(im_names) 39 im_names = glob.glob(os.path.join(path,'*B03.tif')) # search for files with names containing 'B03.tif' 40 g = io.imread(im_names[0]) 41 print(im_names) 42 im_names = glob.glob(os.path.join(path,'*B02.tif')) # search for files with names containing 'B02.tif' 43 b = io.imread(im_names[0]) 44 print(im_names) 45 I = np.stack((r,g,b),axis=2).astype('float') 46 return I 47 48 def save_and_visualize_feature_maps(feature_maps, feature_model, source_model, layers_of_interest, img1, img2, feature_maps_dir): 49 # Predict feature maps and classification/prediction masks 50 feature_maps = feature_model.predict([img1, img2]) 51 y_pred_conv = source_model.predict([img1, img2]) 52 y_pred_conv = np.argmax(y_pred_conv, axis=3) # Assuming a classification task 53 54 # Directory handling 55 if not os.path.exists(feature_maps_dir): 56 os.makedirs(feature_maps_dir) 57 58 # Setting up the plot for initial inputs and prediction results 59 fig, axes = plt.subplots(1, 4, figsize=(20, 5)) # Added more plots to visualize all necessary components 60 61 # Showing the first image in the batch 62 example_index = 0 # assuming batch size or single image 63 64 # Displaying the original images 65 axes[0].imshow(img1[example_index]) # Assuming img1 is properly preprocessed if necessary 66 axes[0].set_title('Input Image 1') 67 axes[0].axis('off') 68 69 axes[1].imshow(img2[example_index]) # Assuming img2 is properly preprocessed if necessary 70 axes[1].set_title('Input Image 2') 71 axes[1].axis('off') 72 73 # Displaying the predicted mask 74 axes[2].imshow(y_pred_conv[example_index], cmap='gray') 75 axes[2].set_title('Predicted Mask') 76 axes[2].axis('off') 77 78 # Saving the initial inputs visualization 79 plt.suptitle('Initial Inputs and Predicted Mask') 80 plt.savefig(os.path.join(feature_maps_dir, 'inputs_and_predictions.png')) 81 plt.show() 82 83 # Visualize and save feature maps 84 for fmap, layer_name in zip(feature_maps, layers_of_interest): 85 layer_dir = os.path.join(feature_maps_dir, layer_name) 86 if not os.path.exists(layer_dir): 87 os.makedirs(layer_dir) 88 89 num_filters = fmap.shape[-1] 90 cols = 8 # Define columns 91 rows = min(4, (num_filters + cols - 1) // cols) # Define rows 92 fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows)) 93 axes = axes.flatten() 94 95 # Plot each feature map 96 for i in range(min(num_filters, 32)): 97 if i < len(axes): 98 ax = axes[i] 99 feature_image = fmap[example_index, :, :, i] 100 ax.imshow(feature_image, cmap='viridis') 101 ax.set_title(f'Filter {i}', fontsize=10) 102 ax.axis('off') 103 plt.savefig(os.path.join(layer_dir, f'filter_{i}_example_{example_index}.png')) 104 105 # Turn off unused axes 106 for j in range(i + 1, len(axes)): 107 axes[j].axis('off') 108 109 plt.tight_layout() 110 plt.suptitle(f'Feature Maps from Layer: {layer_name}', fontsize=16) 111 plt.show() 112 113 # Save each feature map as an NPY file 114 for i in range(num_filters): 115 npy_filename = os.path.join(layer_dir, f'filter_{i}_example_{example_index}.npy') 116 np.save(npy_filename, fmap[example_index, :, :, i]) 117 118 119 def change_detection(dataset_name, input_rasters): 120 121 122 # Load images 123 imgs = [] 124 for timeframe in input_rasters: 125 print(timeframe) 126 imgs.append(read_rasters(timeframe)) 127 img1_or = imgs[0] 128 img2_or = imgs[1] 129 130 131 132 # Load change detection model 133 path = f'/home/dvalsamis/Documents/data/Onera/Predictions/Depth_2/Trainable_Onera/{dataset_name}' 134 # Load change detection model 135 saved_model = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/CD_Simple_Onera_0a8d.h5' 136 137 138 # Check if the folder exists 139 if not os.path.exists(path): 140 # Create the folder 141 os.makedirs(path) 142 143 shape = img1_or.shape 144 depth = 2 145 dropout = 0.1 146 decay = 0.0001 147 ImageSize_X = shape[0] 148 ImageSize_Y = shape[1] 149 n_ch = shape[2] 150 151 source_model = conv_classifier_two_with_aspp(depth, dropout, decay, ImageSize_X, ImageSize_Y, n_ch) 152 153 cd_model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/CD_Simple_CBMI_c141.h5' 154 source_model.load_weights(cd_model_path) 155 156 157 branch_model,feature_model = two_branch_cva_with_aspp_fmaps(dropout, decay, depth, ImageSize_X, ImageSize_Y, n_ch) 158 branch_model = two_feature_selector_cva_aspp(depth, source_model, branch_model) 159 160 #predictions 161 img1 = expand_dims(img1_or, axis=0) 162 img1 = (img1 - img1.mean()) / img1.std() 163 img2 = expand_dims(img2_or, axis=0) 164 img2 = (img2 - img2.mean()) / img2.std() 165 166 167 layers_of_interest = ['aspp_reduced_relu', 'abs_diff_2','reduced_aspp_output','output'] 168 169 170 171 # Example usage 172 feature_maps_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/fmaps/fmaps_total_0' 173 layers_of_interest = ['aspp_reduced_relu', 'abs_diff_2','reduced_aspp_output','output'] 174 175 save_and_visualize_feature_maps(None, feature_model, source_model, layers_of_interest, img1, img2, feature_maps_dir) 176 177 178 # Read the dataset names from the all.txt file 179 180 dataset_names_file = '/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/all.txt' 181 182 #dataset_names_file = '/home/aleoikon/Documents/data/onera/Onera Satellite Change Detection dataset - Images/all.txt' 183 184 185 with open(dataset_names_file, 'r') as file: 186 dataset_names_line = file.readline().strip() # Read the line containing all dataset names 187 dataset_names = dataset_names_line.split(',') # Split the line into individual dataset names 188 189 # Iterate over dataset names 190 for dataset_name in dataset_names: 191 try: 192 # Define input rasters for the current dataset 193 input_rasters = [ 194 195 f'/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/{dataset_name}//img1_cropped/', 196 f'/home/dvalsamis/Documents/data/CBMI/CBMI_0.3/CBMI 0.3_initial/{dataset_name}//img2_cropped/' 197 ] 198 199 # input_rasters = [ 200 # f'/home/aleoikon/Documents/data/onera/Onera Satellite Change Detection dataset - Images//{dataset_name}/imgs_1/', 201 # f'/home/aleoikon/Documents/data/onera/Onera Satellite Change Detection dataset - Images//{dataset_name}/imgs_2/' 202 # ] 203 # Perform change detection for the current dataset 204 change_mask = change_detection(dataset_name, input_rasters) 205 206 except Exception as e: 207 print(f"Error occurred in dataset: {dataset_name}") 208 print(f"Error message: {str(e)}") 209 210 211 212 213 print("Done")