/ FCSiam / fcsiam_simpple_k_fold.py
fcsiam_simpple_k_fold.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import torch.optim as optim
  7  import numpy as np
  8  import pandas as pd
  9  from sklearn.model_selection import KFold
 10  from torch.utils.data import DataLoader, TensorDataset, Subset
 11  import uuid
 12  import matplotlib.pyplot as plt
 13  from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
 14  from FCSiam.fcsiamConc import FCSiamConc
 15  from utils.log_params import log_params_sim1
 16  
 17  import os
 18  
 19  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 20  
 21  # Define helper functions
 22  def create_rgb_onera(x, channel):
 23      if channel == 'red':
 24          r = x[:, :, 2]
 25          r = np.expand_dims(r, axis=2)
 26          return r
 27      if channel == 'green':
 28          g = x[:, :, 1]
 29          g = np.expand_dims(g, axis=2)
 30          return g
 31      if channel == 'blue':
 32          b = x[:, :, 0]
 33          b = np.expand_dims(b, axis=2)
 34          return b
 35      if channel == 'rgb':
 36          r = x[:, :, 2]
 37          g = x[:, :, 1]
 38          b = x[:, :, 0]
 39          rgb = np.dstack((r, g, b))
 40          return rgb
 41      if channel == 'rgbvnir':
 42          r = x[:, :, 2]
 43          g = x[:, :, 1]
 44          b = x[:, :, 0]
 45          vnir = x[:, :, 3]
 46          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 47          return rgbvnir
 48      else:
 49          print("NOT CORRECT CHANNELS")
 50          return x
 51  
 52  # Data Loading and Preparation
 53  onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'
 54  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 55  
 56  train = pd.read_csv(onera_train_target + "dataset_train.csv")
 57  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 58  
 59  train = train.sample(frac=1, random_state=1)
 60  test = test.sample(frac=1, random_state=1)
 61  print("Train Data", len(train))
 62  print("Test Data", len(test))
 63  
 64  n_ch = 3
 65  channel = 'rgb'  # Set the channel according to your requirement
 66  
 67  # Load training data
 68  X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 69  X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 70  y_train = np.ndarray(shape=(len(train), 96, 96))
 71  
 72  pos = 0
 73  for index in train.index:
 74      img1 = np.load(onera_train_target + train['pair1'][index])
 75      img2 = np.load(onera_train_target + train['pair2'][index])
 76      X1 = create_rgb_onera(img1, channel)
 77      X2 = create_rgb_onera(img2, channel)
 78      X1 = (X1 - X1.mean()) / X1.std()
 79      X2 = (X2 - X2.mean()) / X2.std()
 80      X_train1[pos] = X1
 81      X_train2[pos] = X2
 82      y_train[pos] = np.load(onera_train_target + train['change_mask'][index])
 83      pos += 1
 84  
 85  y_train = np.expand_dims(y_train, axis=1)
 86  
 87  # Load test data
 88  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 89  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 90  y_test = np.ndarray(shape=(len(test), 96, 96))
 91  
 92  pos = 0
 93  for index in test.index:
 94      img1 = np.load(onera_test_target + test['pair1'][index])
 95      img2 = np.load(onera_test_target + test['pair2'][index])
 96      X1 = create_rgb_onera(img1, channel)
 97      X2 = create_rgb_onera(img2, channel)
 98      X1 = (X1 - X1.mean()) / X1.std()
 99      X2 = (X2 - X2.mean()) / X2.std()
