/ training / downstream_tasks / simple_cd_k_fold.py
simple_cd_k_fold.py
  1  #!/usr/bin/env python3
  2  # -*- coding: utf-8 -*-
  3  """
  4  Created on Wed Mar  9 20:15:48 2022
  5  
  6  @author: aleoikon
  7  """
  8  
  9  import sys
 10  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
 11  import time
 12  
 13  from architectures.similarity_detection import pretext_task_one_nopool, pretext_task_one_aspp
 14  import tensorflow
 15  from tensorflow.keras import layers, Model
 16  from architectures.branch import branches_triplet
 17  from tensorflow.keras.optimizers import Adam
 18  from tensorflow.keras.utils import plot_model
 19  
 20  import pandas as pd
 21  import numpy as np
 22  import os
 23  from tensorflow import keras
 24  from architectures.conv_classifier import conv_classifier_two, conv_classifier_two_with_nspp, conv_classifier_two_with_aspp
 25  from utils.layer_select import feature_selector, feature_selector_simple, transfer_learning_model
 26  from utils.my_metrics import recall, accuracy, specificity, precision, f_measure, get_confusion_matrix
 27  from utils.log_params import log_params_sim1
 28  
 29  import matplotlib
 30  matplotlib.use('TkAgg')
 31  import matplotlib.pyplot as plt
 32  import uuid
 33  import random
 34  from utils.weighted_cross_entropy import weighted_categorical_crossentropy
 35  from sklearn.model_selection import KFold
 36  
 37  os.environ["CUDA_VISIBLE_DEVICES"] = "1"
 38  
 39  channel = 'rgb'
 40  
 41  # Set random seed for TensorFlow
 42  tensorflow.random.set_seed(1234)
 43  
 44  # Set random seed for NumPy
 45  np.random.seed(1234)
 46  
 47  
 48  def generate_short_id():
 49      # Generate a UUID
 50      unique_id = uuid.uuid4()
 51  
 52      # Convert UUID to a hex string and take the first 4 characters
 53      short_id = str(unique_id.hex)[:4]
 54  
 55      return short_id
 56  
 57  def create_rgb_onera(x, channel):
 58      if channel == 'red':
 59          r = x[:,:,2]
 60          r = np.expand_dims(r, axis=2)
 61          return r
 62      if channel == 'green':
 63          g = x[:,:,1]
 64          g = np.expand_dims(g, axis=2)
 65          return g
 66      if channel == 'blue':
 67          b = x[:,:,0]
 68          b = np.expand_dims(b, axis=2)
 69          return b
 70      if channel == 'rgb':
 71          r = x[:,:,2]
 72          g = x[:,:,1]
 73          b = x[:,:,0]
 74          rgb = np.dstack((r, g, b))
 75          return rgb
 76      if channel == 'rgbvnir':
 77          r = x[:,:,2]
 78          g = x[:,:,1]
 79          b = x[:,:,0]
 80          vnir = x[:,:,3]
 81          rgbvnir = np.stack((r, g, b, vnir), axis=2).astype('float')
 82          return rgbvnir
 83      else:
 84          return x
 85          print("NOT CORRECT CHANNELS")
 86  
 87  def generate_short_id():
 88      # Generate a UUID
 89      unique_id = uuid.uuid4()
 90  
 91      # Convert UUID to a hex string and take the first 4 characters
 92      short_id = str(unique_id.hex)[:4]
 93  
 94      return short_id
 95  
 96  # Data Loading and Preparation
 97  
 98  onera_train_target =  '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_train_data/'  
 99  onera_test_target = '/data/valsamis_data/data/CBMI/CBMI_0.3/CBMI_0.3/NPY_dataset/aug_test_data/'
