segmenter_loss.py
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 5 def recurrent_segmenter_loss(output_array,target_masks): 6 BCE = F.binary_cross_entropy(output_array[-1], target_masks) 7 return BCE
1 import torch 2 import torch.nn as nn 3 import torch.nn.functional as F 4 5 def recurrent_segmenter_loss(output_array,target_masks): 6 BCE = F.binary_cross_entropy(output_array[-1], target_masks) 7 return BCE