This repository is a Tensorflow implementation of DiscoGAN, ICML2017.
- All samples in README.md are genearted by neural network except the first image for each row.
- tensorflow 1.10.0
- python 3.5.3
- numpy 1.14.2
- opencv 3.2.0
- matplotlib 2.2.2
- scipy 0.19.1
- pillow 5.0.0
- Generator
- Discriminator
Results from 2-dimensional Gaussian Mixture Models. Ipython Notebook.
(A) Original GAN
(B) GAN with Reconstruction Loss
(C) Domain A to B of DiscoGAN
(D) Domain B to A of DiscoGAN
- handbag -> shoe -> handbag
- shoe -> handbag -> shoe
- edge -> shoe -> edge
- shoe -> edge -> shoe
- edge -> handbag -> edg
- handbag -> edge -> handbag
- RGB image -> segmentation label -> RGB image
- segmentation label -> RGB image -> segmentation label
- RGB image -> segmentation label -> RGB image
- segmentation label -> RGB image -> segmentation label
- RGB image -> segmentation label -> RGB image
- segmentation label -> RGB image -> segmentation label
Download edges2shoes
, edges2handbags
, cityscapes
, facades
, and maps
datasets from pix2pix first. Use the following command to download datasets and copy the datasets on the corresponding file as introduced in Directory Hierarchy information.
python download.py
.
│ DiscoGAN
│ ├── src
│ │ ├── dataset.py
│ │ ├── discogan.py
│ │ ├── download.py
│ │ ├── main.py
│ │ ├── reader.py
│ │ ├── solver.py
│ │ ├── tensorflow_utils.py
│ │ └── utils.py
│ Data
│ ├── cityscapes
│ ├── edges2handbags
│ ├── edge2shoes
│ ├── facades
│ └── maps
src: source codes of the WGAN
Implementation uses TensorFlow to train the DiscoGAN. Same generator and critic networks are used as described in DiscoGAN paper. We applied learning rate control that started at 2e-4 for the first 1e5 iterations, and decayed linearly to zero as cycleGAN. It's helpful to overcome mode collapse problem.
To respect the original discoGAN paper we set the balance between GAN loss and reconstruction loss are 1:1. Therefore, discoGAN is not good at A -> B -> A
. However, in the cycleGAN the ratio is 1:10. So the reconstructed image is still very similar to input image.
The official code of DiscoGAN implemented by pytorch that used weigt decay. Unfortunately, tensorflow is not support weight deacy as I know. I used regularization term instead of weight decay. So the performance maybe a little different with original one.
Use main.py
to train a DiscoGAN network. Example usage:
python main.py
-
gpu_index
: gpu index, default:0
-
batch_size
: batch size for one feed forward, default:200
-
dataset
: dataset name from [edges2handbags, edges2shoes, handbags2shoes, maps, cityscapes, facades], default:facades
-
is_train
: training or inference mode, default:True
-
learning_rate
: initial learning rate for Adam, default:0.0002
-
beta1
: beta1 momentum term of Adam, default:0.5
-
beta2
: beta2 momentum term of Adam, default:0.999
-
weight_decay
: hyper-parameter for regularization term, default:1e-4
-
iters
: number of interations, default:100000
-
print_freq
: print frequency for loss, default:100
-
save_freq
: save frequency for model, default:10000
-
sample_freq
: sample frequency for saving image, default:500
-
sample_batch
: number of sampling images for check generator quality, default:200
-
load_model
: folder of save model that you wish to test, (e.g. 20180907-1739). default:None
Use main.py
to test a DiscoGAN network. Example usage:
python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20180926-1739
Please refer to the above arguments.
@misc{chengbinjin2018discogan,
author = {Cheng-Bin Jin},
title = {DiscoGAN-tensorflow},
year = {2018},
howpublished = {\url{https://github.com/ChengBinJin/DiscoGAN-TensorFlow}},
note = {commit xxxxxxx}
}
- This project refered some code from carpedm20 and GunhoChoi.
- Some readme formatting was borrowed from Logan Engstrom
Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained.