preds.py
1 import sys 2 sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese') 3 4 import torch 5 import torch.nn as nn 6 import numpy as np 7 import pandas as pd 8 from torch.utils.data import DataLoader, TensorDataset 9 import os 10 from gasnet.CDNet_L import CDNet_L 11 import matplotlib.pyplot as plt 12 13 14 15 os.environ["CUDA_VISIBLE_DEVICES"] = "1" 16 17 def create_rgb_onera(x, channel): 18 if channel == 'red': 19 r = x[:, :, 2] 20 r = np.expand_dims(r, axis=2) 21 return r 22 if channel == 'green': 23 g = x[:, :, 1] 24 g = np.expand_dims(g, axis=2) 25 return g 26 if channel == 'blue': 27 b = x[:, :, 0] 28 b = np.expand_dims(b, axis=2) 29 return b 30 if channel == 'rgb': 31 r = x[:, :, 2] 32 g = x[:, :, 1] 33 b = x[:, :, 0] 34 rgb = np.dstack((r, g, b)) 35 return rgb 36 if channel == 'rgbvnir': 37 r = x[:, :, 2] 38 g = x[:, :, 1] 39 b = x[:, :, 0] 40 vnir = x[:, :, 3] 41 rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float') 42 return rgbvnir 43 else: 44 print("NOT CORRECT CHANNELS") 45 return x 46 47 # Data Loading and Preparation 48 49 onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/' 50 51 test = pd.read_csv(onera_test_target + "dataset_test.csv") 52 test = test.sample(frac=1, random_state=1).head(20) 53 print("Test Data", len(test)) 54 55 n_ch = 3 56 channel = 'rgb' 57 58 # Load test data 59 X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 60 X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch)) 61 y_test = np.ndarray(shape=(len(test), 96, 96)) 62 63 pos = 0 64 for index in test.index: 65 img1 = np.load(onera_test_target + test['pair1'][index]) 66 img2 = np.load(onera_test_target + test['pair2'][index]) 67 X1 = create_rgb_onera(img1, channel) 68 X2 = create_rgb_onera(img2, channel) 69 X1 = (X1 - X1.mean()) / X1.std() 70 X2 = (X2 - X2.mean()) / X2.std() 71 X_test1[pos] = X1 72 X_test2[pos] = X2 73 y_test[pos] = np.load(onera_test_target + test['change_mask'][index]) 74 pos += 1 75 76 # Ensure target labels have the same shape as model output 77 y_test = np.expand_dims(y_test, axis=1) 78 79 # Create DataLoader 80 test_data = TensorDataset(torch.tensor(X_test1).permute(0, 3, 1, 2).float(), 81 torch.tensor(X_test2).permute(0, 3, 1, 2).float(), 82 torch.tensor(y_test).float()) 83 test_loader = DataLoader(test_data, batch_size=16, shuffle=False) 84 85 # Model Evaluation 86 87 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 model = CDNet_L().to(device) 89 90 # Load the trained model 91 model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/FCSiamDiff_CBMI_6acc.h5' 92 model.load_state_dict(torch.load(model_path)) 93 model.eval() 94 95 # Directory to save predictions 96 predictions_path = '/data/valsamis_data/data/CBMI/CBMI_0.3/Predictions/Depth_2/fcsiamDiff' 97 os.makedirs(predictions_path, exist_ok=True) 98 99 all_predictions = [] 100 101 with torch.no_grad(): 102 for i, (inputs1, inputs2, true_labels) in enumerate(test_loader): 103 inputs1, inputs2 = inputs1.to(device), inputs2.to(device) 104 outputs = model(inputs1, inputs2) 105 predicted = (outputs > 0.5).float().cpu().numpy() 106 107 for j in range(predicted.shape[0]): 108 pred_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.npy') 109 np.save(pred_filename, predicted[j]) 110 all_predictions.append((inputs1[j].cpu().numpy(), inputs2[j].cpu().numpy(), predicted[j])) 111 112 # Visualization and saving as .png 113 plt.figure(figsize=(20, 5)) 114 115 plt.subplot(1, 4, 1) 116 plt.imshow(inputs1[j].cpu().permute(1, 2, 0).numpy()) 117 plt.title('Input Image 1') 118 119 plt.subplot(1, 4, 2) 120 plt.imshow(inputs2[j].cpu().permute(1, 2, 0).numpy()) 121 plt.title('Input Image 2') 122 123 plt.subplot(1, 4, 3) 124 plt.imshow(predicted[j][0], cmap='gray') 125 plt.title('Predicted Change Mask') 126 127 plt.subplot(1, 4, 4) 128 plt.imshow(true_labels[j][0].cpu().numpy(), cmap='gray') 129 plt.title('True Change Mask') 130 131 png_filename = os.path.join(predictions_path, f'prediction_{i*test_loader.batch_size + j}.png') 132 plt.savefig(png_filename) 133 plt.close() 134 135 print(f"Predictions saved to {predictions_path}")