/ gasnet / gas_Box.py
gas_Box.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import torch.optim as optim
  7  import numpy as np
  8  import pandas as pd
  9  from sklearn.model_selection import train_test_split
 10  from torch.utils.data import DataLoader, TensorDataset
 11  import uuid
 12  import matplotlib.pyplot as plt
 13  import seaborn as sns
 14  from gasnet.CDNet_L import CDNet_L
 15  from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
 16  #from torch.utils.tensorboard import SummaryWriter
 17  from utils.log_params import log_params_sim1
 18  
 19  import os
 20  
 21  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 22  
 23  # Define helper functions
 24  def create_rgb_onera(x, channel):
 25      if channel == 'red':
 26          r = x[:, :, 2]
 27          r = np.expand_dims(r, axis=2)
 28          return r
 29      if channel == 'green':
 30          g = x[:, :, 1]
 31          g = np.expand_dims(g, axis=2)
 32          return g
 33      if channel == 'blue':
 34          b = x[:, :, 0]
 35          b = np.expand_dims(b, axis=2)
 36          return b
 37      if channel == 'rgb':
 38          r = x[:, :, 2]
 39          g = x[:, :, 1]
 40          b = x[:, :, 0]
 41          rgb = np.dstack((r, g, b))
 42          return rgb
 43      if channel == 'rgbvnir':
 44          r = x[:, :, 2]
 45          g = x[:, :, 1]
 46          b = x[:, :, 0]
 47          vnir = x[:, :, 3]
 48          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 49          return rgbvnir
 50      else:
 51          print("NOT CORRECT CHANNELS")
 52          return x
 53  
 54  # Data Loading and Preparation
 55  
 56  onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'
 57  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 58  
 59  train = pd.read_csv(onera_train_target + "dataset_train.csv")
 60  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 61  
 62  train = train.sample(frac=1, random_state=1)
 63  test = test.sample(frac=1, random_state=1)
 64  print("Train Data", len(train))
 65  print("Test Data", len(test))
 66  
 67  n_ch = 3
 68  channel = 'rgb'  # Set the channel according to your requirement
 69  
 70  # Load training data
 71  X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 72  X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 73  y_train = np.ndarray(shape=(len(train), 96, 96))
 74  
 75  pos = 0
 76  for index in train.index:
 77      img1 = np.load(onera_train_target + train['pair1'][index])
 78      img2 = np.load(onera_train_target + train['pair2'][index])
 79      X1 = create_rgb_onera(img1, channel)
 80      X2 = create_rgb_onera(img2, channel)
 81      X1 = (X1 - X1.mean()) / X1.std()
 82      X2 = (X2 - X2.mean()) / X2.std()
 83      X_train1[pos] = X1
 84      X_train2[pos] = X2
 85      y_train[pos] = np.load(onera_train_target + train['change_mask'][index])
 86      pos += 1
 87  
 88  # Ensure target labels have the same shape as model output
 89  y_train = np.expand_dims(y_train, axis=1)
 90  
 91  # Load test data
 92  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 93  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 94  y_test = np.ndarray(shape=(len(test), 96, 96))
 95  
 96  pos = 0
 97  for index in test.index:
 98      img1 = np.load(onera_test_target + test['pair1'][index])
 99      img2 = np.load(onera_test_target + test['pair2'][index])
100      X1 = create_rgb_onera(img1, channel)
101      X2 = create_rgb_onera(img2, channel)
102      X1 = (X1 - X1.mean()) / X1.std()
103      X2 = (X2 - X2.mean()) / X2.std()
104      X_test1[pos] = X1
105      X_test2[pos] = X2
106      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
107      pos += 1
108  
109  # Ensure target labels have the same shape as model output
110  y_test = np.expand_dims(y_test, axis=1)
111  
112  # Create DataLoaders
113  train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(),
114                             torch.tensor(X_train2).permute(0, 3, 1, 2).float(),
115                             torch.tensor(y_train).float())
116  test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(),
117                            torch.tensor(X_test2).permute(0, 3, 1, 2).float(),
118                            torch.tensor(y_test).float())
119  
120  train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
121  test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
122  
123  
124  # Model Training and Evaluation
125  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
126  criterion = nn.BCELoss()
127  
128  
129  num_epochs = 30
130  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
131  
132  
133  # Initialize lists to store metrics
134  metrics_list = []
135  
136  for run in range(64):
137      model_id = uuid.uuid4().hex[:4]
138      cd_model_name = f"Gas_Net_CBMI_{model_id}.h5"
139      
140      model = CDNet_L().to(device)
141      optimizer = optim.Adam(model.parameters(), lr=1e-4)
142      
143      for epoch in range(num_epochs):
144          model.train()
145          running_loss = 0.0
146          for i, (inputs1, inputs2, labels) in enumerate(train_loader):
147              inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
148              optimizer.zero_grad()
149              outputs = model(inputs1, inputs2)
150              loss = criterion(outputs, labels)
151              loss.backward()
152              optimizer.step()
153              running_loss += loss.item()
154  
155      # Save the model after training
156      save_path = os.path.join(model_path, cd_model_name)
157      torch.save(model.state_dict(), save_path)
158      print(f"Model {run + 1} saved to {save_path}")
159      
160      # Evaluate the model
161      model.eval()
162      all_labels = []
163      all_predictions = []
164  
165      with torch.no_grad():
166          for inputs1, inputs2, labels in test_loader:
167              inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
168              outputs = model(inputs1, inputs2)
169              predicted = (outputs > 0.5).float()
170              all_labels.extend(labels.cpu().numpy().flatten())
171              all_predictions.extend(predicted.cpu().numpy().flatten())
172  
173      all_labels = np.array(all_labels)
174      all_predictions = np.array(all_predictions)
175      tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
176      
177      accuracy = (tp + tn) / (tp + tn + fp + fn)
178      recall = recall_score(all_labels, all_predictions)
179      precision = precision_score(all_labels, all_predictions)
180      f1 = f1_score(all_labels, all_predictions)
181      specificity = tn / (tn + fp)
182  
183      metrics_list.append({
184          'Model': run + 1,
185          'Recall': recall,
186          'Specificity': specificity,
187          'Precision': precision,
188          'F1': f1,
189          'Accuracy': accuracy
190      })
191  
192      log_params_sim1("Task 1", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", recall, specificity, precision, f1, accuracy, "CBMI Set", 96, " ", "none", cd_model_name)
193  
194  
195      print(f"Run {run + 1}: Recall: {recall:.4f}, Specificity: {specificity:.4f}, Precision: {precision:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}")
196  
197  # Convert metrics list to DataFrame
198  metrics_df = pd.DataFrame(metrics_list)
199  
200  # Save the metrics to a CSV file
201  metrics_df.to_csv('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/FCSiam/metrics_64_models.csv', index=False)
202  
203  # Plotting the box plot for all metrics
204  plt.figure(figsize=(10, 6))
205  sns.boxplot(data=metrics_df.drop(columns=['Model']), palette="Set2")
206  plt.title('Boxplot of Model Metrics for 64 Runs')
207  plt.ylabel('Values')
208  
209  # Save the plot
210  output_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/FCSiam/Box_plots/boxplot_metrics_64_runs.png'
211  plt.savefig(output_path)
212  print(f"Box plot saved to {output_path}")