/ training / pretext_tasks / pretext_task1_Grad.py
pretext_task1_Grad.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  
  5  import os
  6  import pandas as pd
  7  from architectures.similarity_detection import pretext_task_one_nopool, pretext_one, pretext_task_one_aspp
  8  from utils.log_params import log_params
  9  import matplotlib
 10  matplotlib.use('TkAgg')
 11  import matplotlib.pyplot as plt
 12  import numpy as np
 13  import tensorflow as tf
 14  from tensorflow.keras.optimizers import Adam
 15  from scipy.ndimage import zoom
 16  import time
 17  import uuid
 18  from datetime import datetime
 19  
 20  os.environ["CUDA_VISIBLE_DEVICES"]="1"
 21  
 22  def feature_scaling(img, method):
 23      I = img
 24      if method == "STAND":
 25          I = (I - I.mean()) / I.std()
 26          return I
 27      if method == "MINMAX":
 28          I = ((I - np.nanmin(I))/(np.nanmax(I) - np.nanmin(I)))
 29          return I
 30      if method == "MEAN":
 31          I = (I - I.mean()) / (np.nanmax(I) - np.nanmin(I))
 32      else:
 33          return I
 34      
 35  
 36  def create_rgb(x,channel):
 37      if channel == 'red':
 38          r = x[:,:,1]
 39          r = np.expand_dims(r, axis=2)
 40          return r
 41      if channel == 'green':
 42          g = x[:,:,2]
 43          g = np.expand_dims(g, axis=2)
 44          return g
 45      if channel == 'blue':
 46          b  = x[:,:,3]
 47          b = np.expand_dims(b, axis=2)
 48          return b
 49      if channel == 'rgb':
 50          r = x[:,:,1]
 51          g = x[:,:,2]
 52          b  = x[:,:,3]
 53          rgb = np.dstack((r,g,b))
 54          return(rgb)
 55      if channel == 'rgbvnir':
 56          r = x[:,:,1]
 57          g = x[:,:,2]
 58          b  = x[:,:,3]
 59          vnir = x[:,:,8]
 60          rgbvnir = np.stack((r,g,b,vnir),axis=2).astype('float')
 61          #rgb = np.dstack((r,g,b))
 62          return(rgbvnir)
 63      if channel == 'eq20':
 64          r = x[:,:,1]
 65          s = r.shape
 66          ir1 = adjust_shape(zoom(x[:,:,4],2),s)
 67          ir2 = adjust_shape(zoom(x[:,:,5],2),s)
 68          ir3 = adjust_shape(zoom(x[:,:,6],2),s)
 69          nir2 = adjust_shape(zoom(x[:,:,8],2),s)
 70          swir2 = adjust_shape(zoom(x[:,:,11],2),s)
 71          swir3 = adjust_shape(zoom(x[:,:,12],2),s)
 72          x = np.stack((ir1,ir2,ir3,nir2,swir2,swir3),axis=2).astype('float') 
 73          return x
 74      else:
 75  
 76          print("NOT CORRECT CHANNELS")
 77          return x
 78          
 79      
 80  # ένα callback για να σταματήσει η εκπαίδευση όταν δούμε 90% accuracy
 81  class myCallback(tf.keras.callbacks.Callback):
 82      def on_epoch_end(self, epoch, logs={}):
 83          if(logs.get('val_loss')<0.080):
 84              print("\nReached 0.080 validation loss so cancelling training!")
 85              self.model.stop_training = True
 86  
 87  def adjust_shape(I, s):
 88      """Adjust shape of grayscale image I to s."""
 89      # crop if necesary
 90      I = I[:s[0],:s[1]]
 91      si = I.shape
 92  
 93      # pad if necessary 
 94      p0 = max(0,s[0] - si[0])
 95      p1 = max(0,s[1] - si[1])
 96  
 97      return np.pad(I,((0,p0),(0,p1)),'edge')
 98  
 99  
