/ pytlib / configuration / semantic_segmentation_coco_config.py
semantic_segmentation_coco_config.py
 1  from configuration.train_configuration import TrainConfiguration
 2  from data_loading.sources.coco_source import COCOSource 
 3  from data_loading.loaders.semantic_segmentation_loader import SegmentationLoader
 4  from data_loading.loaders.multi_loader import MultiLoader
 5  from networks.attention_segmenter import AttentionSegmenter
 6  from loss_functions.segmenter_loss import recurrent_segmenter_loss
 7  import torch.optim as optim
 8  import torch.nn as nn
 9  import random
10  
11  def get_loader(mode='train'):
12      root = '/home/ray/Data/COCO/val2017'
13      annos = '/home/ray/Data/COCO/annotations/instances_val2017.json'
14      source = COCOSource(root,annos)
15      return SegmentationLoader(source,max_frames=100,crop_size=[255,255],obj_types=['person'])
16  
17  # loader = (get_loader,dict())
18  loader = (MultiLoader,dict(loader=get_loader,loader_args=dict(),num_procs=8))
19  model = (AttentionSegmenter,dict(num_classes=1,timesteps=5,attn_grid_size=50))
20  optimizer = (optim.Adam,dict(lr=1e-3))
21  loss = recurrent_segmenter_loss
22  train_config = TrainConfiguration(loader,optimizer,model,loss,cuda=True)