100      X_test1[pos] = X1
101      X_test2[pos] = X2
102      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
103      pos += 1
104  
105  y_test = np.expand_dims(y_test, axis=1)
106  
107  # Create DataLoaders
108  train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(),
109                             torch.tensor(X_train2).permute(0, 3, 1, 2).float(),
110                             torch.tensor(y_train).float())
111  test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(),
112                            torch.tensor(X_test2).permute(0, 3, 1, 2).float(),
113                            torch.tensor(y_test).float())
114  
115  # Initialize k-fold cross-validation
116  kf = KFold(n_splits=10)
117  
118  # Training and evaluation loop
119  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
120  num_epochs = 30
121  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
122  
123  fold_results = []
124  
125  for fold, (train_idx, val_idx) in enumerate(kf.split(train_data)):
126      print(f"FOLD {fold+1}")
127      print("--------------------------------")
128  
129      # Sample elements randomly from a given list of ids, no replacement.
130      train_subsampler = Subset(train_data, train_idx)
131      val_subsampler = Subset(train_data, val_idx)
132  
133      # Define data loaders for training and validation
134      train_loader = DataLoader(train_subsampler, batch_size=16, shuffle=True)
135      val_loader = DataLoader(val_subsampler, batch_size=16, shuffle=False)
136  
137      model = FCSiamConc().to(device)
138      print("Model on GPU:", next(model.parameters()).is_cuda)
139      criterion = nn.BCEWithLogitsLoss()
140      optimizer = optim.Adam(model.parameters(), lr=1e-4)
141  
142      model_id = uuid.uuid4().hex[:4]
143      cd_model_name = f"FCSiamConc_CBMI_{model_id}_fold{fold+1}.pth"
144  
145      for epoch in range(num_epochs):
146          model.train()
147          running_loss = 0.0
148          for i, (inputs1, inputs2, labels) in enumerate(train_loader):
149              inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
150              print("Inputs on GPU:", inputs1.is_cuda, inputs2.is_cuda)
151              optimizer.zero_grad()
152              outputs = model(torch.stack((inputs1, inputs2), dim=1))
153              loss = criterion(outputs, labels)
154              loss.backward()
155              optimizer.step()
156              running_loss += loss.item()
157  
158          print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
159  
160      # Save the model after training
161      save_path = os.path.join(model_path, cd_model_name)
162      torch.save(model.state_dict(), save_path)
163      print(f"Model saved to {save_path}")
164  
165      # Evaluate the model on the validation set
166      model.eval()
167      all_labels = []
168      all_predictions = []
169  
170      with torch.no_grad():
171          for inputs1, inputs2, labels in val_loader:
172              inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
173              print("Validation Inputs on GPU:", inputs1.is_cuda, inputs2.is_cuda)
174              outputs = model(torch.stack((inputs1, inputs2), dim=1))
175              predicted = (outputs > 0.5).float()
176              
177              all_labels.extend(labels.cpu().numpy().flatten())
178              all_predictions.extend(predicted.cpu().numpy().flatten())
179  
180      # Convert to numpy arrays for metric calculations
181      all_labels = np.array(all_labels)
182      all_predictions = np.array(all_predictions)
183  
184      # Calculate confusion matrix
185      tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
186  
187      # Calculate metrics
188      accuracy = (tp + tn) / (tp + tn + fp + fn)
189      recall = recall_score(all_labels, all_predictions)
190      precision = precision_score(all_labels, all_predictions)
191      f1 = f1_score(all_labels, all_predictions)
192      specificity = tn / (tn + fp)
193  
194      # Store results for this fold
195      fold_results.append({
196          'fold': fold+1,
197          'accuracy': accuracy,
198          'recall': recall,
199          'precision': precision,
200          'f1': f1,
201          'specificity': specificity
202      })
203  
204  # Print fold results
205  for result in fold_results:
206      print(result)
207  
208  # Calculate and log average metrics
209  avg_metrics = {
210      'accuracy': np.mean([result['accuracy'] for result in fold_results]),
211      'recall': np.mean([result['recall'] for result in fold_results]),
212      'precision': np.mean([result['precision'] for result in fold_results]),
213      'f1': np.mean([result['f1'] for result in fold_results]),
214      'specificity': np.mean([result['specificity'] for result in fold_results])
215  }
216  
217  log_params_sim1("k-fold", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", 
218                  avg_metrics['recall'], avg_metrics['specificity'], avg_metrics['precision'], avg_metrics['f1'], avg_metrics['accuracy'], 
219                  "CBMI Set", 96, " ", "none", cd_model_name)
220  
221  print("Average Metrics:")
222  print(avg_metrics)