fcsiaDiff.py
1 from collections.abc import Sequence 2 from typing import Any, Callable, Optional, Union 3 4 import segmentation_models_pytorch as smp 5 import torch 6 from segmentation_models_pytorch import Unet 7 from segmentation_models_pytorch.base.model import SegmentationModel 8 from torch import Tensor 9 10 11 12 13 class FCSiamDiff(smp.Unet): 14 def __init__(self, *args: Any, **kwargs: Any) -> None: 15 kwargs["aux_params"] = None 16 super().__init__(*args, **kwargs) 17 18 def forward(self, x: Tensor) -> Tensor: 19 x1 = x[:, 0] 20 x2 = x[:, 1] 21 features1, features2 = self.encoder(x1), self.encoder(x2) 22 features = [features2[i] - features1[i] for i in range(1, len(features1))] 23 features.insert(0, features2[0]) 24 decoder_output = self.decoder(*features) 25 masks: Tensor = self.segmentation_head(decoder_output) 26 return masks