fcsiam_simpple_k_fold_boxplot.py
1 import sys 2 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 3 4 import torch 5 import torch.nn as nn 6 import torch.optim as optim 7 import numpy as np 8 import pandas as pd 9 from sklearn.model_selection import KFold 10 from torch.utils.data import DataLoader, TensorDataset, Subset 11 import uuid 12 import matplotlib.pyplot as plt 13 from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score 14 from FCSiam.fcsiamConc import FCSiamConc 15 from utils.log_params import log_params_sim1 16 17 import os 18 19 os.environ["CUDA_VISIBLE_DEVICES"] = "1" 20 21 # Define helper functions 22 def create_rgb_onera(x, channel): 23 if channel == 'red': 24 r = x[:, :, 2] 25 r = np.expand_dims(r, axis=2) 26 return r 27 if channel == 'green': 28 g = x[:, :, 1] 29 g = np.expand_dims(g, axis=2) 30 return g 31 if channel == 'blue': 32 b = x[:, :, 0] 33 b = np.expand_dims(b, axis=2) 34 return b 35 if channel == 'rgb': 36 r = x[:, :, 2] 37 g = x[:, :, 1] 38 b = x[:, :, 0] 39 rgb = np.dstack((r, g, b)) 40 return rgb 41 if channel == 'rgbvnir': 42 r = x[:, :, 2] 43 g = x[:, :, 1] 44 b = x[:, :, 0] 45 vnir = x[:, :, 3] 46 rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float') 47 return rgbvnir 48 else: 49 print("NOT CORRECT CHANNELS") 50 return x 51 52 # Data Loading and Preparation 53 onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/' 54 onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/' 55 56 train = pd.read_csv(onera_train_target + "dataset_train.csv") 57 test = pd.read_csv(onera_test_target + "dataset_test.csv") 58 59 train = train.sample(frac=1, random_state=1) 60 test = test.sample(frac=1, random_state=1) 61 print("Train Data", len(train)) 62 print("Test Data", len(test)) 63 64 n_ch = 3 65 channel = 'rgb' # Set the channel according to your requirement 66 67 # Load training data 68 X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch)) 69 X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch)) 70 y_train = np.ndarray(shape=(len(train), 96, 96)) 71 72 pos = 0 73 for index in train.index: 74 img1 = np.load(onera_train_target + train['pair1'][index]) 75 img2 = np.load(onera_train_target + train['pair2'][index]) 76 X1 = create_rgb_onera(img1, channel) 77 X2 = create_rgb_onera(img2, channel) 78 X1 = (X1 - X1.mean()) / X1.std() 79 X2 = (X2 - X2.mean()) / X2.std() 80 X_train1[pos] = X1 81 X_train2[pos] = X2 82 y_train[pos] = np.load(onera_train_target + train['change_mask'][index]) 83 pos += 1 84 85 y_train = np.expand_dims(y_train, axis=1) 86 87 # Load test data 88 X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 89 X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 90 y_test = np.ndarray(shape=(len(test), 96, 96)) 91 92 pos = 0 93 for index in test.index: 94 img1 = np.load(onera_test_target + test['pair1'][index]) 95 img2 = np.load(onera_test_target + test['pair2'][index]) 96 X1 = create_rgb_onera(img1, channel) 97 X2 = create_rgb_onera(img2, channel) 98 X1 = (X1 - X1.mean()) / X1.std() 99 X2 = (X2 - X2.mean()) / X2.std() 100 X_test1[pos] = X1 101 X_test2[pos] = X2 102 y_test[pos] = np.load(onera_test_target + test['change_mask'][index]) 103 pos += 1 104 105 y_test = np.expand_dims(y_test, axis=1) 106 107 # Create DataLoaders 108 train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(), 109 torch.tensor(X_train2).permute(0, 3, 1, 2).float(), 110 torch.tensor(y_train).float()) 111 test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(), 112 torch.tensor(X_test2).permute(0, 3, 1, 2).float(), 113 torch.tensor(y_test).float()) 114 115 # Initialize k-fold cross-validation 116 kf = KFold(n_splits=5) 117 118 # Training and evaluation loop 119 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 120 num_epochs = 30 121 model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/' 122 123 fold_results = [] 124 125 # Repeat the procedure for 64 times 126 for run in range(64): 127 print(f"RUN {run+1}") 128 run_results = [] 129 for fold, (train_idx, val_idx) in enumerate(kf.split(train_data)): 130 print(f"FOLD {fold+1}") 131 print("--------------------------------") 132 133 # Sample elements randomly from a given list of ids, no replacement. 134 train_subsampler = Subset(train_data, train_idx) 135 val_subsampler = Subset(train_data, val_idx) 136 137 # Define data loaders for training and validation 138 train_loader = DataLoader(train_subsampler, batch_size=16, shuffle=True) 139 val_loader = DataLoader(val_subsampler, batch_size=16, shuffle=False) 140 141 model = FCSiamConc().to(device) 142 criterion = nn.BCEWithLogitsLoss() 143 optimizer = optim.Adam(model.parameters(), lr=1e-4) 144 145 model_id = uuid.uuid4().hex[:4] 146 cd_model_name = f"FCSiamConc_CBMI_{model_id}_fold{fold+1}.pth" 147 148 for epoch in range(num_epochs): 149 model.train() 150 running_loss = 0.0 151 for i, (inputs1, inputs2, labels) in enumerate(train_loader): 152 inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device) 153 optimizer.zero_grad() 154 outputs = model(torch.stack((inputs1, inputs2), dim=1)) 155 loss = criterion(outputs, labels) 156 loss.backward() 157 optimizer.step() 158 running_loss += loss.item() 159 160 print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}") 161 162 # Save the model after training 163 save_path = os.path.join(model_path, cd_model_name) 164 torch.save(model.state_dict(), save_path) 165 print(f"Model saved to {save_path}") 166 167 # Evaluate the model on the validation set 168 model.eval() 169 all_labels = [] 170 all_predictions = [] 171 172 with torch.no_grad(): 173 for inputs1, inputs2, labels in val_loader: 174 inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device) 175 outputs = model(torch.stack((inputs1, inputs2), dim=1)) 176 predicted = (outputs > 0.5).float() 177 178 all_labels.extend(labels.cpu().numpy().flatten()) 179 all_predictions.extend(predicted.cpu().numpy().flatten()) 180 181 # Convert to numpy arrays for metric calculations 182 all_labels = np.array(all_labels) 183 all_predictions = np.array(all_predictions) 184 185 # Calculate confusion matrix 186 tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel() 187 188 # Calculate metrics 189 accuracy = (tp + tn) / (tp + tn + fp + fn) 190 recall = recall_score(all_labels, all_predictions) 191 precision = precision_score(all_labels, all_predictions) 192 f1 = f1_score(all_labels, all_predictions) 193 specificity = tn / (tn + fp) 194 195 # Store results for this fold 196 run_results.append({ 197 'fold': fold+1, 198 'accuracy': accuracy, 199 'recall': recall, 200 'precision': precision, 201 'f1': f1, 202 'specificity': specificity 203 }) 204 205 # Store the run results 206 fold_results.append(run_results) 207 208 # Calculate and log average metrics for all runs 209 all_avg_metrics = { 210 'accuracy': [], 211 'recall': [], 212 'precision': [], 213 'f1': [], 214 'specificity': [] 215 } 216 217 for run_results in fold_results: 218 avg_metrics = { 219 'accuracy': np.mean([result['accuracy'] for result in run_results]), 220 'recall': np.mean([result['recall'] for result in run_results]), 221 'precision': np.mean([result['precision'] for result in run_results]), 222 'f1': np.mean([result['f1'] for result in run_results]), 223 'specificity': np.mean([result['specificity'] for result in run_results]) 224 } 225 for key in all_avg_metrics: 226 all_avg_metrics[key].append(avg_metrics[key]) 227 228 log_params_sim1("Boxplot", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", 229 np.mean(all_avg_metrics['recall']), np.mean(all_avg_metrics['specificity']), np.mean(all_avg_metrics['precision']), np.mean(all_avg_metrics['f1']), np.mean(all_avg_metrics['accuracy']), 230 "CBMI Set", 96, " ", "none", "k-fold cross-validation") 231 232 print("Average Metrics:") 233 print({k: np.mean(v) for k, v in all_avg_metrics.items()}) 234 235 # Plot the box plot 236 plt.figure(figsize=(10, 7)) 237 plt.boxplot([all_avg_metrics['accuracy'], all_avg_metrics['recall'], all_avg_metrics['precision'], all_avg_metrics['f1'], all_avg_metrics['specificity']], 238 labels=['Accuracy', 'Recall', 'Precision', 'F1 Score', 'Specificity']) 239 plt.title('Box Plot of Metrics over 64 Runs') 240 plt.ylabel('Score') 241 plt.show() 242 243 # Save the plot 244 output_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/boxplots/Conc_boxplot_metrics_64_runs.png' 245 plt.savefig(output_path) 246 print(f"Box plot saved to {output_path}")