/ FCSiam / fcsiam_simpple_k_fold_boxplot.py
fcsiam_simpple_k_fold_boxplot.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import torch.optim as optim
  7  import numpy as np
  8  import pandas as pd
  9  from sklearn.model_selection import KFold
 10  from torch.utils.data import DataLoader, TensorDataset, Subset
 11  import uuid
 12  import matplotlib.pyplot as plt
 13  from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
 14  from FCSiam.fcsiamConc import FCSiamConc
 15  from utils.log_params import log_params_sim1
 16  
 17  import os
 18  
 19  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 20  
 21  # Define helper functions
 22  def create_rgb_onera(x, channel):
 23      if channel == 'red':
 24          r = x[:, :, 2]
 25          r = np.expand_dims(r, axis=2)
 26          return r
 27      if channel == 'green':
 28          g = x[:, :, 1]
 29          g = np.expand_dims(g, axis=2)
 30          return g
 31      if channel == 'blue':
 32          b = x[:, :, 0]
 33          b = np.expand_dims(b, axis=2)
 34          return b
 35      if channel == 'rgb':
 36          r = x[:, :, 2]
 37          g = x[:, :, 1]
 38          b = x[:, :, 0]
 39          rgb = np.dstack((r, g, b))
 40          return rgb
 41      if channel == 'rgbvnir':
 42          r = x[:, :, 2]
 43          g = x[:, :, 1]
 44          b = x[:, :, 0]
 45          vnir = x[:, :, 3]
 46          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 47          return rgbvnir
 48      else:
 49          print("NOT CORRECT CHANNELS")
 50          return x
 51  
 52  # Data Loading and Preparation
 53  onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'
 54  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 55  
 56  train = pd.read_csv(onera_train_target + "dataset_train.csv")
 57  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 58  
 59  train = train.sample(frac=1, random_state=1)
 60  test = test.sample(frac=1, random_state=1)
 61  print("Train Data", len(train))
 62  print("Test Data", len(test))
 63  
 64  n_ch = 3
 65  channel = 'rgb'  # Set the channel according to your requirement
 66  
 67  # Load training data
 68  X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 69  X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 70  y_train = np.ndarray(shape=(len(train), 96, 96))
 71  
 72  pos = 0
 73  for index in train.index:
 74      img1 = np.load(onera_train_target + train['pair1'][index])
 75      img2 = np.load(onera_train_target + train['pair2'][index])
 76      X1 = create_rgb_onera(img1, channel)
 77      X2 = create_rgb_onera(img2, channel)
 78      X1 = (X1 - X1.mean()) / X1.std()
 79      X2 = (X2 - X2.mean()) / X2.std()
 80      X_train1[pos] = X1
 81      X_train2[pos] = X2
 82      y_train[pos] = np.load(onera_train_target + train['change_mask'][index])
 83      pos += 1
 84  
 85  y_train = np.expand_dims(y_train, axis=1)
 86  
 87  # Load test data
 88  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 89  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 90  y_test = np.ndarray(shape=(len(test), 96, 96))
 91  
 92  pos = 0
 93  for index in test.index:
 94      img1 = np.load(onera_test_target + test['pair1'][index])
 95      img2 = np.load(onera_test_target + test['pair2'][index])
 96      X1 = create_rgb_onera(img1, channel)
 97      X2 = create_rgb_onera(img2, channel)
 98      X1 = (X1 - X1.mean()) / X1.std()
 99      X2 = (X2 - X2.mean()) / X2.std()