100  
101  def generate_short_id():
102      # Generate a UUID
103      unique_id = uuid.uuid4()
104  
105      # Convert UUID to a hex string and take the first 4 characters
106      short_id = str(unique_id.hex)[:4]
107  
108      return short_id
109  
110  s2mtcp_target = '/data/aleoikon_data/change_detection/ssl/s2mtcp/patches/task1/'
111  
112  #s2mtcp_target = '/home/aleoikon/Documents/data/ssl/s2mtcp/patches_colorshifted/task1/'
113  
114  df = pd.read_csv(s2mtcp_target+'dataset_unclouded.csv', dtype=str)
115  train = df.sample(frac=0.85,random_state=1)
116  validation = df.drop(train.index)
117  test = validation.sample(frac = 0.33, random_state=1)
118  validation = validation.drop(test.index)
119  
120  print("Data", len(df))
121  print("85% of Data = Train", len(train))
122  print("10% of Data = Validation", len(validation))
123  print("5% of Data = Test", len(test))
124  
125  test_balance = validation['overlap']
126  (unique, counts) = np.unique(test_balance , return_counts=True)
127  frequencies = np.asarray((unique, counts)).T
128  print(frequencies)
129  
130  n_ch = 3
131  channel = 'rgb'
132  method = 'STAND'
133  
134  X_train1 = np.ndarray(shape=(len(train),96,96,n_ch))
135  X_train2 = np.ndarray(shape=(len(train),96,96,n_ch))
136  y_train = np.ndarray(shape=(len(train),1))
137  X_val1 = np.ndarray(shape=(len(validation),96,96,n_ch))
138  X_val2 = np.ndarray(shape=(len(validation),96,96,n_ch))
139  y_val = np.ndarray(shape=(len(validation),1))
140  
141  
142  pos = 0
143  for index in train.index:
144      img1 = np.load(s2mtcp_target + train['pair1'][index])
145      img2 = np.load(s2mtcp_target + train['pair2'][index])
146      X1 = create_rgb(img1,channel)
147      X2 = create_rgb(img2, channel)
148      X1 = feature_scaling(X1, method)
149      X2 = feature_scaling(X2, method)
150      X_train1[pos] = X1
151      X_train2[pos] = X2
152      y_train[pos] = train['overlap'][index]
153      pos += 1
154  
155  pos = 0
156  for index in validation.index:
157      img1 = np.load(s2mtcp_target + validation['pair1'][index])
158      img2 = np.load(s2mtcp_target + validation['pair2'][index])
159      X1 = create_rgb(img1,channel)
160      X2 = create_rgb(img2, channel)
161      X1 = feature_scaling(X1, method)
162      X2 = feature_scaling(X2, method)
163      X_val1[pos] = X1
164      X_val2[pos] = X2
165      y_val[pos] = validation['overlap'][index]
166      pos += 1
167  
168  X_test1 = np.ndarray(shape=(len(test),96,96,n_ch))
169  X_test2 = np.ndarray(shape=(len(test),96,96,n_ch))
170  y_test = np.ndarray(shape=(len(test),1))
171  
172  pos = 0
173  for index in test.index:
174      img1 = np.load(s2mtcp_target + test['pair1'][index])
175      img2 = np.load(s2mtcp_target + test['pair2'][index])
176      X1 = create_rgb(img1,channel)
177      X2 = create_rgb(img2, channel)
178      X1 = feature_scaling(X1, method)
179      X2 = feature_scaling(X2, method)
180      X_test1[pos] = X1
181      X_test2[pos] = X2
182      y_test[pos] = test['overlap'][index]
183      pos += 1
184  
185  NORM = method
186  SHUFFLE = False
187  BATCH_SIZE = 5
188  dropout = 0.1
189  decay = 0.0001
190  model = pretext_task_one_nopool(dropout, decay, 96,96,n_ch)
191  model.summary()
192  
193  #Load saved model
194  #model_name='/saved_models/pretext_tasks/model_pretext1_unclouded_results.h5'
195  #model.load_weights(model_name)
196  #######
197  
198  
199  ## Callbacks
200  callbacks = myCallback()
201  
202  LEARNING_RATE = 0.001
203  EPOCHS = 6
204  optimizer= Adam(learning_rate=LEARNING_RATE)
205  model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
206  
207  
208  # Record start time
209  start_time = time.time()
210  
211  history = model.fit(
212      [X_train1, X_train2],
213      y_train,
214      batch_size = BATCH_SIZE,
215      epochs=EPOCHS,
216      validation_data=([X_val1, X_val2], y_val),
217      callbacks=[callbacks]
218  )
219  
220  # Record end time
221  end_time = time.time()
222  
223  elapsed_time = end_time - start_time
224  elapsed_time_minutes = elapsed_time / 60
225  
226  print(f"Training time: {elapsed_time_minutes:.2f} minutes")
227  
228  # summarize history for accuracy
229  plt.plot(history.history['accuracy'])
230  plt.plot(history.history['val_accuracy'])
231  plt.title('model accuracy')
232  plt.ylabel('accuracy')
233  plt.xlabel('epoch')
234  plt.legend(['train', 'val'], loc='upper left')
235  plt.show()
236  
237  # summarize history for loss
238  plt.plot(history.history['loss'])
239  plt.plot(history.history['val_loss'])
240  plt.title('model loss')
241  plt.ylabel('loss')
242  plt.xlabel('epoch')
243  plt.legend(['train', 'val'], loc='upper left')
244  plt.show()
245  
246  print("Evaluate on val data")
247  results_val = model.evaluate([X_val1, X_val2], y_val)
248  print("val loss, val acc:", results_val)
249  print("Evaluate on test data")
250  results_test = model.evaluate([X_test1, X_test2], y_test)
251  print("test loss, test acc:", results_test)
252  print("Evaluate on train data")
253  results_train = model.evaluate([X_train1, X_train2], y_train, batch_size=5)
254  print("train loss, train acc:", results_train)
255  
256  from sklearn.metrics import confusion_matrix
257  yv_pred = model.predict([X_test1, X_test2])
258  for i in range(len(yv_pred)):
259      if yv_pred[i] < 0.5:
260          yv_pred[i] = 0
261      else:
262          yv_pred[i]= 1
263  confusion_matrix(y_test, yv_pred, labels=[0,1])
264  
265  #visualize predictions 
266  def scaleMinMax(x):
267      return ((x - np.nanpercentile(x,2)) / (np.nanpercentile(x,98) - np.nanpercentile(x,2)))
268  
269  def create_rgb(x, channel):
270      if channel == 'red':
271          r = x[:,:,2]
272          r = scaleMinMax(r)
273          return r
274      if channel == 'green':
275          g = x[:,:,1]
276          g = scaleMinMax(g)
277          return g
278      if channel == 'blue':
279          b  = x[:,:,0]
280          b = scaleMinMax(b)
281          return b
282      if channel == 'rgb':
283          r = x[:,:,2]
284          g = x[:,:,1]
285          b  = x[:,:,0]
286          r = scaleMinMax(r)
287          g = scaleMinMax(g)
288          b = scaleMinMax(b)
289          rgb = np.dstack((r,g,b))
290          return(rgb)
291      
292  import random
293  fig, ax = plt.subplots(2, 3, figsize=(8,4),constrained_layout=True)
294  pair_pos = random.randint(0, len(y_test))
295  pair1 = tf.concat([create_rgb(X_test1[pair_pos], channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
296  ax[0,0].imshow(pair1)
297  ax[0,0].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
298  ax[0,0].axis('off')
299  pair_pos = random.randint(0, len(y_test))
300  pair2 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
301  ax[0,1].imshow(pair2)
302  ax[0,1].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
303  ax[0,1].axis('off')
304  pair_pos = random.randint(0, len(y_test))
305  pair3 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
306  ax[0,2].imshow(pair3)
307  ax[0,2].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
308  ax[0,2].axis('off')
309  pair_pos = random.randint(0, len(y_test))
310  pair1 = tf.concat([create_rgb(X_test1[pair_pos], channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
311  ax[1,0].imshow(pair1)
312  ax[1,0].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
313  ax[1,0].axis('off')
314  pair_pos = random.randint(0, len(y_test))
315  pair2 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
316  ax[1,1].imshow(pair2)
317  ax[1,1].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
318  ax[1,1].axis('off')
319  pair_pos = random.randint(0, len(y_test))
320  pair3 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
321  ax[1,2].imshow(pair3)
322  ax[1,2].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
323  ax[1,2].axis('off')
324  
325  from datetime import date
326  
327  today = date.today()
328  print("Today's date:", today)
329  
330  str(today)
331  
332  # #Save our model
333  # model_name = "8epoch_pre_1_normtrue_unclouded_nodatagen_drop015_callback007_norgb_seed1.h5"
334  # model.save(model_name)
335  # print("Saved model to disk")
336  model_id = generate_short_id()
337  #save weights / oi random onomasies prepei na figoun 
338  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
339  model_name="Pretext_1_NoPool_Grad_CAM"+"S2MTCP_"+model_id+'.h5'
340  model1 = os.path.join(model_path, model_name)
341  model.save_weights(model1) 
342  
343  log_params('S2MTCP',model_id, model_name, LEARNING_RATE, 'Adam', 'binary_crossentropy',EPOCHS, BATCH_SIZE, len(df), len(train), len(test), len(validation), results_train[1], results_train[0], results_val[1], results_val[0], results_test[1], results_test[0], NORM,elapsed_time_minutes)
344  df_params = pd.read_csv('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/training/pretext_tasks/pretext_task_one_models.csv')
345  
346  ########################## Test the saved weights #######################
347  model_2 = pretext_task_one_nopool(dropout,decay,96,96,n_ch)
348  model_2.load_weights(model1)
349  LEARNING_RATE = 0.001
350  optimizer= Adam(learning_rate=LEARNING_RATE)
351  model_2.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
352  
353  print("Evaluate on val data")
354  results_val = model_2.evaluate([X_val1, X_val2], y_val)
355  print("val loss, val acc:", results_val)
356  print("Evaluate on test data")
357  results_test = model_2.evaluate([X_test1, X_test2], y_test)
358  print("test loss, test acc:", results_test)
359  print("Evaluate on train data")
360  results_train = model_2.evaluate([X_train1, X_train2], y_train, batch_size=5)
361  print("train loss, train acc:", results_train)
362  
363  from sklearn.metrics import confusion_matrix
364  yv_pred = model_2.predict([X_test1, X_test2])
365  for i in range(len(yv_pred)):
366      if yv_pred[i] < 0.5:
367          yv_pred[i] = 0
368      else:
369          yv_pred[i]= 1
370  confusion_matrix(y_test, yv_pred, labels=[0,1])
371  
372  ##############################
373  # testing on Onera
374  onera_pretext_target = '/home/aleoikon/Documents/data/ssl/onera_npys/patches/task1/'
375  onera_df = pd.read_csv(onera_pretext_target+'dataset.csv', dtype=str)
376  
377  X_on1 = np.ndarray(shape=(len(onera_df),96,96,n_ch))
378  X_on2 = np.ndarray(shape=(len(onera_df),96,96,n_ch))
379  y_on = np.ndarray(shape=(len(onera_df),1))
380  
381  print(X_on1.shape)
382  print(X_on2.shape)
383  print(y_on.shape)
384  
385  def create_rgb_onera(x,channel):
386      if channel == 'red':
387          r = x[:,:,2]
388          r = np.expand_dims(r, axis=2)
389          return r
390      if channel == 'green':
391          g = x[:,:,1]
392          g = np.expand_dims(g, axis=2)
393          return g
394      if channel == 'blue':
395          b = x[:,:,0]
396          b = np.expand_dims(b, axis=2)
397          return b
398      if channel == 'rgb':
399          r = x[:,:,0]
400          g = x[:,:,1]
401          b = x[:,:,2]
402          rgb = np.dstack((r,g,b))
403          return(rgb)     
404      else:
405          return x
406          
407  img1 = np.load(onera_pretext_target + onera_df['pair1'][0])
408   
409  pos = 0
410  for index in onera_df.index:
411      print(index)
412      img1 = np.load(onera_pretext_target + onera_df['pair1'][index])
413      img2 = np.load(onera_pretext_target + onera_df['pair2'][index])
414      X1 = create_rgb_onera(img1, channel)
415      X2 = create_rgb_onera(img2, channel)
416      X1 = (X1 - X1.mean()) / X1.std()
417      X2 = (X2 - X2.mean()) / X2.std()
418      X_on1[pos] = X1
419      X_on2[pos] = X2
420      y_on[pos] = onera_df['overlap'][index]
421      pos += 1
422  print("Evaluate on Onera data")
423  results = model.evaluate([X_on1, X_on2], y_on, batch_size=5)
424  print("loss, acc:", results)
425  
426  y_on_pred = model.predict([X_on1, X_on2])
427  for i in range(len(y_on_pred)):
428      if y_on_pred[i] < 0.5:
429          y_on_pred[i] = 0
430      else:
431          y_on_pred[i]= 1    
432  
433  confusion_matrix(y_on, y_on_pred, labels=[0,1])
434  
435  log_params('Onera', model_id, model_name, LEARNING_RATE, 'Adam', 'binary_crossentropy',EPOCHS, BATCH_SIZE, len(onera_df), 0, 0, 0, results[1], results[0], 0, 0, 0, 0, NORM,elapsed_time_minutes)
436  
437  
438  import random
439  fig, ax = plt.subplots(2, 3, figsize=(8,4),constrained_layout=True)
440  pair_pos = random.randint(0, len(y_on))
441  pair1 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
442  ax[0,0].imshow(pair1)
443  ax[0,0].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
444  ax[0,0].axis('off')
445  pair_pos = random.randint(0, len(y_on))
446  pair2 = tf.concat([create_rgb(X_on1[pair_pos ], channel),create_rgb(X_on2[pair_pos ], channel)], axis=1)
447  ax[0,1].imshow(pair2)
448  ax[0,1].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
449  ax[0,1].axis('off')
450  pair_pos = random.randint(0, len(y_on))
451  pair3 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
452  ax[0,2].imshow(pair3)
453  ax[0,2].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
454  ax[0,2].axis('off')
455  pair_pos = random.randint(0, len(y_on))
456  pair1 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
457  ax[1,0].imshow(pair1)
458  ax[1,0].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
459  ax[1,0].axis('off')
460  pair_pos = random.randint(0, len(y_on))
461  pair2 = tf.concat([create_rgb(X_on1[pair_pos ], channel),create_rgb(X_on2[pair_pos ], channel)], axis=1)
462  ax[1,1].imshow(pair2)
463  ax[1,1].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
464  ax[1,1].axis('off')
465  pair_pos = random.randint(0, len(y_on))
466  pair3 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
467  ax[1,2].imshow(pair3)
468  ax[1,2].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
469  ax[1,2].axis('off')
470  
471  print("Finished")