simple_cd_k_fold.py
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Created on Wed Mar 9 20:15:48 2022 5 6 @author: aleoikon 7 """ 8 9 import sys 10 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 11 import time 12 13 from architectures.similarity_detection import pretext_task_one_nopool, pretext_task_one_aspp 14 import tensorflow 15 from tensorflow.keras import layers, Model 16 from architectures.branch import branches_triplet 17 from tensorflow.keras.optimizers import Adam 18 from tensorflow.keras.utils import plot_model 19 20 import pandas as pd 21 import numpy as np 22 import os 23 from tensorflow import keras 24 from architectures.conv_classifier import conv_classifier_two, conv_classifier_two_with_nspp, conv_classifier_two_with_aspp 25 from utils.layer_select import feature_selector, feature_selector_simple, transfer_learning_model 26 from utils.my_metrics import recall, accuracy, specificity, precision, f_measure, get_confusion_matrix 27 from utils.log_params import log_params_sim1 28 29 import matplotlib 30 matplotlib.use('TkAgg') 31 import matplotlib.pyplot as plt 32 import uuid 33 import random 34 from utils.weighted_cross_entropy import weighted_categorical_crossentropy 35 from sklearn.model_selection import KFold 36 37 os.environ["CUDA_VISIBLE_DEVICES"] = "1" 38 39 channel = 'rgb' 40 41 # Set random seed for TensorFlow 42 tensorflow.random.set_seed(1234) 43 44 # Set random seed for NumPy 45 np.random.seed(1234) 46 47 48 def generate_short_id(): 49 # Generate a UUID 50 unique_id = uuid.uuid4() 51 52 # Convert UUID to a hex string and take the first 4 characters 53 short_id = str(unique_id.hex)[:4] 54 55 return short_id 56 57 def create_rgb_onera(x, channel): 58 if channel == 'red': 59 r = x[:,:,2] 60 r = np.expand_dims(r, axis=2) 61 return r 62 if channel == 'green': 63 g = x[:,:,1] 64 g = np.expand_dims(g, axis=2) 65 return g 66 if channel == 'blue': 67 b = x[:,:,0] 68 b = np.expand_dims(b, axis=2) 69 return b 70 if channel == 'rgb': 71 r = x[:,:,2] 72 g = x[:,:,1] 73 b = x[:,:,0] 74 rgb = np.dstack((r, g, b)) 75 return rgb 76 if channel == 'rgbvnir': 77 r = x[:,:,2] 78 g = x[:,:,1] 79 b = x[:,:,0] 80 vnir = x[:,:,3] 81 rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float') 82 return rgbvnir 83 else: 84 return x 85 print("NOT CORRECT CHANNELS") 86 87 def generate_short_id(): 88 # Generate a UUID 89 unique_id = uuid.uuid4() 90 91 # Convert UUID to a hex string and take the first 4 characters 92 short_id = str(unique_id.hex)[:4] 93 94 return short_id 95 96 # Data Loading and Preparation 97 98 onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/' 99 onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/' 100 101 train = pd.read_csv(onera_train_target + "dataset_train.csv") 102 test = pd.read_csv(onera_test_target + "dataset_test.csv") 103 104 train = train.sample(frac=1, random_state=1) 105 test = test.sample(frac=1, random_state=1) 106 print("Train Data", len(train)) 107 print("Test Data", len(test)) 108 109 NORM = True 110 n_ch = 3 111 112 # Load everything in memory 113 X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch)) 114 X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch)) 115 y_train = np.ndarray(shape=(len(train), 96, 96)) 116 117 pos = 0 118 for index in train.index: 119 img1 = np.load(onera_train_target + train['pair1'][index]) 120 img2 = np.load(onera_train_target + train['pair2'][index]) 121 X1 = create_rgb_onera(img1, channel) 122 X2 = create_rgb_onera(img2, channel) 123 X1 = (X1 - X1.mean()) / X1.std() 124 X2 = (X2 - X2.mean()) / X2.std() 125 X_train1[pos] = X1 126 X_train2[pos] = X2 127 y_train[pos] = np.load(onera_train_target + train['change_mask'][index]) 128 129 pos += 1 130 131 ##### see the ration of 1 to 0s 132 train_balance = y_train.flatten() 133 (unique, counts) = np.unique(train_balance, return_counts=True) 134 frequencies = np.asarray((unique, counts)).T 135 print(frequencies[0][1] / frequencies[1][1]) 136 137 # One-hot encode the combined labels 138 y_hot_train = keras.utils.to_categorical(y_train, num_classes=2) 139 140 X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 141 X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 142 y_test = np.ndarray(shape=(len(test), 96, 96)) 143 144 pos = 0 145 for index in test.index: 146 img1 = np.load(onera_test_target + test['pair1'][index]) 147 img2 = np.load(onera_test_target + test['pair2'][index]) 148 X1 = create_rgb_onera(img1, channel) 149 X2 = create_rgb_onera(img2, channel) 150 X1 = (X1 - X1.mean()) / X1.std() 151 X2 = (X2 - X2.mean()) / X2.std() 152 X_test1[pos] = X1 153 X_test2[pos] = X2 154 y_test[pos] = np.load(onera_test_target + test['change_mask'][index]) 155 pos += 1 156 157 # Combine X and y for k-fold 158 X1 = np.concatenate((X_train1, X_test1), axis=0) 159 X2 = np.concatenate((X_train2, X_test2), axis=0) 160 y = np.concatenate((y_train, y_test), axis=0) 161 162 # One-hot encode the combined labels 163 y_hot = keras.utils.to_categorical(y, num_classes=2) 164 165 # Initialize KFold 166 kf = KFold(n_splits=5, shuffle=True, random_state=1) 167 168 depth = 2 169 dropout = 0.1 170 decay = 0.0001 171 LEARNING_RATE = 0.001 172 EPOCHS = 55 173 BATCH_SIZE=5 174 175 176 fold_no = 1 177 results = [] 178 179 for train_index, val_index in kf.split(X1): 180 print(f'Training fold {fold_no}...') 181 182 # Split the data 183 X_train1_fold, X_val1_fold = X1[train_index], X1[val_index] 184 X_train2_fold, X_val2_fold = X2[train_index], X2[val_index] 185 y_train_fold, y_val_fold = y_hot[train_index], y_hot[val_index] 186 187 # Create a new instance of the model 188 cd_model = conv_classifier_two_with_aspp(depth, dropout, decay, 96, 96, n_ch) 189 190 # Load pretext model weights 191 sim_model = pretext_task_one_nopool(dropout, decay, 96, 96, n_ch) 192 pretext_model_name = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/model_pretext1_unclouded_results.h5' 193 pretext_model = 'model_pretext1_unclouded_results' 194 sim_model.load_weights(pretext_model_name) 195 196 # Feature selection(task1) 197 cd_model = feature_selector_simple(depth, sim_model, cd_model) 198 199 wx = 0.1 200 wy = 0.2 201 weights = np.array([wx, wy]) 202 203 optimizer = Adam(learning_rate=LEARNING_RATE) 204 cd_model.compile(optimizer=optimizer, loss=weighted_categorical_crossentropy(weights), metrics=['accuracy']) 205 206 # Train the model 207 history = cd_model.fit( 208 [X_train1_fold, X_train2_fold], 209 y_train_fold, 210 validation_data=([X_val1_fold, X_val2_fold], y_val_fold), 211 batch_size=BATCH_SIZE, 212 epochs=EPOCHS 213 ) 214 215 # Evaluate the model 216 predictions = cd_model.predict([X_val1_fold, X_val2_fold]) 217 y_pred = np.argmax(predictions, axis=3) 218 y_true = np.argmax(y_val_fold, axis=3) 219 220 acc = accuracy(y_true, y_pred) 221 spec = specificity(y_true, y_pred) 222 rec = recall(y_true, y_pred) 223 prec = precision(y_true, y_pred) 224 f_m = f_measure(y_true, y_pred) 225 226 results.append({ 227 'fold': fold_no, 228 'accuracy': acc, 229 'specificity': spec, 230 'recall': rec, 231 'precision': prec, 232 'f1_score': f_m 233 }) 234 235 weight_par = '[' + str(wx) + ',' + str(wy) + ']' 236 log_params_sim1("Task 1 (T)", "Linear", 'ASPP', weight_par, depth, "Softmax", LEARNING_RATE, 'Adam', EPOCHS, 'weighted_categorical_crossentropy', BATCH_SIZE, rec, spec, prec, f_m, acc, "CBMI Set", 96, NORM, pretext_model, "CD_Simple_Sysu_d8d2") 237 238 fold_no += 1 239 240 # Print results 241 for result in results: 242 print(f"Fold {result['fold']}:") 243 print(f" Accuracy: {result['accuracy']:.4f}") 244 print(f" Specificity: {result['specificity']:.4f}") 245 print(f" Recall: {result['recall']:.4f}") 246 print(f" Precision: {result['precision']:.4f}") 247 print(f" F1 Score: {result['f1_score']:.4f}") 248 249 # Save the model after training on all folds 250 model_id = generate_short_id() 251 cd_model_name = "CD_Simple_" + "CBMI_" + model_id + ".h5" 252 model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/' 253 254 255 256 cd_model.save_weights(model_path + cd_model_name) 257 weight_par = '[' + str(wx) + ',' + str(wy) + ']' 258 log_params_sim1("Task 1 (T)", "Linear", 'ASPP', weight_par, depth, "Softmax", LEARNING_RATE, 'Adam', EPOCHS, 'weighted_categorical_crossentropy', BATCH_SIZE, rec, spec, prec, f_m, acc, "CBMI Set", 96, NORM, pretext_model, "CD_Simple_Sysu_d8d2") 259 260 print("Saved model to disk") 261 262 263 print("Done")