/ FCSiam / fcsiamDiff_simpple.py
fcsiamDiff_simpple.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import torch.optim as optim
  7  import numpy as np
  8  import pandas as pd
  9  from sklearn.model_selection import train_test_split
 10  from torch.utils.data import DataLoader, TensorDataset
 11  import uuid
 12  import matplotlib.pyplot as plt
 13  from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
 14  from FCSiam.fcsiaDiff import FCSiamDiff
 15  from utils.log_params import log_params_sim1
 16  
 17  import os
 18  
 19  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 20  
 21  # Define helper functions
 22  def create_rgb_onera(x, channel):
 23      if channel == 'red':
 24          r = x[:, :, 2]
 25          r = np.expand_dims(r, axis=2)
 26          return r
 27      if channel == 'green':
 28          g = x[:, :, 1]
 29          g = np.expand_dims(g, axis=2)
 30          return g
 31      if channel == 'blue':
 32          b = x[:, :, 0]
 33          b = np.expand_dims(b, axis=2)
 34          return b
 35      if channel == 'rgb':
 36          r = x[:, :, 2]
 37          g = x[:, :, 1]
 38          b = x[:, :, 0]
 39          rgb = np.dstack((r, g, b))
 40          return rgb
 41      if channel == 'rgbvnir':
 42          r = x[:, :, 2]
 43          g = x[:, :, 1]
 44          b = x[:, :, 0]
 45          vnir = x[:, :, 3]
 46          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 47          return rgbvnir
 48      else:
 49          print("NOT CORRECT CHANNELS")
 50          return x
 51  
 52  # Data Loading and Preparation
 53  onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'
 54  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 55  
 56  train = pd.read_csv(onera_train_target + "dataset_train.csv")
 57  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 58  
 59  train = train.sample(frac=1, random_state=1)
 60  test = test.sample(frac=1, random_state=1)
 61  print("Train Data", len(train))
 62  print("Test Data", len(test))
 63  
 64  n_ch = 3
 65  channel = 'rgb'  # Set the channel according to your requirement
 66  
 67  # Load training data
 68  X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 69  X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 70  y_train = np.ndarray(shape=(len(train), 96, 96))
 71  
 72  pos = 0
 73  for index in train.index:
 74      img1 = np.load(onera_train_target + train['pair1'][index])
 75      img2 = np.load(onera_train_target + train['pair2'][index])
 76      X1 = create_rgb_onera(img1, channel)
 77      X2 = create_rgb_onera(img2, channel)
 78      X1 = (X1 - X1.mean()) / X1.std()
 79      X2 = (X2 - X2.mean()) / X2.std()
 80      X_train1[pos] = X1
 81      X_train2[pos] = X2
 82      y_train[pos] = np.load(onera_train_target + train['change_mask'][index])
 83      pos += 1
 84  
 85  y_train = np.expand_dims(y_train, axis=1)
 86  
 87  # Load test data
 88  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 89  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 90  y_test = np.ndarray(shape=(len(test), 96, 96))
 91  
 92  pos = 0
 93  for index in test.index:
 94      img1 = np.load(onera_test_target + test['pair1'][index])
 95      img2 = np.load(onera_test_target + test['pair2'][index])
 96      X1 = create_rgb_onera(img1, channel)
 97      X2 = create_rgb_onera(img2, channel)
 98      X1 = (X1 - X1.mean()) / X1.std()
 99      X2 = (X2 - X2.mean()) / X2.std()
