/ architecture / model_triplet_loss.py
model_triplet_loss.py
  1  #!/usr/bin/env python3
  2  # -*- coding: utf-8 -*-
  3  """
  4  Created on Tue May  3 07:34:39 2022
  5  
  6  @author: aleoikon
  7  """
  8  import sys
  9  sys.path.append(r'C:\Users\dvalsamis\change\Change_detection_SSL_Siamese')
 10  import os
 11  os.environ["CUDA_VISIBLE_DEVICES"]="1"
 12  from architectures.branch import branches_triplet
 13  import tensorflow as tf
 14  from tensorflow.keras import layers
 15  from tensorflow.keras import Model
 16  from tensorflow.keras import metrics
 17  
 18  
 19  
 20  #distance layer according to paper
 21  class DistanceLayer(layers.Layer):
 22      """
 23      This layer is responsible for computing the distance between the anchor
 24      embedding and the positive embedding, and the anchor embedding and the
 25      negative embedding.
 26      """
 27  
 28      def __init__(self, **kwargs):
 29          super().__init__(**kwargs)
 30  
 31      def call(self, anchor, positive, negative):
 32          ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
 33          an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
 34          l1_distance = tf.reduce_sum(tf.abs(anchor - positive))
 35          return (ap_distance, an_distance, l1_distance)
 36      
 37  #paper model
 38  class SiameseModel(Model):
 39      """The Siamese Network model with a custom training and testing loops.
 40  
 41      Computes the triplet loss using the three embeddings produced by the
 42      Siamese Network.
 43  
 44      The triplet loss is defined as:
 45         L(A, P, N) = max(‖f(A) - f(P)‖² - ‖f(A) - f(N)‖² + margin, 0) + gamma*|f(A) - f(P)|
 46      """
 47  
 48      def __init__(self, siamese_network, margin=0.2, gamma=1):
 49          super(SiameseModel, self).__init__()
 50          self.siamese_network = siamese_network
 51          self.margin = margin
 52          self.gamma = gamma
 53          self.loss_tracker = metrics.Mean(name="loss")
 54  
 55      def call(self, inputs):
 56          return self.siamese_network(inputs)
 57  
 58      def train_step(self, data):
 59          # GradientTape is a context manager that records every operation that
 60          # you do inside. We are using it here to compute the loss so we can get
 61          # the gradients and apply them using the optimizer specified in
 62          # `compile()`.
 63          with tf.GradientTape() as tape:
 64              loss = self._compute_loss(data)
 65  
 66          # Storing the gradients of the loss function with respect to the
 67          # weights/parameters.
 68          gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
 69  
 70          # Applying the gradients on the model using the specified optimizer
 71          self.optimizer.apply_gradients(
 72              zip(gradients, self.siamese_network.trainable_weights)
 73          )
 74  
 75          # Let's update and return the training loss metric.
 76          self.loss_tracker.update_state(loss)
 77          return {"loss": self.loss_tracker.result()}
 78  
 79      def test_step(self, data):
 80          loss = self._compute_loss(data)
 81  
 82          # Let's update and return the loss metric.
 83          self.loss_tracker.update_state(loss)
 84          return {"loss": self.loss_tracker.result()}
 85  
 86      def _compute_loss(self, data):
 87          # The output of the network is a tuple containing the distances
 88          # between the anchor and the positive example, and the anchor and
 89          # the negative example.
 90          ap_distance, an_distance, l1_distance = self.siamese_network(data)
 91  
 92          # Computing the Triplet Loss by subtracting both distances and
 93          # making sure we don't get a negative value.
 94          l1_loss = self.gamma*l1_distance
 95          loss = ap_distance - an_distance
 96          loss = tf.maximum(loss + self.margin, 0.0)
 97          return loss
 98  
 99      @property
100      def metrics(self):
101          # We need to list our metrics here so the `reset_states()` can be
102          # called automatically.
103          return [self.loss_tracker]
104  
105  def pretext_task_2_model(dropout, decay, IMG_HEIGHT, IMG_WIDTH , n_ch):
106      base_cnn = branches_triplet(dropout, decay, IMG_HEIGHT, IMG_WIDTH , n_ch)
107      flatten = layers.Flatten()(base_cnn.output)
108      dense1 = layers.Dense(512, activation="relu")(flatten)
109      dense1 = layers.BatchNormalization()(dense1)
110      dense2 = layers.Dense(256, activation="relu")(dense1)
111      dense2 = layers.BatchNormalization()(dense2)
112      output = layers.Dense(256)(dense2)
113      embedding = Model(base_cnn.input, output, name="Embedding")
114      
115      anchor_input = layers.Input((int(IMG_HEIGHT), int(IMG_WIDTH), int(n_ch)))
116      positive_input = layers.Input((int(IMG_HEIGHT), int(IMG_WIDTH), int(n_ch)))
117      negative_input = layers.Input((int(IMG_HEIGHT), int(IMG_WIDTH), int(n_ch)))
118  
119      distances = DistanceLayer()(
120          embedding(anchor_input),
121          embedding(positive_input),
122          embedding(negative_input),
123      )
124  
125      siamese_network = Model(
126          inputs=[anchor_input, positive_input, negative_input], outputs=distances
127      )
128  
129      siamese_model = SiameseModel(siamese_network)
130      
131      return siamese_model, embedding
132      
133