/ training / pretext_tasks / pretext_task1.py
pretext_task1.py
  1  import sys
  2  sys.path.append('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese')
  3  
  4  
  5  import os
  6  import pandas as pd
  7  from architectures.similarity_detection import pretext_task_one_nopool, pretext_one, pretext_task_one_aspp
  8  from utils.log_params import log_params
  9  import matplotlib
 10  matplotlib.use('TkAgg')
 11  import matplotlib.pyplot as plt
 12  import numpy as np
 13  import tensorflow as tf
 14  from tensorflow.keras.optimizers import Adam
 15  from scipy.ndimage import zoom
 16  import time
 17  import uuid
 18  from datetime import datetime
 19  
 20  os.environ["CUDA_VISIBLE_DEVICES"]="1"
 21  
 22  def feature_scaling(img, method):
 23      I = img
 24      if method == "STAND":
 25          I = (I - I.mean()) / I.std()
 26          return I
 27      if method == "MINMAX":
 28          I = ((I - np.nanmin(I))/(np.nanmax(I) - np.nanmin(I)))
 29          return I
 30      if method == "MEAN":
 31          I = (I - I.mean()) / (np.nanmax(I) - np.nanmin(I))
 32      else:
 33          return I
 34      
 35  # ένα callback για να σταματήσει η εκπαίδευση όταν δούμε 90% accuracy
 36  class myCallback(tf.keras.callbacks.Callback):
 37      def on_epoch_end(self, epoch, logs={}):
 38          if(logs.get('val_loss')<0.080):
 39              print("\nReached 0.080 validation loss so cancelling training!")
 40              self.model.stop_training = True
 41  
 42  def adjust_shape(I, s):
 43      """Adjust shape of grayscale image I to s."""
 44      # crop if necesary
 45      I = I[:s[0],:s[1]]
 46      si = I.shape
 47  
 48      # pad if necessary 
 49      p0 = max(0,s[0] - si[0])
 50      p1 = max(0,s[1] - si[1])
 51  
 52      return np.pad(I,((0,p0),(0,p1)),'edge')
 53  
 54  
 55  
 56  def generate_short_id():
 57      # Generate a UUID
 58      unique_id = uuid.uuid4()
 59  
 60      # Convert UUID to a hex string and take the first 4 characters
 61      short_id = str(unique_id.hex)[:4]
 62  
 63      return short_id
 64  
 65  s2mtcp_target = '/data/aleoikon_data/change_detection/ssl/s2mtcp/patches/task1/'
 66  
 67  #s2mtcp_target = '/home/aleoikon/Documents/data/ssl/s2mtcp/patches_colorshifted/task1/'
 68  
 69  df = pd.read_csv(s2mtcp_target+'dataset_unclouded.csv', dtype=str)
 70  train = df.sample(frac=0.85,random_state=1)
 71  validation = df.drop(train.index)
 72  test = validation.sample(frac = 0.33, random_state=1)
 73  validation = validation.drop(test.index)
 74  
 75  print("Data", len(df))
 76  print("85% of Data = Train", len(train))
 77  print("10% of Data = Validation", len(validation))
 78  print("5% of Data = Test", len(test))
 79  
 80  test_balance = validation['overlap']
 81  (unique, counts) = np.unique(test_balance , return_counts=True)
 82  frequencies = np.asarray((unique, counts)).T
 83  print(frequencies)
 84  
 85  n_ch = 3
 86  channel = 'rgb'
 87  method = 'STAND'
 88  
 89  X_train1 = np.ndarray(shape=(len(train),96,96,n_ch))
 90  X_train2 = np.ndarray(shape=(len(train),96,96,n_ch))
 91  y_train = np.ndarray(shape=(len(train),1))
 92  X_val1 = np.ndarray(shape=(len(validation),96,96,n_ch))
 93  X_val2 = np.ndarray(shape=(len(validation),96,96,n_ch))
 94  y_val = np.ndarray(shape=(len(validation),1))
 95  
 96  def create_rgb(x,channel):
 97      if channel == 'red':
 98          r = x[:,:,1]
 99          r = np.expand_dims(r, axis=2)
