/ FCSiam / fcsiaDiff.py
fcsiaDiff.py
 1  from collections.abc import Sequence
 2  from typing import Any, Callable, Optional, Union
 3  
 4  import segmentation_models_pytorch as smp
 5  import torch
 6  from segmentation_models_pytorch import Unet
 7  from segmentation_models_pytorch.base.model import SegmentationModel
 8  from torch import Tensor
 9  
10  
11  
12  
13  class FCSiamDiff(smp.Unet):
14      def __init__(self, *args: Any, **kwargs: Any) -> None:
15          kwargs["aux_params"] = None
16          super().__init__(*args, **kwargs)
17  
18      def forward(self, x: Tensor) -> Tensor:
19          x1 = x[:, 0]
20          x2 = x[:, 1]
21          features1, features2 = self.encoder(x1), self.encoder(x2)
22          features = [features2[i] - features1[i] for i in range(1, len(features1))]
23          features.insert(0, features2[0])
24          decoder_output = self.decoder(*features)
25          masks: Tensor = self.segmentation_head(decoder_output)
26          return masks