/ FCSiam / fcsiamConc_preds.py
fcsiamConc_preds.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  import torch
  5  import torch.nn as nn
  6  import numpy as np
  7  import pandas as pd
  8  from torch.utils.data import DataLoader, TensorDataset
  9  import os
 10  import matplotlib.pyplot as plt
 11  
 12  from FCSiam.fcsiamConc import FCSiamConc
 13  
 14  os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 15  
 16  def create_rgb_onera(x, channel):
 17      if channel == 'red':
 18          r = x[:, :, 2]
 19          r = np.expand_dims(r, axis=2)
 20          return r
 21      if channel == 'green':
 22          g = x[:, :, 1]
 23          g = np.expand_dims(g, axis=2)
 24          return g
 25      if channel == 'blue':
 26          b = x[:, :, 0]
 27          b = np.expand_dims(b, axis=2)
 28          return b
 29      if channel == 'rgb':
 30          r = x[:, :, 2]
 31          g = x[:, :, 1]
 32          b = x[:, :, 0]
 33          rgb = np.dstack((r, g, b))
 34          return rgb
 35      if channel == 'rgbvnir':
 36          r = x[:, :, 2]
 37          g = x[:, :, 1]
 38          b = x[:, :, 0]
 39          vnir = x[:, :, 3]
 40          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 41          return rgbvnir
 42      else:
 43          print("NOT CORRECT CHANNELS")
 44          return x
 45  
 46  # Data Loading and Preparation
 47  
 48  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
 49  
 50  test = pd.read_csv(onera_test_target + "dataset_test.csv")
 51  test = test.sample(frac=1, random_state=1).head(20)  # Select only 20 examples
 52  print("Test Data", len(test))
 53  
 54  n_ch = 3
 55  channel = 'rgb'  
 56  
 57  # Load test data
 58  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 59  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
 60  y_test = np.ndarray(shape=(len(test), 96, 96))
 61  
 62  pos = 0
 63  for index in test.index:
 64      img1 = np.load(onera_test_target + test['pair1'][index])
 65      img2 = np.load(onera_test_target + test['pair2'][index])
 66      X1 = create_rgb_onera(img1, channel)
 67      X2 = create_rgb_onera(img2, channel)
 68      X1 = (X1 - X1.mean()) / X1.std()
 69      X2 = (X2 - X2.mean()) / X2.std()
 70      X_test1[pos] = X1
 71      X_test2[pos] = X2
 72      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
 73      pos += 1
 74  
 75  # Ensure target labels have the same shape as model output
 76  y_test = np.expand_dims(y_test, axis=1)
 77  
 78  # Permute the inputs to match the expected shape
 79  X_test1 = torch.tensor(X_test1).permute(0, 3, 1, 2)  # Convert to [batch_size, channels, height, width]
 80  X_test2 = torch.tensor(X_test2).permute(0, 3, 1, 2)  # Convert to [batch_size, channels, height, width]
 81  X_test = torch.stack((X_test1, X_test2), dim=1)  # Stack along the second dimension to get [batch_size, 2, channels, height, width]
 82  print("Shape of combined inputs:", X_test.shape) 
 83  
 84  
 85  # Create DataLoader
 86  test_data = TensorDataset(X_test.float(), torch.tensor(y_test).float())
 87  test_loader = DataLoader(test_data, batch_size=16, shuffle=False)
 88  
 89  # Model Evaluation
 90  
 91  device = torch.device("cpu")
 92  model = FCSiamConc().to(device)
 93  
 94  # Load the trained model
 95  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/FCSiamConc_CBMI_ac4f.h5'
 96  model.load_state_dict(torch.load(model_path, map_location=device))
 97  model.eval()
 98  
 99  # Directory to save predictions
100  predictions_path = '/data/valsamis_data/data/CBMI/CBMI_0.3/Predictions/Depth_2/siamConc'
101  os.makedirs(predictions_path, exist_ok=True)
102  
103  all_predictions = []
104  
105  with torch.no_grad():
106      for i, (inputs, true_labels) in enumerate(test_loader):
107          inputs = inputs.to(device)
108          outputs = model(inputs)
109  
110          # Assuming outputs are already squeezed to remove the batch and channel dimensions if singular
111          output_mask = outputs.squeeze().cpu().numpy()
112  
113          plt.figure(figsize=(6, 6))
114          plt.imshow(output_mask, cmap='gray')  # Use an appropriate colormap if needed
115          plt.title('Output Mask')
116          plt.colorbar()
117          plt.show()
118  
119          predicted = (outputs > 0.5).float().cpu().numpy()
120  
121          for j in range(predicted.shape[0]):
122              pred_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.npy')
123              np.save(pred_filename, predicted[j])
124              all_predictions.append((inputs[j, 0].cpu().numpy(), inputs[j, 1].cpu().numpy(), predicted[j]))
125  
126              # Visualization and saving as .png
127              plt.figure(figsize=(20, 5))
128  
129              plt.subplot(1, 4, 1)
130              plt.imshow(inputs[j, 0].cpu().permute(1, 2, 0).numpy())
131              plt.title('Input Image 1')
132  
133              plt.subplot(1, 4, 2)
134              plt.imshow(inputs[j, 1].cpu().permute(1, 2, 0).numpy())
135              plt.title('Input Image 2')
136  
137              plt.subplot(1, 4, 3)
138              plt.imshow(predicted[j][0], cmap='gray')
139              plt.title('Predicted Change Mask')
140  
141              plt.subplot(1, 4, 4)
142              plt.imshow(true_labels[j][0].cpu().numpy(), cmap='gray')
143              plt.title('True Change Mask')
144  
145              png_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.png')
146              plt.savefig(png_filename)
147              plt.close()
148  
149  print(f"Predictions saved to {predictions_path}")