/ pytlib / loss_functions / segmenter_loss.py
segmenter_loss.py
1  import torch
2  import torch.nn as nn
3  import torch.nn.functional as F
4  
5  def recurrent_segmenter_loss(output_array,target_masks):
6      BCE = F.binary_cross_entropy(output_array[-1], target_masks)
7      return BCE