/ gasnet / gas_simple.py
gas_simple.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import torch.optim as optim
  7  import numpy as np
  8  import pandas as pd
  9  from sklearn.model_selection import train_test_split
 10  from torch.utils.data import DataLoader, TensorDataset
 11  import uuid
 12  import matplotlib.pyplot as plt
 13  from gasnet.CDNet_L import CDNet_L
 14  from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
 15  #from torch.utils.tensorboard import SummaryWriter
 16  from utils.log_params import log_params_sim1
 17  
 18  import os
 19  
 20  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 21  
 22  
 23  
 24  def create_rgb_onera(x, channel):
 25      if channel == 'red':
 26          r = x[:, :, 2]
 27          r = np.expand_dims(r, axis=2)
 28          return r
 29      if channel == 'green':
 30          g = x[:, :, 1]
 31          g = np.expand_dims(g, axis=2)
 32          return g
 33      if channel == 'blue':
 34          b = x[:, :, 0]
 35          b = np.expand_dims(b, axis=2)
 36          return b
 37      if channel == 'rgb':
 38          r = x[:, :, 2]
 39          g = x[:, :, 1]
 40          b = x[:, :, 0]
 41          rgb = np.dstack((r, g, b))
 42          return rgb
 43      if channel == 'rgbvnir':
 44          r = x[:, :, 2]
 45          g = x[:, :, 1]
 46          b = x[:, :, 0]
 47          vnir = x[:, :, 3]
 48          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 49          return rgbvnir
 50      else:
 51          print("NOT CORRECT CHANNELS")
 52          return x
 53  
 54  # Data Loading and Preparation
 55  
 56  onera_train_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'
 57  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 58  
 59  train = pd.read_csv(onera_train_target + "dataset_train.csv")
 60  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 61  
 62  train = train.sample(frac=1, random_state=1)
 63  test = test.sample(frac=1, random_state=1)
 64  print("Train Data", len(train))
 65  print("Test Data", len(test))
 66  
 67  n_ch = 3
 68  channel = 'rgb'  
 69  
 70  # Load training data
 71  X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 72  X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch))
 73  y_train = np.ndarray(shape=(len(train), 96, 96))
 74  
 75  pos = 0
 76  for index in train.index:
 77      img1 = np.load(onera_train_target + train['pair1'][index])
 78      img2 = np.load(onera_train_target + train['pair2'][index])
 79      X1 = create_rgb_onera(img1, channel)
 80      X2 = create_rgb_onera(img2, channel)
 81      X1 = (X1 - X1.mean()) / X1.std()
 82      X2 = (X2 - X2.mean()) / X2.std()
 83      X_train1[pos] = X1
 84      X_train2[pos] = X2
 85      y_train[pos] = np.load(onera_train_target + train['change_mask'][index])
 86      pos += 1
 87  
 88  # Ensure target labels have the same shape as model output
 89  y_train = np.expand_dims(y_train, axis=1)
 90  
 91  # Load test data
 92  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 93  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 94  y_test = np.ndarray(shape=(len(test), 96, 96))
 95  
 96  pos = 0
 97  for index in test.index:
 98      img1 = np.load(onera_test_target + test['pair1'][index])
 99      img2 = np.load(onera_test_target + test['pair2'][index])
100      X1 = create_rgb_onera(img1, channel)
101      X2 = create_rgb_onera(img2, channel)
102      X1 = (X1 - X1.mean()) / X1.std()
103      X2 = (X2 - X2.mean()) / X2.std()
104      X_test1[pos] = X1
105      X_test2[pos] = X2
106      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
107      pos += 1
108  
109  # Ensure target labels have the same shape as model output
110  y_test = np.expand_dims(y_test, axis=1)
111  
112  # Create DataLoaders
113  train_data = TensorDataset(torch.tensor(X_train1).permute(0, 3, 1, 2).float(),
114                             torch.tensor(X_train2).permute(0, 3, 1, 2).float(),
115                             torch.tensor(y_train).float())
116  test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(),
117                            torch.tensor(X_test2).permute(0, 3, 1, 2).float(),
118                            torch.tensor(y_test).float())
119  
120  train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
121  test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
122  
123  # Define hook function
124  def print_layer_shape(module, input, output):
125      print(f'{module.__class__.__name__}:')
126      print(f'    Input shape: {input[0].shape}')
127      print(f'    Output shape: {output[0].shape}')
128  
129  # Register hooks
130  def register_hooks(model):
131      for layer in model.children():
132          layer.register_forward_hook(print_layer_shape)
133          if len(list(layer.children())) > 0:
134              register_hooks(layer)
135  
136  
137  
138  # Model Training and Evaluation
139  
140  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141  model = CDNet_L().to(device)
142  criterion = nn.BCELoss()
143  optimizer = optim.Adam(model.parameters(), lr=1e-4)
144  
145  
146  
147  # Training loop
148  num_epochs = 30
149  
150  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
151  
152  model_id = uuid.uuid4().hex[:4]
153  cd_model_name = "Gas_Net_"+"CBMI_"+model_id+".h5"
154  
155  for epoch in range(num_epochs):
156      model.train()
157      running_loss = 0.0
158      for i, (inputs1, inputs2, labels) in enumerate(train_loader):
159          inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
160          optimizer.zero_grad()
161          outputs = model(inputs1, inputs2)
162          loss = criterion(outputs, labels)
163          loss.backward()
164          optimizer.step()
165          running_loss += loss.item()
166  
167  
168      print(f"Logged Loss/train for epoch {epoch}")
169  
170      print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
171  
172  # Save the model after training
173  save_path = os.path.join(model_path, cd_model_name)
174  torch.save(model.state_dict(), save_path)
175  print(f"Model saved to {save_path}")
176  
177  # Evaluate the model
178  model.eval()
179  all_labels = []
180  all_predictions = []
181  
182  with torch.no_grad():
183      for inputs1, inputs2, labels in test_loader:
184          inputs1, inputs2, labels = inputs1.to(device), inputs2.to(device), labels.to(device)
185          outputs = model(inputs1, inputs2)
186          predicted = (outputs > 0.5).float()
187          
188          all_labels.extend(labels.cpu().numpy().flatten())
189          all_predictions.extend(predicted.cpu().numpy().flatten())
190  
191  # Convert to numpy arrays for metric calculations
192  all_labels = np.array(all_labels)
193  all_predictions = np.array(all_predictions)
194  
195  # Calculate confusion matrix
196  tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions).ravel()
197  
198  # Calculate metrics
199  accuracy = (tp + tn) / (tp + tn + fp + fn)
200  recall = recall_score(all_labels, all_predictions)
201  precision = precision_score(all_labels, all_predictions)
202  f1 = f1_score(all_labels, all_predictions)
203  specificity = tn / (tn + fp)
204  
205  
206  print(f"Recall: {recall:.4f}")
207  print(f"Specificity: {specificity:.4f}")
208  print(f"Precision: {precision:.4f}")
209  print(f"F1 Score: {f1:.4f}")
210  print(f"Accuracy: {accuracy:.4f}")
211  
212  log_params_sim1("Task 1", " ", ' ', " ", " ", "Softmax", " ", 'Adam', num_epochs, 'weighted_categorical_crossentropy', " ", recall, specificity, precision, f1, accuracy, "CBMI Set", 96, " ", "none", cd_model_name)
213