/ architecture / model_triplet_loss.py
model_triplet_loss.py
1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 """ 4 Created on Tue May 3 07:34:39 2022 5 6 @author: aleoikon 7 """ 8 import sys 9 sys.path.append(r'C:\Users\dvalsamis\change\Change_detection_SSL_Siamese') 10 import os 11 os.environ["CUDA_VISIBLE_DEVICES"]="1" 12 from architectures.branch import branches_triplet 13 import tensorflow as tf 14 from tensorflow.keras import layers 15 from tensorflow.keras import Model 16 from tensorflow.keras import metrics 17 18 19 20 #distance layer according to paper 21 class DistanceLayer(layers.Layer): 22 """ 23 This layer is responsible for computing the distance between the anchor 24 embedding and the positive embedding, and the anchor embedding and the 25 negative embedding. 26 """ 27 28 def __init__(self, **kwargs): 29 super().__init__(**kwargs) 30 31 def call(self, anchor, positive, negative): 32 ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1) 33 an_distance = tf.reduce_sum(tf.square(anchor - negative), -1) 34 l1_distance = tf.reduce_sum(tf.abs(anchor - positive)) 35 return (ap_distance, an_distance, l1_distance) 36 37 #paper model 38 class SiameseModel(Model): 39 """The Siamese Network model with a custom training and testing loops. 40 41 Computes the triplet loss using the three embeddings produced by the 42 Siamese Network. 43 44 The triplet loss is defined as: 45 L(A, P, N) = max(‖f(A) - f(P)‖² - ‖f(A) - f(N)‖² + margin, 0) + gamma*|f(A) - f(P)| 46 """ 47 48 def __init__(self, siamese_network, margin=0.2, gamma=1): 49 super(SiameseModel, self).__init__() 50 self.siamese_network = siamese_network 51 self.margin = margin 52 self.gamma = gamma 53 self.loss_tracker = metrics.Mean(name="loss") 54 55 def call(self, inputs): 56 return self.siamese_network(inputs) 57 58 def train_step(self, data): 59 # GradientTape is a context manager that records every operation that 60 # you do inside. We are using it here to compute the loss so we can get 61 # the gradients and apply them using the optimizer specified in 62 # `compile()`. 63 with tf.GradientTape() as tape: 64 loss = self._compute_loss(data) 65 66 # Storing the gradients of the loss function with respect to the 67 # weights/parameters. 68 gradients = tape.gradient(loss, self.siamese_network.trainable_weights) 69 70 # Applying the gradients on the model using the specified optimizer 71 self.optimizer.apply_gradients( 72 zip(gradients, self.siamese_network.trainable_weights) 73 ) 74 75 # Let's update and return the training loss metric. 76 self.loss_tracker.update_state(loss) 77 return {"loss": self.loss_tracker.result()} 78 79 def test_step(self, data): 80 loss = self._compute_loss(data) 81 82 # Let's update and return the loss metric. 83 self.loss_tracker.update_state(loss) 84 return {"loss": self.loss_tracker.result()} 85 86 def _compute_loss(self, data): 87 # The output of the network is a tuple containing the distances 88 # between the anchor and the positive example, and the anchor and 89 # the negative example. 90 ap_distance, an_distance, l1_distance = self.siamese_network(data) 91 92 # Computing the Triplet Loss by subtracting both distances and 93 # making sure we don't get a negative value. 94 l1_loss = self.gamma*l1_distance 95 loss = ap_distance - an_distance 96 loss = tf.maximum(loss + self.margin, 0.0) 97 return loss 98 99 @property 100 def metrics(self): 101 # We need to list our metrics here so the `reset_states()` can be 102 # called automatically. 103 return [self.loss_tracker] 104 105 def pretext_task_2_model(dropout, decay, IMG_HEIGHT, IMG_WIDTH , n_ch): 106 base_cnn = branches_triplet(dropout, decay, IMG_HEIGHT, IMG_WIDTH , n_ch) 107 flatten = layers.Flatten()(base_cnn.output) 108 dense1 = layers.Dense(512, activation="relu")(flatten) 109 dense1 = layers.BatchNormalization()(dense1) 110 dense2 = layers.Dense(256, activation="relu")(dense1) 111 dense2 = layers.BatchNormalization()(dense2) 112 output = layers.Dense(256)(dense2) 113 embedding = Model(base_cnn.input, output, name="Embedding") 114 115 anchor_input = layers.Input((int(IMG_HEIGHT), int(IMG_WIDTH), int(n_ch))) 116 positive_input = layers.Input((int(IMG_HEIGHT), int(IMG_WIDTH), int(n_ch))) 117 negative_input = layers.Input((int(IMG_HEIGHT), int(IMG_WIDTH), int(n_ch))) 118 119 distances = DistanceLayer()( 120 embedding(anchor_input), 121 embedding(positive_input), 122 embedding(negative_input), 123 ) 124 125 siamese_network = Model( 126 inputs=[anchor_input, positive_input, negative_input], outputs=distances 127 ) 128 129 siamese_model = SiameseModel(siamese_network) 130 131 return siamese_model, embedding 132 133