/ architecture / branch.py
branch.py
  1  import tensorflow as tf
  2  from tensorflow import keras
  3  from tensorflow.keras import layers
  4  from keras.regularizers import l2
  5  
  6  
  7  decay = 0.0001
  8  dropout = 0.1
  9  
 10  def abs_diff(vects):
 11      x,y = vects
 12      result = tf.math.abs(tf.math.subtract(x,y))
 13      return result
 14  
 15  
 16  def branches(IMG_HEIGHT=96, IMG_WIDTH=96, IMG_CHANNELS=3):
 17  
 18      input_b = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
 19      x_b = layers.BatchNormalization()(input_b)
 20      
 21      br_conv_1_b = layers.Conv2D(32, (3, 3), activation="relu", padding="same", name='conv1')(x_b)
 22      x_b = layers.Dropout(0.1, seed=1, name = 'dropout1')(br_conv_1_b)
 23      x_b = layers.MaxPooling2D(pool_size=(2, 2), name = 'pool1')(x_b)
 24      
 25      br_conv_2_b = layers.Conv2D(32, (3, 3), activation="relu", padding="same",name='conv2')(x_b)
 26      x_b = layers.Dropout(0.1, seed=1, name='dropout2')(br_conv_2_b)
 27      x_b = layers.MaxPooling2D(pool_size=(2, 2), name='pool2')(x_b)
 28      
 29      br_conv_3_b = layers.Conv2D(32, (3, 3), activation="relu", padding="same", name = 'conv3')(x_b)
 30      x_b = layers.Dropout(0.1, seed=1, name='dropout3')(br_conv_3_b)
 31      x_b = layers.MaxPooling2D(pool_size=(2, 2), name='pool3')(x_b)
 32      
 33  
 34      branch_network_b = keras.Model(input_b, x_b, name='branch')
 35      
 36      return branch_network_b
 37  
 38  
 39  
 40  def branches_nopool(dropout, decay, IMG_HEIGHT=96, IMG_WIDTH=96, IMG_CHANNELS=3):
 41  
 42      input_b = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
 43      x_b = layers.BatchNormalization()(input_b)
 44      
 45      br_conv_1_b = layers.Conv2D(32, (3, 3), activation="relu", kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1')(x_b)
 46      x_b = layers.Dropout(dropout, seed=1, name = 'dropout1')(br_conv_1_b)
 47      #x_b = layers.MaxPooling2D(pool_size=(2, 2), name = 'pool1')(x_b)
 48      
 49      br_conv_2_b = layers.Conv2D(32, (3, 3), activation="relu", kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv2')(x_b)
 50      x_b = layers.Dropout(dropout, seed=1, name='dropout2')(br_conv_2_b)
 51      #x_b = layers.MaxPooling2D(pool_size=(2, 2), name='pool2')(x_b)
 52      
 53      br_conv_3_b = layers.Conv2D(32, (3, 3), activation="relu", kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name = 'conv3')(x_b)
 54      x_b = layers.Dropout(dropout, seed=1, name='dropout3')(br_conv_3_b)
 55      #x_b = layers.MaxPooling2D(pool_size=(2, 2), name='pool3')(x_b)
 56      
 57  
 58      branch_network_b = keras.Model(input_b, x_b, name='branch')
 59      
 60      return branch_network_b
 61  
 62  def branch_cva(dropout, decay, depth, IMG_HEIGHT=96, IMG_WIDTH=96, IMG_CHANNELS=3):
 63    
 64      input_1 = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
 65      ##depth 1
 66      x_1 = layers.BatchNormalization(name='norm_1')(input_1)
 67      
 68      conv1_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1_1')(x_1)
 69      batch_norm1_1 = layers.BatchNormalization(name='norm1_1')(conv1_1)
 70      activation1_1 = layers.Activation('relu', name='relu1_1')(batch_norm1_1)
 71      drop1_1 = layers.Dropout(dropout, seed=1, name='dropout1_1')(activation1_1)
 72      
 73      ##depth 2
 74      conv2_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv2_1')(drop1_1)
 75      batch_norm2_1 = layers.BatchNormalization(name='norm2_1')(conv2_1)
 76      activation2_1 = layers.Activation('relu', name='relu2_1')(batch_norm2_1)
 77      drop2_1 = layers.Dropout(dropout, seed=1, name='dropout2_1')(activation2_1)
 78      
 79      ##depth 3
 80      conv3_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv3_1')(drop2_1)
 81      batch_norm3_1 = layers.BatchNormalization(name='norm3_1')(conv3_1)
 82      activation3_1 = layers.Activation('relu', name='relu3_1')(batch_norm3_1)
 83      drop3_1 = layers.Dropout(dropout, seed=1, name='dropout3_1')(activation3_1)
 84      
 85      if depth == 1:
 86          branch_network = keras.Model(input_1, drop1_1, name='branch')
 87      if depth == 2:
 88          branch_network = keras.Model(input_1, drop2_1, name='branch')
 89      if depth == 3:
 90          branch_network = keras.Model(input_1, drop3_1, name='branch')
 91  
 92      return branch_network
 93  
 94  
 95  def create_nspp_block(input_feature, IMG_HEIGHT=96, IMG_WIDTH=96, filters=32, scales=[2, 4, 8, 16]):
 96      nspp_features = []
 97  
 98      for i, scale in enumerate(scales):
 99          # Pooling Block
100          pooled = layers.AveragePooling2D(pool_size=(scale, scale), padding='same', name=f'pooled_{scale}')(input_feature)
101          pooling_block = layers.Conv2D(filters // 4, (1, 1), padding='same', activation='relu', name=f'pooling_block_{scale}')(pooled)
102          
103          # Strided Convolution Block
104          strided_conv = layers.SeparableConv2D(filters // 4, (3, 3), strides=(scale, scale), padding='same', name=f'strided_conv_{scale}')(input_feature)
105          merged_features = layers.Add(name=f'merged_features_{scale}')([pooling_block, strided_conv])
106  
107          # Global Pooling Block
108          reduced_mean = layers.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2], keepdims=True), name=f'reduced_mean_{scale}')(merged_features)
109          pointwise_conv = layers.Conv2D(filters // 4, (1, 1), padding='same', name=f'pointwise_conv_{scale}')(reduced_mean)
110  
111          # Resize the tensor to the desired shape
112          resized_feature = tf.image.resize(pointwise_conv, (IMG_HEIGHT, IMG_WIDTH), name=f'resized_feature_{scale}')
113  
114          nspp_features.append(resized_feature)
115  
116      # Channel-wise concatenation of the features from the four scales
117      nspp_output = layers.Concatenate(axis=-1, name='nspp_concat')(nspp_features)
118      # 1x1 Convolution to match the channel dimensions with the original input feature size
119      final_nspp_output = layers.Conv2D(filters, (1, 1), padding='same', name='final_nspp_output')(nspp_output)
120  
121      return final_nspp_output
122  
123  
124  
125  
126  
127  def branch_cva_with_nspp(dropout, decay, depth, IMG_HEIGHT=96, IMG_WIDTH=96, IMG_CHANNELS=3, filters=32, scales=[2, 4, 8, 16]):
128      input_1 = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
129      
130      # Depth 1
131      x_1 = layers.BatchNormalization(name='norm_1')(input_1)
132      conv1_1 = layers.Conv2D(filters, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1_1')(x_1)
133      batch_norm1_1 = layers.BatchNormalization(name='norm1_1')(conv1_1)
134      activation1_1 = layers.Activation('relu', name='relu1_1')(batch_norm1_1)
135      drop1_1 = layers.Dropout(dropout, seed=1, name='dropout1_1')(activation1_1)
136      output_feature = drop1_1
137      
138      # Depth 2
139      if depth >= 2:
140          conv2_1 = layers.Conv2D(filters, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv2_1')(drop1_1)
141          batch_norm2_1 = layers.BatchNormalization(name='norm2_1')(conv2_1)
142          activation2_1 = layers.Activation('relu', name='relu2_1')(batch_norm2_1)
143          drop2_1 = layers.Dropout(dropout, seed=1, name='dropout2_1')(activation2_1)
144          output_feature = drop2_1
145      
146      # Depth 3
147      if depth == 3:
148          conv3_1 = layers.Conv2D(filters, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv3_1')(drop2_1)
149          batch_norm3_1 = layers.BatchNormalization(name='norm3_1')(conv3_1)
150          activation3_1 = layers.Activation('relu', name='relu3_1')(batch_norm3_1)
151          drop3_1 = layers.Dropout(dropout, seed=1, name='dropout3_1')(activation3_1)
152          output_feature = drop3_1
153      
154      # NSPP block is added here, after the last dropout layer based on the specified depth
155      nspp_output = create_nspp_block(output_feature, IMG_HEIGHT, IMG_WIDTH, filters, scales)
156      
157      branch_network = keras.Model(inputs=input_1, outputs=nspp_output, name='branch_nspp')
158      
159      return branch_network
160  
161  
162  
163  
164  
165  def branches_triplet(dropout, decay, IMG_HEIGHT=96, IMG_WIDTH=96, IMG_CHANNELS=3):
166  
167      input_b = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
168      x_b = layers.BatchNormalization()(input_b)
169      
170      br_conv_1_b = layers.Conv2D(32, (3, 3), activation="relu", kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1')(x_b)
171      x_b = layers.Dropout(dropout, seed=1, name = 'dropout1')(br_conv_1_b)
172   
173      br_conv_2_b = layers.Conv2D(32, (3, 3), activation="relu", kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv2')(x_b)
174      x_b = layers.Dropout(dropout, seed=1, name='dropout2')(br_conv_2_b)
175  
176      br_conv_3_b = layers.Conv2D(32, (3, 3), activation="relu", kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name = 'conv3')(x_b)
177      x_b = layers.Dropout(dropout, seed=1, name='dropout3')(br_conv_3_b)
178      
179      branch_network_b = keras.Model(input_b, x_b, name='branch_triplet')
180      
181      return branch_network_b
182  
183  
184  
185  
186  def ASPP(inputs, filters, dilation_rates):
187      # 1x1 convolution
188      conv_1x1 = layers.Conv2D(filters, (1, 1), padding='same', name='aspp_conv_1x1')(inputs)
189      conv_1x1_bn = layers.BatchNormalization(name='aspp_conv_1x1_bn')(conv_1x1)
190      conv_1x1_relu = layers.Activation('relu', name='aspp_conv_1x1_relu')(conv_1x1_bn)
191  
192      # Atrous convolutions with different dilation rates
193      atrous_layers = [conv_1x1_relu]
194      for idx, rate in enumerate(dilation_rates):
195          atrous_conv = layers.Conv2D(filters, (3, 3), dilation_rate=rate, padding='same', name=f'aspp_conv_{rate}')(inputs)
196          atrous_conv_bn = layers.BatchNormalization(name=f'aspp_conv_{rate}_bn')(atrous_conv)
197          atrous_conv_relu = layers.Activation('relu', name=f'aspp_conv_{rate}_relu')(atrous_conv_bn)
198          atrous_layers.append(atrous_conv_relu)
199      
200      # Concatenate the atrous convolutions
201      concatenated = layers.Concatenate(axis=-1, name='aspp_concat')(atrous_layers)
202  
203      # Reduce the number of channels
204      reduced = layers.Conv2D(filters, (1, 1), padding='same', name='aspp_reduced')(concatenated)
205      reduced_bn = layers.BatchNormalization(name='aspp_reduced_bn')(reduced)
206      reduced_relu = layers.Activation('relu', name='aspp_reduced_relu')(reduced_bn)
207      
208      return reduced_relu
209  
210  
211  def branch_cva_aspp(dropout, decay, depth, IMG_HEIGHT=96, IMG_WIDTH=96, IMG_CHANNELS=3, aspp_filters=32, aspp_rates=[6, 12]):
212      input_1 = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
213      
214      # Depth 1
215      x_1 = layers.BatchNormalization(name='norm_1')(input_1)
216      conv1_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), padding="same", name='conv1_1')(x_1)
217      batch_norm1_1 = layers.BatchNormalization(name='norm1_1')(conv1_1)
218      activation1_1 = layers.Activation('relu', name='relu1_1')(batch_norm1_1)
219      drop1_1 = layers.Dropout(dropout, name='dropout1_1')(activation1_1)
220  
221      if depth == 1:
222          aspp_output_1 = ASPP(drop1_1, aspp_filters, aspp_rates)
223          branch_network = keras.Model(inputs=input_1, outputs=aspp_output_1, name='branch_cva_depth1')
224          return branch_network
225      
226      # Depth 2
227      conv2_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), padding="same", name='conv2_1')(drop1_1)
228      batch_norm2_1 = layers.BatchNormalization(name='norm2_1')(conv2_1)
229      activation2_1 = layers.Activation('relu', name='relu2_1')(batch_norm2_1)
230      drop2_1 = layers.Dropout(dropout, name='dropout2_1')(activation2_1)
231  
232      if depth == 2:
233          aspp_output_2 = ASPP(drop2_1, aspp_filters, aspp_rates)
234          branch_network = keras.Model(inputs=input_1, outputs=aspp_output_2, name='branch_cva_depth2')
235          return branch_network
236  
237      # Depth 3
238      conv3_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), padding="same", name='conv3_1')(drop2_1)
239      batch_norm3_1 = layers.BatchNormalization(name='norm3_1')(conv3_1)
240      activation3_1 = layers.Activation('relu', name='relu3_1')(batch_norm3_1)
241      drop3_1 = layers.Dropout(dropout, name='dropout3_1')(activation3_1)
242      
243      aspp_output_3 = ASPP(drop3_1, aspp_filters, aspp_rates)
244      branch_network = keras.Model(inputs=input_1, outputs=aspp_output_3, name='branch_cva_depth3')
245      
246      return branch_network
247  
248  
249  
250  
251  def two_branch_cva_with_aspp(dropout, decay, depth, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, aspp_filters=32, aspp_rates=[6, 12]):
252  
253      input_1 = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
254      input_2 = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
255  
256      ##depth 1
257      x_1 = layers.BatchNormalization(name='norm_1')(input_1)
258      
259      conv1_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1_1')(x_1)
260      batch_norm1_1 = layers.BatchNormalization(name='norm1_1')(conv1_1)
261      activation1_1 = layers.Activation('relu', name='relu1_1')(batch_norm1_1)
262      drop1_1 = layers.Dropout(dropout, seed=1, name='dropout1_1')(activation1_1)
263      
264      ##depth 2
265      conv2_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv2_1')(drop1_1)
266      batch_norm2_1 = layers.BatchNormalization(name='norm2_1')(conv2_1)
267      activation2_1 = layers.Activation('relu', name='relu2_1')(batch_norm2_1)
268      drop2_1 = layers.Dropout(dropout, seed=1, name='dropout2_1')(activation2_1)
269      
270      ##depth 3
271      conv3_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv3_1')(drop2_1)
272      batch_norm3_1 = layers.BatchNormalization(name='norm3_1')(conv3_1)
273      activation3_1 = layers.Activation('relu', name='relu3_1')(batch_norm3_1)
274      drop3_1 = layers.Dropout(dropout, seed=1, name='dropout3_1')(activation3_1)
275         
276      x_2 = layers.BatchNormalization(name='norm_2')(input_2)
277      
278      ##depth 1
279      conv1_2 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1_2')(x_2)
280      batch_norm1_2 = layers.BatchNormalization(name='norm1_2')(conv1_2)
281      activation1_2 = layers.Activation('relu', name='relu1_2')(batch_norm1_2)
282      drop1_2 = layers.Dropout(dropout, seed=1, name='dropout1_2')(activation1_2)
283      
284      ##depth 2
285      conv2_2 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv2_2')(drop1_2)
286      batch_norm2_2 = layers.BatchNormalization(name='norm2_2')(conv2_2)
287      activation2_2 = layers.Activation('relu', name='relu2_2')(batch_norm2_2)
288      drop2_2 = layers.Dropout(dropout, seed=1, name='dropout2_2')(activation2_2)
289       
290      ##depth 3
291      conv3_2 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv3_2')(drop2_2)
292      batch_norm3_2 = layers.BatchNormalization(name='norm3_2')(conv3_2)
293      activation3_2 = layers.Activation('relu', name='relu3_2')(batch_norm3_2)
294      drop3_2 = layers.Dropout(dropout, seed=1, name='dropout3_2')(activation3_2)
295      #######################################UPSAMPLING#####################################################
296      
297      if depth == 1:
298          distance = layers.Lambda(abs_diff)([drop1_1, drop1_2])
299      if depth == 2:
300          distance = layers.Lambda(abs_diff)([drop2_1, drop2_2])
301      if depth == 3:
302          distance = layers.Lambda(abs_diff)([drop3_1, drop3_2])
303      
304      encoder_output = distance  # Output from the encoder stage
305  
306      # Apply the ASPP block after the last convolutional layer
307      aspp_output = ASPP(encoder_output, aspp_filters, aspp_rates)
308  
309      # Create model
310      branch_model = keras.Model(inputs=[input_1, input_2], outputs=aspp_output, name='branch')
311      
312      return branch_model
313  
314  
315  def two_branch_cva_with_aspp_fmaps(dropout, decay, depth, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS, aspp_filters=32, aspp_rates=[6, 12]):
316  
317      input_1 = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
318      input_2 = layers.Input((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
319  
320      ##depth 1
321      x_1 = layers.BatchNormalization(name='norm_1')(input_1)
322      
323      conv1_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1_1')(x_1)
324      batch_norm1_1 = layers.BatchNormalization(name='norm1_1')(conv1_1)
325      activation1_1 = layers.Activation('relu', name='relu1_1')(batch_norm1_1)
326      drop1_1 = layers.Dropout(dropout, seed=1, name='dropout1_1')(activation1_1)
327      
328      ##depth 2
329      conv2_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv2_1')(drop1_1)
330      batch_norm2_1 = layers.BatchNormalization(name='norm2_1')(conv2_1)
331      activation2_1 = layers.Activation('relu', name='relu2_1')(batch_norm2_1)
332      drop2_1 = layers.Dropout(dropout, seed=1, name='dropout2_1')(activation2_1)
333      
334      ##depth 3
335      conv3_1 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv3_1')(drop2_1)
336      batch_norm3_1 = layers.BatchNormalization(name='norm3_1')(conv3_1)
337      activation3_1 = layers.Activation('relu', name='relu3_1')(batch_norm3_1)
338      drop3_1 = layers.Dropout(dropout, seed=1, name='dropout3_1')(activation3_1)
339         
340      x_2 = layers.BatchNormalization(name='norm_2')(input_2)
341      
342      ##depth 1
343      conv1_2 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same", name='conv1_2')(x_2)
344      batch_norm1_2 = layers.BatchNormalization(name='norm1_2')(conv1_2)
345      activation1_2 = layers.Activation('relu', name='relu1_2')(batch_norm1_2)
346      drop1_2 = layers.Dropout(dropout, seed=1, name='dropout1_2')(activation1_2)
347      
348      ##depth 2
349      conv2_2 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv2_2')(drop1_2)
350      batch_norm2_2 = layers.BatchNormalization(name='norm2_2')(conv2_2)
351      activation2_2 = layers.Activation('relu', name='relu2_2')(batch_norm2_2)
352      drop2_2 = layers.Dropout(dropout, seed=1, name='dropout2_2')(activation2_2)
353       
354      ##depth 3
355      conv3_2 = layers.Conv2D(32, (3, 3), kernel_regularizer=l2(decay), bias_regularizer=l2(decay), padding="same",name='conv3_2')(drop2_2)
356      batch_norm3_2 = layers.BatchNormalization(name='norm3_2')(conv3_2)
357      activation3_2 = layers.Activation('relu', name='relu3_2')(batch_norm3_2)
358      drop3_2 = layers.Dropout(dropout, seed=1, name='dropout3_2')(activation3_2)
359      #######################################UPSAMPLING#####################################################
360      
361      if depth == 1:
362          distance = layers.Lambda(abs_diff)([drop1_1, drop1_2])
363      if depth == 2:
364          distance = layers.Lambda(abs_diff, name='abs_diff_2')([drop2_1, drop2_2])
365      if depth == 3:
366          distance = layers.Lambda(abs_diff)([drop3_1, drop3_2])
367      
368      encoder_output = distance  # Output from the encoder stage
369  
370      # After applying ASPP, you already have aspp_output
371      aspp_output = ASPP(encoder_output, aspp_filters, aspp_rates)
372      
373      # After ASPP, reducing the channels
374      reduced_aspp_output = layers.Conv2D(32, (1, 1), padding="same", name='reduced_aspp_output')(aspp_output)
375      reduced_aspp_output = layers.BatchNormalization()(reduced_aspp_output)
376      reduced_aspp_output = layers.Activation('relu')(reduced_aspp_output)
377  
378      # Final output layer
379      output = layers.Conv2D(2, (1, 1), activation="softmax", padding="same", name="output")(reduced_aspp_output)
380  
381      # Modify the model creation line to output both the ASPP feature maps and the final model output
382      final_model = keras.Model(inputs=[input_1, input_2], outputs=[output, output], name='branch_with_aspp_features')
383  
384      # Define layers for which you want to see the feature maps
385      layers_of_interest = ['aspp_reduced_relu', 'abs_diff_2','reduced_aspp_output','output']
386      feature_map_outputs = [final_model.get_layer(name).output for name in layers_of_interest]
387  
388      # Create a new model that outputs feature maps
389      feature_model = keras.Model(inputs=final_model.inputs, outputs=feature_map_outputs)
390  
391      return final_model, feature_model
392      
393