autoencoder_config.py
1 from configuration.train_configuration import TrainConfiguration 2 from data_loading.sources.kitti_source import KITTISource 3 from data_loading.loaders.autoencoder_loader import AutoEncoderLoader 4 from data_loading.loaders.multi_loader import MultiLoader 5 import torch.optim as optim 6 import torch.nn as nn 7 from networks.autoencoder import AutoEncoder 8 import random 9 10 def get_loader(): 11 source = KITTISource('/home/ray/Data/KITTI/training',max_frames=10000) 12 return AutoEncoderLoader(source,crop_size=[255,255],obj_types=['Car']) 13 14 loader = (MultiLoader,dict(loader=get_loader,loader_args=dict(),num_procs=10)) 15 model = (AutoEncoder,dict()) 16 optimizer = (optim.Adam,dict(lr=1e-3)) 17 loss = nn.BCELoss() 18 train_config = TrainConfiguration(loader,optimizer,model,loss,True)