semantic_segmentation_coco_config.py
1 from configuration.train_configuration import TrainConfiguration 2 from data_loading.sources.coco_source import COCOSource 3 from data_loading.loaders.semantic_segmentation_loader import SegmentationLoader 4 from data_loading.loaders.multi_loader import MultiLoader 5 from networks.attention_segmenter import AttentionSegmenter 6 from loss_functions.segmenter_loss import recurrent_segmenter_loss 7 import torch.optim as optim 8 import torch.nn as nn 9 import random 10 11 def get_loader(mode='train'): 12 root = '/home/ray/Data/COCO/val2017' 13 annos = '/home/ray/Data/COCO/annotations/instances_val2017.json' 14 source = COCOSource(root,annos) 15 return SegmentationLoader(source,max_frames=100,crop_size=[255,255],obj_types=['person']) 16 17 # loader = (get_loader,dict()) 18 loader = (MultiLoader,dict(loader=get_loader,loader_args=dict(),num_procs=8)) 19 model = (AttentionSegmenter,dict(num_classes=1,timesteps=5,attn_grid_size=50)) 20 optimizer = (optim.Adam,dict(lr=1e-3)) 21 loss = recurrent_segmenter_loss 22 train_config = TrainConfiguration(loader,optimizer,model,loss,cuda=True)