Skip to content

Latest commit

 

History

History
209 lines (153 loc) · 7 KB

README.md

File metadata and controls

209 lines (153 loc) · 7 KB

SCGAN

Code for reproducing experiments in "SCGAN: Disentangled Representation Learning by Adding Similarity Constraint on Generative Adversarial Nets"

Prerequisites

  • Python
  • Numpy
  • Scipy
  • Matplotlib
  • Tensorflow >= 1.5

Preparing datasets

The tree of directory is listed as follow:

  • DATA_DIR
    • mnist
    • fashion-mnist
    • svhn
    • cifar10
    • celeba

The DATA_DIR is root directory of these datasets, you can set it any path you want, e.g. /home/username. Then you should download all datasets.

  1. Download mnist dataset from MNIST, and put train-images-idx3-ubyte.gz, train-labels-idx1-ubyte.gz, t10k-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz into DATA_DIR/mnist.

  2. Download fashion-mnist dataset from Fashion-MNIST, and put train-images-idx3-ubyte.gz, train-labels-idx1-ubyte.gz, t10k-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz into DATA_DIR/fashion-mnist.

  3. Download svhn dataset from SVHN, and put train_32x32.mat into DATA_DIR/svhn.

  4. Download cifar10 dataset from CIFAR, extract cifar-10-batches-py from cifar-10-python.tar.gz, and put cifar-10-batches-py into DATA_DIR/cifar10.

  5. Download celeba dataset from Baidu or Google, and put celeba_0, celeba_1, celeba_2, celeba_3(We have converted it to tfrecords) into DATA_DIR/celeba.

MNIST & Fashion-MNIST

Train on MNIST

python cgan_mnist.py --data_dir=DATA_DIR
python infogan_mnist.py --data_dir=DATA_DIR
python scgan_mnist.py --data_dir=DATA_DIR

Test on MNIST

It will create *.npy which contains 10000 synthetic images.

python cgan_mnist.py --train=False --data_dir=DATA_DIR
python scgan_mnist.py --train=False --data_dir=DATA_DIR
python infogan_mnist.py --train=False --data_dir=DATA_DIR

Train on Fashion-MNIST

python cgan_mnist.py --dataset_type=fashion-mnist --data_dir=DATA_DIR
python infogan_mnist.py --dataset_type=fashion-mnist --data_dir=DATA_DIR
python scgan_mnist.py --dataset_type=fashion-mnist --data_dir=DATA_DIR

Test on Fashion-MNIST

It will create *.npy which contains 10000 synthetic images.

python cgan_mnist.py --train=False --dataset_type=fashion-mnist --data_dir=DATA_DIR
python infogan_mnist.py --train=False --dataset_type=fashion-mnist --data_dir=DATA_DIR
python scgan_mnist.py --train=False --dataset_type=fashion-mnist --data_dir=DATA_DIR

On MNIST and Fashion-MNIST, we find categorical conditional variables capture class labels, e.g., digit type, clothing type.

Gaussian Parzen Window

We use gaussian parzen window to estimate log-likelihood on 10000 synthetic images for each model.

python gaussian_parzen_window.py --data_dir DATA_DIR/mnist --gen_data_path result/cgan_mnist/cgan_mnist.npy --file result/cgan_mnist/cgan_mnist.txt
python gaussian_parzen_window.py --data_dir DATA_DIR/mnist --gen_data_path result/scgan_mnist/scgan_mnist.npy --file result/scgan_mnist/scgan_mnist.txt
python gaussian_parzen_window.py --data_dir DATA_DIR/mnist --gen_data_path result/infogan_mnist/infogan_mnist.npy --file result/infogan_mnist/infogan_mnist.txt
python gaussian_parzen_window.py --data_dir DATA_DIR/fashion-mnist --gen_data_path result/cgan_fashion-mnist/cgan_fashion-mnist.npy --file result/cgan_fashion-mnist/cgan_fashion-mnist.txt
python gaussian_parzen_window.py --data_dir DATA_DIR/fashion-mnist --gen_data_path result/scgan_fashion-mnist/scgan_fashion-mnist.npy --file result/scgan_fashion-mnist/scgan_fashion-mnist.txt
python gaussian_parzen_window.py --data_dir DATA_DIR/fashion-mnist --gen_data_path result/infogan_fashion-mnist/infogan_fashion-mnist.npy --file result/infogan_fashion-mnist/infogan_fashion-mnist.txt

Gaussian Parzen window-based log-likelihood estimates for MNIST and Fashion-MNIST.

Model MNIST Fashion-MNIST
CGAN 228.1 ± 2.2 312.7 ± 2
InfoGAN 231 ± 2.2 320.2 ± 2
SCGAN 233.6 ± 2.2 321.7 ± 2

SVHN & CIFAR10

Train on SVHN

python scgan_svhn.py --data_dir=DATA_DIR

Test on SVHN

Fix one conditional variable, and vary the other conditional variable.

python scgan_svhn.py --train=False --data_dir=DATA_DIR --con_dim=0
python scgan_svhn.py --train=False --data_dir=DATA_DIR --con_dim=1

We select some good-quality synthetic images. Continuous variables capture variation on lighting and digit types.

Train on CIFAR10

python scgan_cifar.py --data_dir=DATA_DIR

Test on CIFAR10

Fix one conditional variable, and vary the other conditional variable.

python scgan_cifar.py --train=False --data_dir=DATA_DIR --con_dim=0
python scgan_cifar.py --train=False --data_dir=DATA_DIR --con_dim=1

We select some good-quality synthetic images. Continuous variables capture variation on object size and colour.

CelebA

Train on CelebA

python scgan_celeba.py --data_dir=DATA_DIR

Test on CelebA

Fix three conditional variables, and vary the remaining conditional variable.

python scgan_celeba.py --train=False --data_dir=DATA_DIR --con_dim=0
python scgan_celeba.py --train=False --data_dir=DATA_DIR --con_dim=1
python scgan_celeba.py --train=False --data_dir=DATA_DIR --con_dim=2
python scgan_celeba.py --train=False --data_dir=DATA_DIR --con_dim=3

We select some good-quality synthetic images. Categorical variables capture variation on pose, glasses, hair and sex.

Note: There is no guarantee that one conditional variable just captures a kind of representation, so that InfoGAN uses ten 10-dimensional categorical variables to capture these representations.

Cite SCGAN

If you use SCGAN in your research, we would appreciate references to the following paper:

Biblatex entry:

@article{li2018scgan,
  title={SCGAN: Disentangled Representation Learning by Adding Similarity Constraint on Generative Adversarial Nets},
  author={Li, Xiaoqiang and Chen, Liangbo and Wang, Lu and Wu, Pin and Tong, Weiqin},
  journal={IEEE Access},
  volume={7},
  pages={147928--147938},
  year={2018},
  publisher={IEEE}
}