fcsiamConc_preds.py
1 import sys 2 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 3 4 import torch 5 import torch.nn as nn 6 import numpy as np 7 import pandas as pd 8 from torch.utils.data import DataLoader, TensorDataset 9 import os 10 import matplotlib.pyplot as plt 11 12 from FCSiam.fcsiamConc import FCSiamConc 13 14 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 15 16 def create_rgb_onera(x, channel): 17 if channel == 'red': 18 r = x[:, :, 2] 19 r = np.expand_dims(r, axis=2) 20 return r 21 if channel == 'green': 22 g = x[:, :, 1] 23 g = np.expand_dims(g, axis=2) 24 return g 25 if channel == 'blue': 26 b = x[:, :, 0] 27 b = np.expand_dims(b, axis=2) 28 return b 29 if channel == 'rgb': 30 r = x[:, :, 2] 31 g = x[:, :, 1] 32 b = x[:, :, 0] 33 rgb = np.dstack((r, g, b)) 34 return rgb 35 if channel == 'rgbvnir': 36 r = x[:, :, 2] 37 g = x[:, :, 1] 38 b = x[:, :, 0] 39 vnir = x[:, :, 3] 40 rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float') 41 return rgbvnir 42 else: 43 print("NOT CORRECT CHANNELS") 44 return x 45 46 # Data Loading and Preparation 47 48 onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/' 49 50 test = pd.read_csv(onera_test_target + "dataset_test.csv") 51 test = test.sample(frac=1, random_state=1).head(20) # Select only 20 examples 52 print("Test Data", len(test)) 53 54 n_ch = 3 55 channel = 'rgb' 56 57 # Load test data 58 X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 59 X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 60 y_test = np.ndarray(shape=(len(test), 96, 96)) 61 62 pos = 0 63 for index in test.index: 64 img1 = np.load(onera_test_target + test['pair1'][index]) 65 img2 = np.load(onera_test_target + test['pair2'][index]) 66 X1 = create_rgb_onera(img1, channel) 67 X2 = create_rgb_onera(img2, channel) 68 X1 = (X1 - X1.mean()) / X1.std() 69 X2 = (X2 - X2.mean()) / X2.std() 70 X_test1[pos] = X1 71 X_test2[pos] = X2 72 y_test[pos] = np.load(onera_test_target + test['change_mask'][index]) 73 pos += 1 74 75 # Ensure target labels have the same shape as model output 76 y_test = np.expand_dims(y_test, axis=1) 77 78 # Permute the inputs to match the expected shape 79 X_test1 = torch.tensor(X_test1).permute(0, 3, 1, 2) # Convert to [batch_size, channels, height, width] 80 X_test2 = torch.tensor(X_test2).permute(0, 3, 1, 2) # Convert to [batch_size, channels, height, width] 81 X_test = torch.stack((X_test1, X_test2), dim=1) # Stack along the second dimension to get [batch_size, 2, channels, height, width] 82 print("Shape of combined inputs:", X_test.shape) 83 84 85 # Create DataLoader 86 test_data = TensorDataset(X_test.float(), torch.tensor(y_test).float()) 87 test_loader = DataLoader(test_data, batch_size=16, shuffle=False) 88 89 # Model Evaluation 90 91 device = torch.device("cpu") 92 model = FCSiamConc().to(device) 93 94 # Load the trained model 95 model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/FCSiamConc_CBMI_ac4f.h5' 96 model.load_state_dict(torch.load(model_path, map_location=device)) 97 model.eval() 98 99 # Directory to save predictions 100 predictions_path = '/data/valsamis_data/data/CBMI/CBMI_0.3/Predictions/Depth_2/siamConc' 101 os.makedirs(predictions_path, exist_ok=True) 102 103 all_predictions = [] 104 105 with torch.no_grad(): 106 for i, (inputs, true_labels) in enumerate(test_loader): 107 inputs = inputs.to(device) 108 outputs = model(inputs) 109 110 # Assuming outputs are already squeezed to remove the batch and channel dimensions if singular 111 output_mask = outputs.squeeze().cpu().numpy() 112 113 plt.figure(figsize=(6, 6)) 114 plt.imshow(output_mask, cmap='gray') # Use an appropriate colormap if needed 115 plt.title('Output Mask') 116 plt.colorbar() 117 plt.show() 118 119 predicted = (outputs > 0.5).float().cpu().numpy() 120 121 for j in range(predicted.shape[0]): 122 pred_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.npy') 123 np.save(pred_filename, predicted[j]) 124 all_predictions.append((inputs[j, 0].cpu().numpy(), inputs[j, 1].cpu().numpy(), predicted[j])) 125 126 # Visualization and saving as .png 127 plt.figure(figsize=(20, 5)) 128 129 plt.subplot(1, 4, 1) 130 plt.imshow(inputs[j, 0].cpu().permute(1, 2, 0).numpy()) 131 plt.title('Input Image 1') 132 133 plt.subplot(1, 4, 2) 134 plt.imshow(inputs[j, 1].cpu().permute(1, 2, 0).numpy()) 135 plt.title('Input Image 2') 136 137 plt.subplot(1, 4, 3) 138 plt.imshow(predicted[j][0], cmap='gray') 139 plt.title('Predicted Change Mask') 140 141 plt.subplot(1, 4, 4) 142 plt.imshow(true_labels[j][0].cpu().numpy(), cmap='gray') 143 plt.title('True Change Mask') 144 145 png_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.png') 146 plt.savefig(png_filename) 147 plt.close() 148 149 print(f"Predictions saved to {predictions_path}")