SiamConc_box.py
1 import sys 2 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 3 4 import sys 5 import torch 6 import torch.nn as nn 7 import torch.optim as optim 8 import numpy as np 9 import pandas as pd 10 from sklearn.model_selection import train_test_split 11 from torch.utils.data import DataLoader, TensorDataset 12 import uuid 13 import matplotlib.pyplot as plt 14 import seaborn as sns 15 from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score 16 from FCSiam.fcsiamConc import FCSiamConc 17 from utils.log_params import log_params_sim1 18 import os 19 20 # Set up environment 21 os.environ["CUDA_VISIBLE_DEVICES"] = "1" 22 23 # Define helper functions 24 def create_rgb_onera(x, channel): 25 if channel == 'red': 26 r = x[:, :, 2] 27 r = np.expand_dims(r, axis=2) 28 return r 29 if channel == 'green': 30 g = x[:, :, 1] 31 g = np.expand_dims(g, axis=2) 32 return g 33 if channel == 'blue': 34 b = x[:, :, 0] 35 b = np.expand_dims(b, axis=2) 36 return b 37 if channel == 'rgb': 38 r = x[:, :, 2] 39 g = x[:, :, 1] 40 b = x[:, :, 0] 41 rgb = np.dstack((r, g, b)) 42 return rgb 43 if channel == 'rgbvnir': 44 r = x[:, :, 2] 45 g = x[:, :, 1] 46 b = x[:, :, 0] 47 vnir = x[:, :, 3] 48 rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float') 49 return rgbvnir 50 else: 51 print("NOT CORRECT CHANNELS") 52 return x 53 54 # Data Loading and Preparation 55 onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/' 56 onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/' 57 58 train = pd.read_csv(onera_train_target + "dataset_train.csv") 59 test = pd.read_csv(onera_test_target + "dataset_test.csv") 60 61 train = train.sample(frac=1, random_state=1) 62 test = test.sample(frac=1, random_state=1) 63 print("Train Data", len(train)) 64 print("Test Data", len(test)) 65 66 n_ch = 3 67 channel = 'rgb' # Set the channel according to your requirement 68 69 # Load training data 70 X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch)) 71 X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch)) 72 y_train = np.ndarray(shape=(len(train), 96, 96)) 73 74 pos = 0 75 for index in train.index: 76 img1 = np.load(onera_train_target + train['pair1'][index]) 77 img2 = np.load(onera_train_target + train['pair2'][index]) 78 X1 = create_rgb_onera(img1, channel) 79 X2 = create_rgb_onera(img2, channel) 80 X1 = (X1 - X1.mean()) / X1.std() 81 X2 = (X2 - X2.mean()) / X2.std() 82 X_train1[pos] = X1 83 X_train2[pos] = X2 84 y_train[pos] = np.load(onera_train_target + train['change_mask'][index]) 85 pos += 1 86 87 y_train = np.expand_dims(y_train, axis=1) 88 89 # Load test data 90 X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 91 X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 92 y_test = np.ndarray(shape=(len(test), 96, 96)) 93 94 pos = 0 95 for index in test.index: 96 img1 = np.load(onera_test_target + test['pair1'][index]) 97 img2 = np.load(onera_test_target + test['pair2'][index]) 98 X1 = create_rgb_onera(img1, channel) 99 X2 = create_rgb_onera(img2, channel) 100 X1 = (X1 - X1.mean()) / X1.std() 101 X2 = (X2 - X2.mean()) / X2.std() 102 X_test1[pos] = X1 103 X_test2[pos] = X2 104 y_test[pos] = np.load(onera_test_target + test['change_mask'][index]) 105 pos += 1 106 107 y_test = np.expand_dims(y_test, axis=1) 108 109 # Create DataLoaders 110 train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(), 111 torch.tensor(X_train2).permute(0, 3, 1, 2).float(), 112 torch.tensor(y_train).float()) 113 test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(), 114 torch.tensor(X_test2).permute(0, 3, 1, 2).float(), 115 torch.tensor(y_test).float()) 116 117 train_loader = DataLoader(train_data, batch_size=16, shuffle=True) 118 test_loader = DataLoader(test_data, batch_size=16, shuffle=False) 119 120 log_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/events' 121 122 # Model Training and Evaluation 123 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 124 criterion = nn.BCEWithLogitsLoss() 125 126 num_epochs = 30 127 model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/' 128 129 130 # Initialize lists to store metrics 131 metrics_list = [] 132 133 for run in range(64): 134 model_id = uuid.uuid4().hex[:4] 135 cd_model_name = f"FCSiamConc_CBMI_{model_id}.h5" 136 137 model = FCSiamConc().to(device) 138 optimizer = optim.Adam(model.parameters(), lr=1e-4) 139 140 for epoch in range(num_epochs): 141 model.train() 142 running_loss = 0.0 143 for i, (inputs1, inputs2, labels) in enumerate(train_loader): 144 inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device) 145 optimizer.zero_grad() 146 outputs = model(torch.stack((inputs1, inputs2), dim=1)) 147 loss = criterion(outputs, labels) 148 loss.backward() 149 optimizer.step() 150 running_loss += loss.item() 151 152 # Save the model after training 153 save_path = os.path.join(model_path, cd_model_name) 154 torch.save(model.state_dict(), save_path) 155 print(f"Model {run + 1} saved to {save_path}") 156 157 # Evaluate the model 158 model.eval() 159 all_labels = [] 160 all_predictions = [] 161 162 with torch.no_grad(): 163 for inputs1, inputs2, labels in test_loader: 164 inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device) 165 outputs = model(torch.stack((inputs1, inputs2), dim=1)) 166 predicted = (outputs > 0.5).float() 167 all_labels.extend(labels.cpu().numpy().flatten()) 168 all_predictions.extend(predicted.cpu().numpy().flatten()) 169 170 all_labels = np.array(all_labels) 171 all_predictions = np.array(all_predictions) 172 tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel() 173 174 accuracy = (tp + tn) / (tp + tn + fp + fn) 175 recall = recall_score(all_labels, all_predictions) 176 precision = precision_score(all_labels, all_predictions) 177 f1 = f1_score(all_labels, all_predictions) 178 specificity = tn / (tn + fp) 179 180 181 metrics_list.append({ 182 'Model': run + 1, 183 'Recall': recall, 184 'Specificity': specificity, 185 'Precision': precision, 186 'F1': f1, 187 'Accuracy': accuracy 188 }) 189 190 191 log_params_sim1("Task 1", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", recall, specificity, precision, f1, accuracy, "CBMI Set", 96, " ", "none", cd_model_name) 192 193 print(f"Run {run + 1}: Recall: {recall:.4f}, Specificity: {specificity:.4f}, Precision: {precision:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}") 194 195 # Convert metrics list to DataFrame 196 metrics_df = pd.DataFrame(metrics_list) 197 198 # Save the metrics to a CSV file 199 metrics_df.to_csv('metrics_64_models.csv', index=False) 200 201 # Plotting the box plot for all metrics 202 plt.figure(figsize=(10, 6)) 203 sns.boxplot(data=metrics_df.drop(columns=['Model']), palette="Set2") 204 plt.title('Boxplot of Model Metrics for 64 Runs') 205 plt.ylabel('Values') 206 plt.show() 207 208 # Save the plot 209 output_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/FCSiam/Box_plots/Conc_boxplot_metrics_64_runs.png' 210 plt.savefig(output_path) 211 print(f"Box plot saved to {output_path}")