100  
101  train = pd.read_csv(onera_train_target + "dataset_train.csv")
102  test = pd.read_csv(onera_test_target + "dataset_test.csv")
103  
104  train = train.sample(frac=1, random_state=1)
105  test = test.sample(frac=1, random_state=1)
106  print("Train Data", len(train))
107  print("Test Data", len(test))
108  
109  NORM = True
110  n_ch = 3
111  
112  # Load everything in memory
113  X_train1 = np.ndarray(shape=(len(train), 96, 96, n_ch))
114  X_train2 = np.ndarray(shape=(len(train), 96, 96, n_ch))
115  y_train = np.ndarray(shape=(len(train), 96, 96))
116  
117  pos = 0
118  for index in train.index:
119      img1 = np.load(onera_train_target + train['pair1'][index])
120      img2 = np.load(onera_train_target + train['pair2'][index])
121      X1 = create_rgb_onera(img1, channel)
122      X2 = create_rgb_onera(img2, channel)
123      X1 = (X1 - X1.mean()) / X1.std()
124      X2 = (X2 - X2.mean()) / X2.std()
125      X_train1[pos] = X1
126      X_train2[pos] = X2
127      y_train[pos] = np.load(onera_train_target + train['change_mask'][index])
128  
129      pos += 1
130  
131  ##### see the ration of 1 to 0s
132  train_balance = y_train.flatten()
133  (unique, counts) = np.unique(train_balance, return_counts=True)
134  frequencies = np.asarray((unique, counts)).T
135  print(frequencies[0][1] / frequencies[1][1])
136  
137  # One-hot encode the combined labels
138  y_hot_train = keras.utils.to_categorical(y_train, num_classes=2)
139  
140  X_test1 = np.ndarray(shape=(len(test), 96, 96, n_ch))
141  X_test2 = np.ndarray(shape=(len(test), 96, 96, n_ch))
142  y_test = np.ndarray(shape=(len(test), 96, 96))
143  
144  pos = 0
145  for index in test.index:
146      img1 = np.load(onera_test_target + test['pair1'][index])
147      img2 = np.load(onera_test_target + test['pair2'][index])
148      X1 = create_rgb_onera(img1, channel)
149      X2 = create_rgb_onera(img2, channel)
150      X1 = (X1 - X1.mean()) / X1.std()
151      X2 = (X2 - X2.mean()) / X2.std()
152      X_test1[pos] = X1
153      X_test2[pos] = X2
154      y_test[pos] = np.load(onera_test_target + test['change_mask'][index])
155      pos += 1
156  
157  # Combine X and y for k-fold
158  X1 = np.concatenate((X_train1, X_test1), axis=0)
159  X2 = np.concatenate((X_train2, X_test2), axis=0)
160  y = np.concatenate((y_train, y_test), axis=0)
161  
162  # One-hot encode the combined labels
163  y_hot = keras.utils.to_categorical(y, num_classes=2)
164  
165  # Initialize KFold
166  kf = KFold(n_splits=5, shuffle=True, random_state=1)
167  
168  depth = 2
169  dropout = 0.1
170  decay = 0.0001
171  LEARNING_RATE = 0.001
172  EPOCHS = 55
173  BATCH_SIZE=5
174  
175  
176  fold_no = 1
177  results = []
178  
179  for train_index, val_index in kf.split(X1):
180      print(f'Training fold {fold_no}...')
181      
182      # Split the data
183      X_train1_fold, X_val1_fold = X1[train_index], X1[val_index]
184      X_train2_fold, X_val2_fold = X2[train_index], X2[val_index]
185      y_train_fold, y_val_fold = y_hot[train_index], y_hot[val_index]
186  
187      # Create a new instance of the model
188      cd_model = conv_classifier_two_with_aspp(depth, dropout, decay, 96, 96, n_ch)
189      
190      # Load pretext model weights
191      sim_model = pretext_task_one_nopool(dropout, decay, 96, 96, n_ch)
192      pretext_model_name = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/model_pretext1_unclouded_results.h5'
193      pretext_model = 'model_pretext1_unclouded_results'
194      sim_model.load_weights(pretext_model_name)
195      
196      # Feature selection(task1)
197      cd_model = feature_selector_simple(depth, sim_model, cd_model)
198      
199      wx = 0.1
200      wy = 0.2
201      weights = np.array([wx, wy])
202      
203      optimizer = Adam(learning_rate=LEARNING_RATE)
204      cd_model.compile(optimizer=optimizer, loss=weighted_categorical_crossentropy(weights), metrics=['accuracy'])
205      
206      # Train the model
207      history = cd_model.fit(
208          [X_train1_fold, X_train2_fold],
209          y_train_fold,
210          validation_data=([X_val1_fold, X_val2_fold], y_val_fold),
211          batch_size=BATCH_SIZE,
212          epochs=EPOCHS
213      )
214      
215      # Evaluate the model
216      predictions = cd_model.predict([X_val1_fold, X_val2_fold])
217      y_pred = np.argmax(predictions, axis=3)
218      y_true = np.argmax(y_val_fold, axis=3)
219      
220      acc = accuracy(y_true, y_pred)
221      spec = specificity(y_true, y_pred)
222      rec = recall(y_true, y_pred)
223      prec = precision(y_true, y_pred)
224      f_m = f_measure(y_true, y_pred)
225      
226      results.append({
227          'fold': fold_no,
228          'accuracy': acc,
229          'specificity': spec,
230          'recall': rec,
231          'precision': prec,
232          'f1_score': f_m
233      })
234  
235      weight_par = '[' + str(wx) + ',' + str(wy) + ']'
236      log_params_sim1("Task 1 (T)", "Linear", 'ASPP', weight_par, depth, "Softmax", LEARNING_RATE, 'Adam', EPOCHS, 'weighted_categorical_crossentropy', BATCH_SIZE, rec, spec, prec, f_m, acc, "CBMI Set", 96, NORM, pretext_model, "CD_Simple_Sysu_d8d2")
237      
238      fold_no += 1
239  
240  # Print results
241  for result in results:
242      print(f"Fold {result['fold']}:")
243      print(f"  Accuracy: {result['accuracy']:.4f}")
244      print(f"  Specificity: {result['specificity']:.4f}")
245      print(f"  Recall: {result['recall']:.4f}")
246      print(f"  Precision: {result['precision']:.4f}")
247      print(f"  F1 Score: {result['f1_score']:.4f}")
248  
249  # Save the model after training on all folds
250  model_id = generate_short_id()
251  cd_model_name = "CD_Simple_" + "CBMI_" + model_id + ".h5"
252  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
253  
254  
255  
256  cd_model.save_weights(model_path + cd_model_name)
257  weight_par = '[' + str(wx) + ',' + str(wy) + ']'
258  log_params_sim1("Task 1 (T)", "Linear", 'ASPP', weight_par, depth, "Softmax", LEARNING_RATE, 'Adam', EPOCHS, 'weighted_categorical_crossentropy', BATCH_SIZE, rec, spec, prec, f_m, acc, "CBMI Set", 96, NORM, pretext_model, "CD_Simple_Sysu_d8d2")
259      
260  print("Saved model to disk")
261  
262  
263  print("Done")