100      X_test1[pos] = X1
101      X_test2[pos] = X2
102      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
103      pos += 1
104  
105  y_test = np.expand_dims(y_test, axis=1)
106  
107  # Create DataLoaders
108  train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(),
109                             torch.tensor(X_train2).permute(0, 3, 1, 2).float(),
110                             torch.tensor(y_train).float())
111  test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(),
112                            torch.tensor(X_test2).permute(0, 3, 1, 2).float(),
113                            torch.tensor(y_test).float())
114  
115  # Initialize k-fold cross-validation
116  kf = KFold(n_splits=5)
117  
118  # Training and evaluation loop
119  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
120  num_epochs = 30
121  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
122  
123  fold_results = []
124  
125  # Repeat the procedure for 64 times
126  for run in range(64):
127      print(f"RUN {run+1}")
128      run_results = []
129      for fold, (train_idx, val_idx) in enumerate(kf.split(train_data)):
130          print(f"FOLD {fold+1}")
131          print("--------------------------------")
132  
133          # Sample elements randomly from a given list of ids, no replacement.
134          train_subsampler = Subset(train_data, train_idx)
135          val_subsampler = Subset(train_data, val_idx)
136  
137          # Define data loaders for training and validation
138          train_loader = DataLoader(train_subsampler, batch_size=16, shuffle=True)
139          val_loader = DataLoader(val_subsampler, batch_size=16, shuffle=False)
140  
141          model = FCSiamConc().to(device)
142          criterion = nn.BCEWithLogitsLoss()
143          optimizer = optim.Adam(model.parameters(), lr=1e-4)
144  
145          model_id = uuid.uuid4().hex[:4]
146          cd_model_name = f"FCSiamConc_CBMI_{model_id}_fold{fold+1}.pth"
147  
148          for epoch in range(num_epochs):
149              model.train()
150              running_loss = 0.0
151              for i, (inputs1, inputs2, labels) in enumerate(train_loader):
152                  inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
153                  optimizer.zero_grad()
154                  outputs = model(torch.stack((inputs1, inputs2), dim=1))
155                  loss = criterion(outputs, labels)
156                  loss.backward()
157                  optimizer.step()
158                  running_loss += loss.item()
159  
160              print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
161  
162          # Save the model after training
163          save_path = os.path.join(model_path, cd_model_name)
164          torch.save(model.state_dict(), save_path)
165          print(f"Model saved to {save_path}")
166  
167          # Evaluate the model on the validation set
168          model.eval()
169          all_labels = []
170          all_predictions = []
171  
172          with torch.no_grad():
173              for inputs1, inputs2, labels in val_loader:
174                  inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
175                  outputs = model(torch.stack((inputs1, inputs2), dim=1))
176                  predicted = (outputs > 0.5).float()
177                  
178                  all_labels.extend(labels.cpu().numpy().flatten())
179                  all_predictions.extend(predicted.cpu().numpy().flatten())
180  
181          # Convert to numpy arrays for metric calculations
182          all_labels = np.array(all_labels)
183          all_predictions = np.array(all_predictions)
184  
185          # Calculate confusion matrix
186          tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
187  
188          # Calculate metrics
189          accuracy = (tp + tn) / (tp + tn + fp + fn)
190          recall = recall_score(all_labels, all_predictions)
191          precision = precision_score(all_labels, all_predictions)
192          f1 = f1_score(all_labels, all_predictions)
193          specificity = tn / (tn + fp)
194  
195          # Store results for this fold
196          run_results.append({
197              'fold': fold+1,
198              'accuracy': accuracy,
199              'recall': recall,
200              'precision': precision,
201              'f1': f1,
202              'specificity': specificity
203          })
204  
205      # Store the run results
206      fold_results.append(run_results)
207  
208  # Calculate and log average metrics for all runs
209  all_avg_metrics = {
210      'accuracy': [],
211      'recall': [],
212      'precision': [],
213      'f1': [],
214      'specificity': []
215  }
216  
217  for run_results in fold_results:
218      avg_metrics = {
219          'accuracy': np.mean([result['accuracy'] for result in run_results]),
220          'recall': np.mean([result['recall'] for result in run_results]),
221          'precision': np.mean([result['precision'] for result in run_results]),
222          'f1': np.mean([result['f1'] for result in run_results]),
223          'specificity': np.mean([result['specificity'] for result in run_results])
224      }
225      for key in all_avg_metrics:
226          all_avg_metrics[key].append(avg_metrics[key])
227  
228  log_params_sim1("Boxplot", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", 
229                  np.mean(all_avg_metrics['recall']), np.mean(all_avg_metrics['specificity']), np.mean(all_avg_metrics['precision']), np.mean(all_avg_metrics['f1']), np.mean(all_avg_metrics['accuracy']), 
230                  "CBMI Set", 96, " ", "none", "k-fold cross-validation")
231  
232  print("Average Metrics:")
233  print({k: np.mean(v) for k, v in all_avg_metrics.items()})
234  
235  # Plot the box plot
236  plt.figure(figsize=(10, 7))
237  plt.boxplot([all_avg_metrics['accuracy'], all_avg_metrics['recall'], all_avg_metrics['precision'], all_avg_metrics['f1'], all_avg_metrics['specificity']],
238              labels=['Accuracy', 'Recall', 'Precision', 'F1 Score', 'Specificity'])
239  plt.title('Box Plot of Metrics over 64 Runs')
240  plt.ylabel('Score')
241  plt.show()
242  
243  # Save the plot
244  output_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/boxplots/Conc_boxplot_metrics_64_runs.png'
245  plt.savefig(output_path)
246  print(f"Box plot saved to {output_path}")