/ gasnet / preds.py
preds.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import numpy as np
  7  import pandas as pd
  8  from torch.utils.data import DataLoader, TensorDataset
  9  import os
 10  from gasnet.CDNet_L import CDNet_L
 11  import matplotlib.pyplot as plt
 12  
 13  
 14  
 15  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 16  
 17  def create_rgb_onera(x, channel):
 18      if channel == 'red':
 19          r = x[:, :, 2]
 20          r = np.expand_dims(r, axis=2)
 21          return r
 22      if channel == 'green':
 23          g = x[:, :, 1]
 24          g = np.expand_dims(g, axis=2)
 25          return g
 26      if channel == 'blue':
 27          b = x[:, :, 0]
 28          b = np.expand_dims(b, axis=2)
 29          return b
 30      if channel == 'rgb':
 31          r = x[:, :, 2]
 32          g = x[:, :, 1]
 33          b = x[:, :, 0]
 34          rgb = np.dstack((r, g, b))
 35          return rgb
 36      if channel == 'rgbvnir':
 37          r = x[:, :, 2]
 38          g = x[:, :, 1]
 39          b = x[:, :, 0]
 40          vnir = x[:, :, 3]
 41          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 42          return rgbvnir
 43      else:
 44          print("NOT CORRECT CHANNELS")
 45          return x
 46  
 47  # Data Loading and Preparation
 48  
 49  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 50  
 51  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 52  test = test.sample(frac=1, random_state=1).head(20)  
 53  print("Test Data", len(test))
 54  
 55  n_ch = 3
 56  channel = 'rgb'  
 57  
 58  # Load test data
 59  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 60  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 61  y_test = np.ndarray(shape=(len(test), 96, 96))
 62  
 63  pos = 0
 64  for index in test.index:
 65      img1 = np.load(onera_test_target + test['pair1'][index])
 66      img2 = np.load(onera_test_target + test['pair2'][index])
 67      X1 = create_rgb_onera(img1, channel)
 68      X2 = create_rgb_onera(img2, channel)
 69      X1 = (X1 - X1.mean()) / X1.std()
 70      X2 = (X2 - X2.mean()) / X2.std()
 71      X_test1[pos] = X1
 72      X_test2[pos] = X2
 73      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
 74      pos += 1
 75  
 76  # Ensure target labels have the same shape as model output
 77  y_test = np.expand_dims(y_test, axis=1)
 78  
 79  # Create DataLoader
 80  test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(),
 81                            torch.tensor(X_test2).permute(0, 3, 1, 2).float(),
 82                            torch.tensor(y_test).float())
 83  test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
 84  
 85  # Model Evaluation
 86  
 87  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 88  model = CDNet_L().to(device)
 89  
 90  # Load the trained model
 91  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/FCSiamDiff_CBMI_6acc.h5'
 92  model.load_state_dict(torch.load(model_path))
 93  model.eval()
 94  
 95  # Directory to save predictions
 96  predictions_path = '/data/valsamis_data/data/CBMI/CBMI_0.3/Predictions/Depth_2/fcsiamDiff'
 97  os.makedirs(predictions_path, exist_ok=True)
 98  
 99  all_predictions = []
100  
101  with torch.no_grad():
102      for i, (inputs1, inputs2, true_labels) in enumerate(test_loader):
103          inputs1, inputs2 = inputs1.to(device), inputs2.to(device)
104          outputs = model(inputs1, inputs2)
105          predicted = (outputs > 0.5).float().cpu().numpy()
106  
107          for j in range(predicted.shape[0]):
108              pred_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.npy')
109              np.save(pred_filename, predicted[j])
110              all_predictions.append((inputs1[j].cpu().numpy(), inputs2[j].cpu().numpy(), predicted[j]))
111  
112              # Visualization and saving as .png
113              plt.figure(figsize=(20, 5))
114  
115              plt.subplot(1, 4, 1)
116              plt.imshow(inputs1[j].cpu().permute(1, 2, 0).numpy())
117              plt.title('Input Image 1')
118  
119              plt.subplot(1, 4, 2)
120              plt.imshow(inputs2[j].cpu().permute(1, 2, 0).numpy())
121              plt.title('Input Image 2')
122  
123              plt.subplot(1, 4, 3)
124              plt.imshow(predicted[j][0], cmap='gray')
125              plt.title('Predicted Change Mask')
126  
127              plt.subplot(1, 4, 4)
128              plt.imshow(true_labels[j][0].cpu().numpy(), cmap='gray')
129              plt.title('True Change Mask')
130  
131              png_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.png')
132              plt.savefig(png_filename)
133              plt.close()
134  
135  print(f"Predictions saved to {predictions_path}")