100          return r
101      if channel == 'green':
102          g = x[:,:,2]
103          g = np.expand_dims(g, axis=2)
104          return g
105      if channel == 'blue':
106          b  = x[:,:,3]
107          b = np.expand_dims(b, axis=2)
108          return b
109      if channel == 'rgb':
110          r = x[:,:,1]
111          g = x[:,:,2]
112          b  = x[:,:,3]
113          rgb = np.dstack((r,g,b))
114          return(rgb)
115      if channel == 'rgbvnir':
116          r = x[:,:,1]
117          g = x[:,:,2]
118          b  = x[:,:,3]
119          vnir = x[:,:,8]
120          rgbvnir = np.stack((r,g,b,vnir),axis=2).astype('float')
121          #rgb = np.dstack((r,g,b))
122          return(rgbvnir)
123      if channel == 'eq20':
124          r = x[:,:,1]
125          s = r.shape
126          ir1 = adjust_shape(zoom(x[:,:,4],2),s)
127          ir2 = adjust_shape(zoom(x[:,:,5],2),s)
128          ir3 = adjust_shape(zoom(x[:,:,6],2),s)
129          nir2 = adjust_shape(zoom(x[:,:,8],2),s)
130          swir2 = adjust_shape(zoom(x[:,:,11],2),s)
131          swir3 = adjust_shape(zoom(x[:,:,12],2),s)
132          x = np.stack((ir1,ir2,ir3,nir2,swir2,swir3),axis=2).astype('float') 
133          return x
134      else:
135          return x
136          print("NOT CORRECT CHANNELS")
137  pos = 0
138  for index in train.index:
139      img1 = np.load(s2mtcp_target + train['pair1'][index])
140      img2 = np.load(s2mtcp_target + train['pair2'][index])
141      X1 = create_rgb(img1,channel)
142      X2 = create_rgb(img2, channel)
143      X1 = feature_scaling(X1, method)
144      X2 = feature_scaling(X2, method)
145      X_train1[pos] = X1
146      X_train2[pos] = X2
147      y_train[pos] = train['overlap'][index]
148      pos += 1
149  
150  pos = 0
151  for index in validation.index:
152      img1 = np.load(s2mtcp_target + validation['pair1'][index])
153      img2 = np.load(s2mtcp_target + validation['pair2'][index])
154      X1 = create_rgb(img1,channel)
155      X2 = create_rgb(img2, channel)
156      X1 = feature_scaling(X1, method)
157      X2 = feature_scaling(X2, method)
158      X_val1[pos] = X1
159      X_val2[pos] = X2
160      y_val[pos] = validation['overlap'][index]
161      pos += 1
162  
163  X_test1 = np.ndarray(shape=(len(test),96,96,n_ch))
164  X_test2 = np.ndarray(shape=(len(test),96,96,n_ch))
165  y_test = np.ndarray(shape=(len(test),1))
166  
167  pos = 0
168  for index in test.index:
169      img1 = np.load(s2mtcp_target + test['pair1'][index])
170      img2 = np.load(s2mtcp_target + test['pair2'][index])
171      X1 = create_rgb(img1,channel)
172      X2 = create_rgb(img2, channel)
173      X1 = feature_scaling(X1, method)
174      X2 = feature_scaling(X2, method)
175      X_test1[pos] = X1
176      X_test2[pos] = X2
177      y_test[pos] = test['overlap'][index]
178      pos += 1
179  
180  NORM = method
181  SHUFFLE = False
182  BATCH_SIZE = 5
183  dropout = 0.1
184  decay = 0.0001
185  model = pretext_task_one_aspp(dropout, decay, 96,96,n_ch)
186  model.summary()
187  
188  #Load saved model
189  #model_name='/saved_models/pretext_tasks/model_pretext1_unclouded_results.h5'
190  #model.load_weights(model_name)
191  #######
192  ## Callbacks
193  callbacks = myCallback()
194  
195  LEARNING_RATE = 0.001
196  EPOCHS = 6
197  optimizer= Adam(learning_rate=LEARNING_RATE)
198  model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
199  
200  
201  # Record start time
202  start_time = time.time()
203  
204  history = model.fit(
205      [X_train1, X_train2],
206      y_train,
207      batch_size = BATCH_SIZE,
208      epochs=EPOCHS,
209      validation_data=([X_val1, X_val2], y_val),
210      callbacks=[callbacks]
211  )
212  
213  # Record end time
214  end_time = time.time()
215  
216  elapsed_time = end_time - start_time
217  elapsed_time_minutes = elapsed_time / 60
218  
219  print(f"Training time: {elapsed_time_minutes:.2f} minutes")
220  
221  # summarize history for accuracy
222  plt.plot(history.history['accuracy'])
223  plt.plot(history.history['val_accuracy'])
224  plt.title('model accuracy')
225  plt.ylabel('accuracy')
226  plt.xlabel('epoch')
227  plt.legend(['train', 'val'], loc='upper left')
228  plt.show()
229  
230  # summarize history for loss
231  plt.plot(history.history['loss'])
232  plt.plot(history.history['val_loss'])
233  plt.title('model loss')
234  plt.ylabel('loss')
235  plt.xlabel('epoch')
236  plt.legend(['train', 'val'], loc='upper left')
237  plt.show()
238  
239  print("Evaluate on val data")
240  results_val = model.evaluate([X_val1, X_val2], y_val)
241  print("val loss, val acc:", results_val)
242  print("Evaluate on test data")
243  results_test = model.evaluate([X_test1, X_test2], y_test)
244  print("test loss, test acc:", results_test)
245  print("Evaluate on train data")
246  results_train = model.evaluate([X_train1, X_train2], y_train, batch_size=5)
247  print("train loss, train acc:", results_train)
248  
249  from sklearn.metrics import confusion_matrix
250  yv_pred = model.predict([X_test1, X_test2])
251  for i in range(len(yv_pred)):
252      if yv_pred[i] < 0.5:
253          yv_pred[i] = 0
254      else:
255          yv_pred[i]= 1
256  confusion_matrix(y_test, yv_pred, labels=[0,1])
257  
258  #visualize predictions 
259  def scaleMinMax(x):
260      return ((x - np.nanpercentile(x,2)) / (np.nanpercentile(x,98) - np.nanpercentile(x,2)))
261  
262  def create_rgb(x, channel):
263      if channel == 'red':
264          r = x[:,:,2]
265          r = scaleMinMax(r)
266          return r
267      if channel == 'green':
268          g = x[:,:,1]
269          g = scaleMinMax(g)
270          return g
271      if channel == 'blue':
272          b  = x[:,:,0]
273          b = scaleMinMax(b)
274          return b
275      if channel == 'rgb':
276          r = x[:,:,2]
277          g = x[:,:,1]
278          b  = x[:,:,0]
279          r = scaleMinMax(r)
280          g = scaleMinMax(g)
281          b = scaleMinMax(b)
282          rgb = np.dstack((r,g,b))
283          return(rgb)
284      
285  import random
286  fig, ax = plt.subplots(2, 3, figsize=(8,4),constrained_layout=True)
287  pair_pos = random.randint(0, len(y_test))
288  pair1 = tf.concat([create_rgb(X_test1[pair_pos], channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
289  ax[0,0].imshow(pair1)
290  ax[0,0].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
291  ax[0,0].axis('off')
292  pair_pos = random.randint(0, len(y_test))
293  pair2 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
294  ax[0,1].imshow(pair2)
295  ax[0,1].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
296  ax[0,1].axis('off')
297  pair_pos = random.randint(0, len(y_test))
298  pair3 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
299  ax[0,2].imshow(pair3)
300  ax[0,2].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
301  ax[0,2].axis('off')
302  pair_pos = random.randint(0, len(y_test))
303  pair1 = tf.concat([create_rgb(X_test1[pair_pos], channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
304  ax[1,0].imshow(pair1)
305  ax[1,0].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
306  ax[1,0].axis('off')
307  pair_pos = random.randint(0, len(y_test))
308  pair2 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
309  ax[1,1].imshow(pair2)
310  ax[1,1].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
311  ax[1,1].axis('off')
312  pair_pos = random.randint(0, len(y_test))
313  pair3 = tf.concat([create_rgb(X_test1[pair_pos],channel),create_rgb(X_test2[pair_pos],channel)], axis=1)
314  ax[1,2].imshow(pair3)
315  ax[1,2].set_title("True: {} | Pred: {}".format(y_test[pair_pos], yv_pred[pair_pos]))
316  ax[1,2].axis('off')
317  
318  from datetime import date
319  
320  today = date.today()
321  print("Today's date:", today)
322  
323  str(today)
324  
325  # #Save our model
326  # model_name = "8epoch_pre_1_normtrue_unclouded_nodatagen_drop015_callback007_norgb_seed1.h5"
327  # model.save(model_name)
328  # print("Saved model to disk")
329  model_id = generate_short_id()
330  #save weights / oi random onomasies prepei na figoun 
331  model_path = '/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/saved_models/'
332  model_name="Pretext_1_NoPool_"+"S2MTCP_"+model_id+'.h5'
333  model1 = os.path.join(model_path, model_name)
334  model.save_weights(model1) 
335  
336  log_params('S2MTCP',model_id, model_name, LEARNING_RATE, 'Adam', 'binary_crossentropy',EPOCHS, BATCH_SIZE, len(df), len(train), len(test), len(validation), results_train[1], results_train[0], results_val[1], results_val[0], results_test[1], results_test[0], NORM,elapsed_time_minutes)
337  df_params = pd.read_csv('/home/dvalsamis/Documents/projects/Change_detection_SSL_Siamese/training/pretext_tasks/pretext_task_one_models.csv')
338  
339  ########################## Test the saved weights #######################
340  model_2 = pretext_task_one_nopool(dropout,decay,96,96,n_ch)
341  model_2.load_weights(model1)
342  LEARNING_RATE = 0.001
343  optimizer= Adam(learning_rate=LEARNING_RATE)
344  model_2.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
345  
346  print("Evaluate on val data")
347  results_val = model_2.evaluate([X_val1, X_val2], y_val)
348  print("val loss, val acc:", results_val)
349  print("Evaluate on test data")
350  results_test = model_2.evaluate([X_test1, X_test2], y_test)
351  print("test loss, test acc:", results_test)
352  print("Evaluate on train data")
353  results_train = model_2.evaluate([X_train1, X_train2], y_train, batch_size=5)
354  print("train loss, train acc:", results_train)
355  
356  from sklearn.metrics import confusion_matrix
357  yv_pred = model_2.predict([X_test1, X_test2])
358  for i in range(len(yv_pred)):
359      if yv_pred[i] < 0.5:
360          yv_pred[i] = 0
361      else:
362          yv_pred[i]= 1
363  confusion_matrix(y_test, yv_pred, labels=[0,1])
364  
365  ##############################
366  # testing on Onera
367  onera_pretext_target = '/home/aleoikon/Documents/data/ssl/onera_npys/patches/task1/'
368  onera_df = pd.read_csv(onera_pretext_target+'dataset.csv', dtype=str)
369  
370  X_on1 = np.ndarray(shape=(len(onera_df),96,96,n_ch))
371  X_on2 = np.ndarray(shape=(len(onera_df),96,96,n_ch))
372  y_on = np.ndarray(shape=(len(onera_df),1))
373  
374  print(X_on1.shape)
375  print(X_on2.shape)
376  print(y_on.shape)
377  
378  def create_rgb_onera(x,channel):
379      if channel == 'red':
380          r = x[:,:,2]
381          r = np.expand_dims(r, axis=2)
382          return r
383      if channel == 'green':
384          g = x[:,:,1]
385          g = np.expand_dims(g, axis=2)
386          return g
387      if channel == 'blue':
388          b = x[:,:,0]
389          b = np.expand_dims(b, axis=2)
390          return b
391      if channel == 'rgb':
392          r = x[:,:,0]
393          g = x[:,:,1]
394          b = x[:,:,2]
395          rgb = np.dstack((r,g,b))
396          return(rgb)     
397      else:
398          return x
399          
400  img1 = np.load(onera_pretext_target + onera_df['pair1'][0])
401   
402  pos = 0
403  for index in onera_df.index:
404      print(index)
405      img1 = np.load(onera_pretext_target + onera_df['pair1'][index])
406      img2 = np.load(onera_pretext_target + onera_df['pair2'][index])
407      X1 = create_rgb_onera(img1, channel)
408      X2 = create_rgb_onera(img2, channel)
409      X1 = (X1 - X1.mean()) / X1.std()
410      X2 = (X2 - X2.mean()) / X2.std()
411      X_on1[pos] = X1
412      X_on2[pos] = X2
413      y_on[pos] = onera_df['overlap'][index]
414      pos += 1
415  print("Evaluate on Onera data")
416  results = model.evaluate([X_on1, X_on2], y_on, batch_size=5)
417  print("loss, acc:", results)
418  
419  y_on_pred = model.predict([X_on1, X_on2])
420  for i in range(len(y_on_pred)):
421      if y_on_pred[i] < 0.5:
422          y_on_pred[i] = 0
423      else:
424          y_on_pred[i]= 1    
425  
426  confusion_matrix(y_on, y_on_pred, labels=[0,1])
427  
428  log_params('Onera', model_id, model_name, LEARNING_RATE, 'Adam', 'binary_crossentropy',EPOCHS, BATCH_SIZE, len(onera_df), 0, 0, 0, results[1], results[0], 0, 0, 0, 0, NORM,elapsed_time_minutes)
429  
430  
431  import random
432  fig, ax = plt.subplots(2, 3, figsize=(8,4),constrained_layout=True)
433  pair_pos = random.randint(0, len(y_on))
434  pair1 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
435  ax[0,0].imshow(pair1)
436  ax[0,0].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
437  ax[0,0].axis('off')
438  pair_pos = random.randint(0, len(y_on))
439  pair2 = tf.concat([create_rgb(X_on1[pair_pos ], channel),create_rgb(X_on2[pair_pos ], channel)], axis=1)
440  ax[0,1].imshow(pair2)
441  ax[0,1].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
442  ax[0,1].axis('off')
443  pair_pos = random.randint(0, len(y_on))
444  pair3 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
445  ax[0,2].imshow(pair3)
446  ax[0,2].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
447  ax[0,2].axis('off')
448  pair_pos = random.randint(0, len(y_on))
449  pair1 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
450  ax[1,0].imshow(pair1)
451  ax[1,0].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
452  ax[1,0].axis('off')
453  pair_pos = random.randint(0, len(y_on))
454  pair2 = tf.concat([create_rgb(X_on1[pair_pos ], channel),create_rgb(X_on2[pair_pos ], channel)], axis=1)
455  ax[1,1].imshow(pair2)
456  ax[1,1].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
457  ax[1,1].axis('off')
458  pair_pos = random.randint(0, len(y_on))
459  pair3 = tf.concat([create_rgb(X_on1[pair_pos], channel),create_rgb(X_on2[pair_pos], channel)], axis=1)
460  ax[1,2].imshow(pair3)
461  ax[1,2].set_title("True: {} | Pred: {}".format(y_on[pair_pos], y_on_pred[pair_pos]))
462  ax[1,2].axis('off')
463  
464  print("Finished")