100      X_test1[pos] = X1
101      X_test2[pos] = X2
102      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
103      pos += 1
104  
105  y_test = np.expand_dims(y_test, axis=1)
106  
107  # Create DataLoaders
108  train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(),
109                             torch.tensor(X_train2).permute(0, 3, 1, 2).float(),
110                             torch.tensor(y_train).float())
111  test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(),
112                            torch.tensor(X_test2).permute(0, 3, 1, 2).float(),
113                            torch.tensor(y_test).float())
114  
115  train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
116  test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
117  
118  # Define hook function
119  def print_layer_shape(module, input, output):
120      print(f'{module.__class__.__name__}:')
121      print(f'    Input shape: {input[0].shape}')
122      print(f'    Output shape: {output[0].shape}')
123  
124  # Register hooks
125  def register_hooks(model):
126      for layer in model.children():
127          layer.register_forward_hook(print_layer_shape)
128          if len(list(layer.children())) > 0:
129              register_hooks(layer)
130  
131  # TensorBoard setup
132  log_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/events'
133  
134  # Model Training and Evaluation
135  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
136  model = FCSiamDiff().to(device)
137  criterion = nn.BCEWithLogitsLoss()
138  optimizer = optim.Adam(model.parameters(), lr=1e-4)
139  
140  
141  
142  
143  # Training loop
144  num_epochs = 55
145  
146  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
147  
148  model_id = uuid.uuid4().hex[:4]
149  cd_model_name = "FCSiamDiff_"+"CBMI_"+model_id+".h5"
150  
151  for epoch in range(num_epochs):
152      model.train()
153      running_loss = 0.0
154      for i, (inputs1, inputs2, labels) in enumerate(train_loader):
155          inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
156          optimizer.zero_grad()
157          outputs = model(torch.stack((inputs1, inputs2), dim=1))
158          loss = criterion(outputs, labels)
159          loss.backward()
160          optimizer.step()
161          running_loss += loss.item()
162  
163      
164      print(f"Logged Loss/train for epoch {epoch}")
165  
166      print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
167  
168  # Save the model after training
169  save_path = os.path.join(model_path, cd_model_name)
170  torch.save(model.state_dict(), save_path)
171  print(f"Model saved to {save_path}")
172  
173  # Evaluate the model
174  model.eval()
175  all_labels = []
176  all_predictions = []
177  
178  with torch.no_grad():
179      for inputs1, inputs2, labels in test_loader:
180          inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
181          outputs = model(torch.stack((inputs1, inputs2), dim=1))
182          predicted = (outputs > 0.5).float()
183          
184          all_labels.extend(labels.cpu().numpy().flatten())
185          all_predictions.extend(predicted.cpu().numpy().flatten())
186  
187  # Convert to numpy arrays for metric calculations
188  all_labels = np.array(all_labels)
189  all_predictions = np.array(all_predictions)
190  
191  # Calculate confusion matrix
192  tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
193  
194  # Calculate metrics
195  accuracy = (tp + tn) / (tp + tn + fp + fn)
196  recall = recall_score(all_labels, all_predictions)
197  precision = precision_score(all_labels, all_predictions)
198  f1 = f1_score(all_labels, all_predictions)
199  specificity = tn / (tn + fp)
200  
201  
202  print(f"Recall: {recall:.4f}")
203  print(f"Specificity: {specificity:.4f}")
204  print(f"Precision: {precision:.4f}")
205  print(f"F1 Score: {f1:.4f}")
206  print(f"Accuracy: {accuracy:.4f}")
207  
208  
209  log_params_sim1("Task 1", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", recall, specificity, precision, f1, accuracy, "CBMI Set", 96, " ", "none", cd_model_name)
210  
211  
212  #--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
213  
214  # Prediction
215  def create_rgb_onera(x, channel):
216      if channel == 'red':
217          r = x[:, :, 2]
218          r = np.expand_dims(r, axis=2)
219          return r
220      if channel == 'green':
221          g = x[:, :, 1]
222          g = np.expand_dims(g, axis=2)
223          return g
224      if channel == 'blue':
225          b = x[:, :, 0]
226          b = np.expand_dims(b, axis=2)
227          return b
228      if channel == 'rgb':
229          r = x[:, :, 2]
230          g = x[:, :, 1]
231          b = x[:, :, 0]
232          rgb = np.dstack((r, g, b))
233          return rgb
234      if channel == 'rgbvnir':
235          r = x[:, :, 2]
236          g = x[:, :, 1]
237          b = x[:, :, 0]
238          vnir = x[:, :, 3]
239          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
240          return rgbvnir
241      else:
242          print("NOT CORRECT CHANNELS")
243          return x
244  
245  # Data Loading and Preparation
246  
247  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
248  
249  test = pd.read_csv(onera_test_target + "dataset_test.csv")
250  test = test.sample(frac=1, random_state=1).head(20)  # Select only 20 examples
251  print("Test Data", len(test))
252  
253  n_ch = 3
254  channel = 'rgb'  # Set the channel according to your requirement
255  
256  # Load test data
257  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
258  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
259  y_test = np.ndarray(shape=(len(test), 96, 96))
260  
261  pos = 0
262  for index in test.index:
263      img1 = np.load(onera_test_target + test['pair1'][index])
264      img2 = np.load(onera_test_target + test['pair2'][index])
265      X1 = create_rgb_onera(img1, channel)
266      X2 = create_rgb_onera(img2, channel)
267      X1 = (X1 - X1.mean()) / X1.std()
268      X2 = (X2 - X2.mean()) / X2.std()
269      X_test1[pos] = X1
270      X_test2[pos] = X2
271      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
272      pos += 1
273  
274  # Ensure target labels have the same shape as model output
275  y_test = np.expand_dims(y_test, axis=1)
276  
277  # Permute the inputs to match the expected shape
278  X_test1 = torch.tensor(X_test1).permute(0, 3, 1, 2)  # Convert to [batch_size, channels, height, width]
279  X_test2 = torch.tensor(X_test2).permute(0, 3, 1, 2)  # Convert to [batch_size, channels, height, width]
280  X_test = torch.stack((X_test1, X_test2), dim=1)  # Stack along the second dimension to get [batch_size, 2, channels, height, width]
281  
282  # Create DataLoader
283  test_data = TensorDataset(X_test.float(), torch.tensor(y_test).float())
284  test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
285  
286  # Directory to save predictions
287  predictions_path = '/data/valsamis_data/data/CBMI/CBMI_0.3/Predictions/Depth_2/fcsiamDiff'
288  os.makedirs(predictions_path, exist_ok=True)
289  
290  all_predictions = []
291  
292  with torch.no_grad():
293      for i, (inputs, true_labels) in enumerate(test_loader):
294          inputs = inputs.to(device)
295          outputs = model(inputs)
296          predicted = (outputs > 0.5).float().cpu().numpy()
297  
298          for j in range(predicted.shape[0]):
299              pred_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.npy')
300              np.save(pred_filename, predicted[j])
301              all_predictions.append((inputs[j, 0].cpu().numpy(), inputs[j, 1].cpu().numpy(), predicted[j]))
302  
303              # Visualization and saving as .png
304              plt.figure(figsize=(20, 5))
305  
306              plt.subplot(1, 4, 1)
307              plt.imshow(inputs[j, 0].cpu().permute(1, 2, 0).numpy())
308              plt.title('Input Image 1')
309  
310              plt.subplot(1, 4, 2)
311              plt.imshow(inputs[j, 1].cpu().permute(1, 2, 0).numpy())
312              plt.title('Input Image 2')
313  
314              plt.subplot(1, 4, 3)
315              plt.imshow(predicted[j][0], cmap='gray')
316              plt.title('Predicted Change Mask')
317  
318              plt.subplot(1, 4, 4)
319              plt.imshow(true_labels[j][0].cpu().numpy(), cmap='gray')
320              plt.title('True Change Mask')
321  
322              png_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.png')
323              plt.savefig(png_filename)
324              plt.close()
325  
326  print(f"Predictions saved to {predictions_path}")