/ FCSiam / fcsiam_simpple.py
fcsiam_simpple.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import torch.optim as optim
  7  import numpy as np
  8  import pandas as pd
  9  from sklearn.model_selection import train_test_split
 10  from torch.utils.data import DataLoader, TensorDataset
 11  import uuid
 12  import matplotlib.pyplot as plt
 13  from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
 14  from FCSiam.fcsiamConc import FCSiamConc
 15  from utils.log_params import log_params_sim1
 16  
 17  import os
 18  
 19  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 20  
 21  
 22  
 23  # Define helper functions
 24  def create_rgb_onera(x, channel):
 25      if channel == 'red':
 26          r = x[:, :, 2]
 27          r = np.expand_dims(r, axis=2)
 28          return r
 29      if channel == 'green':
 30          g = x[:, :, 1]
 31          g = np.expand_dims(g, axis=2)
 32          return g
 33      if channel == 'blue':
 34          b = x[:, :, 0]
 35          b = np.expand_dims(b, axis=2)
 36          return b
 37      if channel == 'rgb':
 38          r = x[:, :, 2]
 39          g = x[:, :, 1]
 40          b = x[:, :, 0]
 41          rgb = np.dstack((r, g, b))
 42          return rgb
 43      if channel == 'rgbvnir':
 44          r = x[:, :, 2]
 45          g = x[:, :, 1]
 46          b = x[:, :, 0]
 47          vnir = x[:, :, 3]
 48          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 49          return rgbvnir
 50      else:
 51          print("NOT CORRECT CHANNELS")
 52          return x
 53  
 54  # Data Loading and Preparation
 55  onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'
 56  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 57  
 58  train = pd.read_csv(onera_train_target + "dataset_train.csv")
 59  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 60  
 61  train = train.sample(frac=1, random_state=1)
 62  test = test.sample(frac=1, random_state=1)
 63  print("Train Data", len(train))
 64  print("Test Data", len(test))
 65  
 66  n_ch = 3
 67  channel = 'rgb'  
 68  
 69  # Load training data
 70  X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 71  X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 72  y_train = np.ndarray(shape=(len(train), 96, 96))
 73  
 74  pos = 0
 75  for index in train.index:
 76      img1 = np.load(onera_train_target + train['pair1'][index])
 77      img2 = np.load(onera_train_target + train['pair2'][index])
 78      X1 = create_rgb_onera(img1, channel)
 79      X2 = create_rgb_onera(img2, channel)
 80      X1 = (X1 - X1.mean()) / X1.std()
 81      X2 = (X2 - X2.mean()) / X2.std()
 82      X_train1[pos] = X1
 83      X_train2[pos] = X2
 84      y_train[pos] = np.load(onera_train_target + train['change_mask'][index])
 85      pos += 1
 86  
 87  y_train = np.expand_dims(y_train, axis=1)
 88  
 89  # Load test data
 90  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 91  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 92  y_test = np.ndarray(shape=(len(test), 96, 96))
 93  
 94  pos = 0
 95  for index in test.index:
 96      img1 = np.load(onera_test_target + test['pair1'][index])
 97      img2 = np.load(onera_test_target + test['pair2'][index])
 98      X1 = create_rgb_onera(img1, channel)
 99      X2 = create_rgb_onera(img2, channel)
100      X1 = (X1 - X1.mean()) / X1.std()
101      X2 = (X2 - X2.mean()) / X2.std()
102      X_test1[pos] = X1
103      X_test2[pos] = X2
104      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
105      pos += 1
106  
107  y_test = np.expand_dims(y_test, axis=1)
108  
109  # Create DataLoaders
110  train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(),
111                             torch.tensor(X_train2).permute(0, 3, 1, 2).float(),
112                             torch.tensor(y_train).float())
113  test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(),
114                            torch.tensor(X_test2).permute(0, 3, 1, 2).float(),
115                            torch.tensor(y_test).float())
116  
117  train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
118  test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
119  
120  # Define hook function
121  def print_layer_shape(module, input, output):
122      print(f'{module.__class__.__name__}:')
123      print(f'    Input shape: {input[0].shape}')
124      print(f'    Output shape: {output[0].shape}')
125  
126  # Register hooks
127  def register_hooks(model):
128      for layer in model.children():
129          layer.register_forward_hook(print_layer_shape)
130          if len(list(layer.children())) > 0:
131              register_hooks(layer)
132  
133  # TensorBoard setup
134  log_dir = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/events'
135  
136  
137  # Model Training and Evaluation
138  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
139  model = FCSiamConc().to(device)
140  criterion = nn.BCEWithLogitsLoss()
141  optimizer = optim.Adam(model.parameters(), lr=1e-4)
142  
143  
144  dummy_input1 = torch.randn(1, 3, 96, 96).to(device)
145  dummy_input2 = torch.randn(1, 3, 96, 96).to(device)
146  
147  
148  # Training loop
149  num_epochs = 3
150  
151  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
152  
153  model_id = uuid.uuid4().hex[:4]
154  cd_model_name = "FCSiamConc_"+"CBMI_"+model_id+".h5"
155  
156  for epoch in range(num_epochs):
157      model.train()
158      running_loss = 0.0
159      for i, (inputs1, inputs2, labels) in enumerate(train_loader):
160          inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
161          optimizer.zero_grad()
162          outputs = model(torch.stack((inputs1, inputs2), dim=1))
163          loss = criterion(outputs, labels)
164          loss.backward()
165          optimizer.step()
166          running_loss += loss.item()
167  
168  
169  
170      print(f"Logged Loss/train for epoch {epoch}")
171  
172      print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
173  
174  # Save the model after training
175  save_path = os.path.join(model_path, cd_model_name)
176  torch.save(model.state_dict(), save_path)
177  print(f"Model saved to {save_path}")
178  
179  # Evaluate the model
180  model.eval()
181  all_labels = []
182  all_predictions = []
183  
184  with torch.no_grad():
185      for inputs1, inputs2, labels in test_loader:
186          inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
187          outputs = model(torch.stack((inputs1, inputs2), dim=1))
188          predicted = (outputs > 0.5).float()
189          
190          all_labels.extend(labels.cpu().numpy().flatten())
191          all_predictions.extend(predicted.cpu().numpy().flatten())
192  
193  # Convert to numpy arrays for metric calculations
194  all_labels = np.array(all_labels)
195  all_predictions = np.array(all_predictions)
196  
197  # Calculate confusion matrix
198  tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
199  
200  # Calculate metrics
201  accuracy = (tp + tn) / (tp + tn + fp + fn)
202  recall = recall_score(all_labels, all_predictions)
203  precision = precision_score(all_labels, all_predictions)
204  f1 = f1_score(all_labels, all_predictions)
205  specificity = tn / (tn + fp)
206  
207  
208  
209  print(f"Recall: {recall:.4f}")
210  print(f"Specificity: {specificity:.4f}")
211  print(f"Precision: {precision:.4f}")
212  print(f"F1 Score: {f1:.4f}")
213  print(f"Accuracy: {accuracy:.4f}")
214  
215  
216  log_params_sim1("Task 1", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", recall, specificity, precision, f1, accuracy, "CBMI Set", 96, " ", "none", cd_model_name)