mnist_ae_config.py
1 from configuration.train_configuration import TrainConfiguration 2 from data_loading.sources.mnist_source import MNISTSource 3 from data_loading.loaders.autoencoder_loader import AutoEncoderLoader 4 from data_loading.loaders.multi_loader import MultiLoader 5 import torch.optim as optim 6 import torch.nn as nn 7 from networks.autoencoder import AutoEncoder 8 import random 9 10 def get_loader(): 11 source = MNISTSource('mnist_example',download=True) 12 return AutoEncoderLoader(source,crop_size=[28,28]) 13 14 loader = (get_loader,dict()) 15 # loader = (MultiLoader,dict(loader=get_loader,loader_args=dict(),num_procs=16)) 16 model = (AutoEncoder,dict(inchans=1)) 17 optimizer = (optim.Adam,dict(lr=1e-3)) 18 loss = nn.BCELoss() 19 train_config = TrainConfiguration(loader,optimizer,model,loss,cuda=False)