diff --git a/examples/auto_encoder/train.py b/examples/auto_encoder/train.py index 1c3ab3b..5acb557 100644 --- a/examples/auto_encoder/train.py +++ b/examples/auto_encoder/train.py @@ -37,6 +37,9 @@ def get_argparser(): def main(): + if not os.path.exists('results'): + os.mkdir('results') + opts = get_argparser().parse_args() # dataset @@ -58,7 +61,7 @@ def main(): val_loader = DataLoader( ImageDataset(root='datasets/CLIC/valid', transform=val_transform), - batch_size=opts.batch_size, shuffle=False, num_workers=0) + batch_size=1, shuffle=False, num_workers=0) print("Train set: %d, Val set: %d" % (len(train_loader.dataset), len(val_loader.dataset))) @@ -105,7 +108,7 @@ def main(): # ===== Validation ===== print("Val...") best_score = 0.0 - cur_score = test(opts, model, val_loader) + cur_score = test(opts, model, val_loader, cur_epoch) print("%s = %.6f" % (opts.loss_type, cur_score)) # ===== Save Best Model ===== if cur_score > best_score: # save best model @@ -114,7 +117,10 @@ def main(): print("Best model saved as best_model.pt") -def test(opts, model, val_loader): +def test(opts, model, val_loader, epoch): + save_dir = os.path.join('results', 'epoch_%d' % epoch) + if not os.path.exists(save_dir): + os.mkdir(save_dir) model.eval() cur_score = 0.0 @@ -124,11 +130,10 @@ def test(opts, model, val_loader): for i, (images, ) in enumerate(val_loader): outputs = model(images) # save the first reconstructed image - if i == 20: - Image.fromarray((outputs*255).squeeze(0).detach().numpy().astype( - 'uint8').transpose(1, 2, 0)).save('recons_%s.png' % (opts.loss_type)) cur_score += metric(outputs, images, data_range=1.0) + Image.fromarray((outputs*255).squeeze(0).detach().numpy().astype('uint8').transpose(1, 2, 0)).save(os.path.join(save_dir, 'recons_%s_%d.png' % (opts.loss_type, i))) cur_score /= len(val_loader.dataset) + return cur_score