/ pytlib / configuration / mnist_ae_config.py
mnist_ae_config.py
 1  from configuration.train_configuration import TrainConfiguration
 2  from data_loading.sources.mnist_source import MNISTSource
 3  from data_loading.loaders.autoencoder_loader import AutoEncoderLoader
 4  from data_loading.loaders.multi_loader import MultiLoader
 5  import torch.optim as optim
 6  import torch.nn as nn
 7  from networks.autoencoder import AutoEncoder
 8  import random
 9  
10  def get_loader():
11      source = MNISTSource('mnist_example',download=True)
12      return AutoEncoderLoader(source,crop_size=[28,28])
13  
14  loader = (get_loader,dict())
15  # loader = (MultiLoader,dict(loader=get_loader,loader_args=dict(),num_procs=16))
16  model = (AutoEncoder,dict(inchans=1))
17  optimizer = (optim.Adam,dict(lr=1e-3))
18  loss = nn.BCELoss()
19  train_config = TrainConfiguration(loader,optimizer,model,loss,cuda=False)