Implement GANs with PyTorch.
Unconditional image generation (CIFAR-10):
- DCGAN (vanilla GAN)
- DCGAN + R1 regularization
- WGAN
- WGAN-GP
- SNGAN
- LSGAN
Conditional image generation (CIFAR-10):
- CGAN
- ACGAN
Unsupervised decomposition (MNIST, FFHQ):
- InfoGAN
- EigenGAN
Mode collapse study (Ring8, MNIST):
- GAN (vanilla GAN)
- GAN + R1 regularization
- WGAN
- WGAN-GP
- SNGAN
- LSGAN
- VEEGAN
Notes:
Model | G. Arch. | D. Arch. | Loss | Configs |
---|---|---|---|---|
DCGAN | SimpleCNN | SimpleCNN | Vanilla | config file |
DCGAN + R1 reg | SimpleCNN | SimpleCNN | Vanilla R1 regularization |
config fileAdditional args--train.loss_fn.params.lambda_r1_reg 10.0 |
WGAN | SimpleCNN | SimpleCNN | Wasserstein (weight clipping) |
config file |
WGAN-GP | SimpleCNN | SimpleCNN | Wasserstein (gradient penalty) |
config file |
SNGAN | SimpleCNN | SimpleCNN (SN) | Vanilla | config file |
SNGAN | SimpleCNN | SimpleCNN (SN) | Hinge | config file |
LSGAN | SimpleCNN | SimpleCNN | Least Sqaure | config file |
-
SN stands for "Spectral Normalization".
-
For simplicity, the network architecture in all experiments is SimpleCNN, namely a stack of
nn.Conv2d
ornn.ConvTranspose2d
layers. The results can be improved by adding more parameters and using advanced architectures (e.g., residual connections), but I decide to use the simplest setup here. -
All models except LSGAN are trained for 40k generator update steps. However, the optimizers and learning rates are not optimized for each model, so some models may not reach their optimal performance.
Quantitative results:
Model | FID ↓ | Inception Score ↑ |
---|---|---|
DCGAN | 24.7311 | 7.0339 ± 0.0861 |
DCGAN + R1 reg | 24.1535 | 7.0188 ± 0.1089 |
WGAN | 49.9169 | 5.6852 ± 0.0649 |
WGAN-GP | 28.7963 | 6.7241 ± 0.0784 |
SNGAN (vanilla loss) | 24.9151 | 6.8838 ± 0.0667 |
SNGAN (hinge loss) | 28.5197 | 6.7429 ± 0.0818 |
LSGAN | 28.4850 | 6.7465 ± 0.0911 |
- The FID is calculated between 50k generated samples and the CIFAR-10 training split (50k images).
- The Inception Score is calculated on 50k generated samples.
Visualization:
DCGAN | DCGAN + R1 reg | WGAN | WGAN-GP |
---|---|---|---|
SNGAN (vanilla loss) | SNGAN (hinge loss) | LSGAN | |
Notes:
Model | G. Arch. | D. Arch. | G. cond. | D. cond. | Loss | Configs & Args |
---|---|---|---|---|---|---|
CGAN | SimpleCNN | SimpleCNN | concat | concat | Vanilla | config file |
CGAN (cBN) | SimpleCNN | SimpleCNN | cBN | concat | Vanilla | config file |
ACGAN | SimpleCNN | SimpleCNN | cBN | AC | Vanilla | config file |
- cBN stands for "conditional Batch Normalization"; SN stands for "Spectral Normalization"; AC stands for "Auxiliary Classifier"; PD stands for "Projection Discriminator".
Quantitative results:
Model | FID ↓ | Intra FID ↓ | Inception Score ↑ |
---|---|---|---|
CGAN | 25.4999 | 47.7334DetailsClass 0: 53.4163 Class 1: 44.3311 Class 2: 53.1971 Class 3: 52.2223 Class 4: 36.9577 Class 5: 65.0020 Class 6: 37.9598 Class 7: 48.3610 Class 8: 41.8075 Class 9: 44.0796 |
7.5597 ± 0.0909 |
CGAN (cBN) | 25.3466 | 47.4136DetailsClass 0: 51.5959 Class 1: 46.6855 Class 2: 49.9857 Class 3: 53.6737 Class 4: 35.1658 Class 5: 65.7719 Class 6: 38.0958 Class 7: 44.7279 Class 8: 43.3078 Class 9: 45.1265 |
7.7541 ± 0.0944 |
ACGAN | 19.9154 | 49.9892DetailsClass 0: 47.3203 Class 1: 38.6481 Class 2: 62.5885 Class 3: 66.2386 Class 4: 64.5535 Class 5: 60.7876 Class 6: 58.9524 Class 7: 36.8940 Class 8: 28.5964 Class 9: 35.3120 |
7.9903 ± 0.1038 |
- The FID is calculated between 50k generated samples (5k for each class) and the CIFAR-10 training split (50k images).
- The intra FID is calculated between 5k generated samples and CIFAR-10 training split within each class.
- The Inception Score is calculated on 50k generated samples.
Visualizations:
CGAN | CGAN (cBN) | ACGAN |
---|---|---|
InfoGAN
- Left: change the discrete latent variable, which corresponds to the digit type.
- Right: change one of the continuous latent variable from -1 to 1. However, the decomposition is not clear.
- Note: I found that batch normalization layers play an important role in InfoGAN. Without BN layers, the discrete latent variable tends to have a clear meaning as shown above, while the continuous variables have little effect. On the contrary, with BN layers, it's harder for the discrete variable to catch the digit type information and easier for continuous ones to find rotation in digits.
EigenGAN
Random samples (no truncation):
Traverse:
Mode collapse is a notorious problem in GANs, where the model can only generate a few modes of the real data. Various methods have been proposed to solve it. To study this problem, I experimented different methods on the following two datasets:
- Ring8: eight gaussian distributions lying on a ring.
- MNIST: handwritten digit dataset.
For simplicity, the model architecture in all experiments is SimpleMLP, namely a stack of nn.Linear
layers, thus the quality of generated MNIST image may not be so good. However, this section aims to demonstrate the mode collapse problem rather than to achieve the best image quality.
GAN
200 steps | 400 steps | 600 steps | 800 steps | 1000 steps |
---|---|---|---|---|
1000 steps | 2000 steps | 3000 steps | 4000 steps | 5000 steps |
---|---|---|---|---|
On the Ring8 dataset, it can be clearly seen that all the generated data gather to only one of the 8 modes.
In the MNIST case, the generated images eventually collapse to 1.
GAN + R1 regularization
200 steps | 400 steps | 600 steps | 800 steps | 5000 steps |
---|---|---|---|---|
1000 steps | 3000 steps | 5000 steps | 7000 steps | 9000 steps |
---|---|---|---|---|
R1 regularization, a technique to stabilize the training process of GANs, can prevent mode collapse in vanilla GAN as well.
WGAN
200 steps | 400 steps | 600 steps | 800 steps | 5000 steps |
---|---|---|---|---|
1000 steps | 3000 steps | 5000 steps | 7000 steps | 9000 steps |
---|---|---|---|---|
WGAN indeed resolves the mode collapse problem, but converges much slower due to weight clipping.
WGAN-GP
200 steps | 400 steps | 600 steps | 800 steps | 5000 steps |
---|---|---|---|---|
1000 steps | 3000 steps | 5000 steps | 7000 steps | 9000 steps |
---|---|---|---|---|
WGAN-GP improves WGAN by replacing the hard weight clipping with the soft gradient penalty.
The pathological weights distribution in WGAN's discriminator does not appear in WGAN-GP, as shown below.
SNGAN
200 steps | 400 steps | 600 steps | 800 steps | 5000 steps |
---|---|---|---|---|
1000 steps | 3000 steps | 5000 steps | 7000 steps | 9000 steps |
---|---|---|---|---|
Note: The above SNGAN is trained with the vanilla GAN loss instead of the hinge loss.
SNGAN uses spectral normalization to control the Lipschitz constant of the discriminator. Even with the vanilla GAN loss, SNGAN can avoid mode collapse problem.
LSGAN
200 steps | 400 steps | 600 steps | 800 steps | 5000 steps |
---|---|---|---|---|
1000 steps | 3000 steps | 5000 steps | 7000 steps | 9000 steps |
---|---|---|---|---|
LSGAN uses MSE instead of Cross-Entropy as the loss function to overcome the vanishing gradients in vanilla GAN. However, it still suffers from the mode collapse problem. For example, as shown above, LSGAN fails to cover all 8 modes on the Ring8 dataset.
Note: Contrary to the claim in the paper, I found that LSGAN w/o batch normalization does not converge on MNIST.
VEEGAN
200 steps | 400 steps | 600 steps | 800 steps | 5000 steps |
---|---|---|---|---|
1000 steps | 3000 steps | 5000 steps | 7000 steps | 10000 steps |
---|---|---|---|---|
VEEGAN uses an extra network to reconstruct the latent codes from the generated data.
The checkpoints and training logs are stored in xyfJASON/GANs-Implementations on huggingface.
For GAN, WGAN-GP, SNGAN, LSGAN:
accelerate-launch scripts/train.py -c ./configs/xxx.yaml
For WGAN (weight clipping), InfoGAN, VEEGAN, CGAN, ACGAN and EigenGAN, use the scripts with corresponding name instead:
accelerate-launch scripts/train_xxxgan.py -c ./configs/xxx.yaml
Unconditional GANs:
accelerate-launch scripts/sample.py \
-c ./configs/xxx.yaml \
--weights /path/to/saved/ckpt/model.pt \
--n_samples N_SAMPLES \
--save_dir SAVE_DIR
Conditional GANs:
accelerate-launch scripts/sample_cond.py \
-c ./configs/xxx.yaml \
--weights /path/to/saved/ckpt/model.pt \
--n_classes N_CLASSES \
--n_samples_per_class N_SAMPLES_PER_CLASS \
--save_dir SAVE_DIR
EigenGAN:
accelerate-launch scripts/sample_eigengan.py \
-c ./configs/xxx.yaml \
--weights /path/to/saved/ckpt/model.pt \
--n_samples N_SAMPLES \
--save_dir SAVE_DIR \
--mode MODE
Sample images following the instructions above and use tools like torch-fidelity to calculate FID / IS.