/ pytlib / configuration / autoencoder_config.py
autoencoder_config.py
 1  from configuration.train_configuration import TrainConfiguration
 2  from data_loading.sources.kitti_source import KITTISource
 3  from data_loading.loaders.autoencoder_loader import AutoEncoderLoader
 4  from data_loading.loaders.multi_loader import MultiLoader
 5  import torch.optim as optim
 6  import torch.nn as nn
 7  from networks.autoencoder import AutoEncoder
 8  import random
 9  
10  def get_loader():
11      source = KITTISource('/home/ray/Data/KITTI/training',max_frames=10000)
12      return AutoEncoderLoader(source,crop_size=[255,255],obj_types=['Car'])
13  
14  loader = (MultiLoader,dict(loader=get_loader,loader_args=dict(),num_procs=10))
15  model = (AutoEncoder,dict())
16  optimizer = (optim.Adam,dict(lr=1e-3))
17  loss = nn.BCELoss()
18  train_config = TrainConfiguration(loader,optimizer,model,loss,True)