Skip to content

computer vision course project: Adversarial Training and Generative Models

Notifications You must be signed in to change notification settings

leo-du/vision-project

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CSE 455 Computer Vision Final Project

Generative Adversarial Models

Demo

Vanilla GAN on MNIST Deep Convolutional GAN on MNIST
vg dc
Vanilla GAN on CIFAR Deep Convolutional GAN on CIFAR
vg dc

Model

Our Vanilla GAN are just Multi-Layer Perceptrons (linear transformations followed by ReLU nonlinearities). CIFAR's Vanilla model has more layers than MNIST; and Generator has more layers than Discriminators. For DCGAN, Generator and Discriminator in both MNIST and CIFAR have deep convolutional structures. The parameters for MNIST's both model is taken from InfoGAN (Chen et al.). CIFAR's model is similar to MNIST's only deeper.

Usage

To reproduce our results, first run $ jupyter notebook and copy over the following setup code:

Note: if you run on a CPU please change the dtype to torch.FloatTensor (uncomment the last line). However, beware that our model includes deep convolutional networks that would run extremely slow on CPU.

import GAN
import GAN.MNIST.GAN, GAN.MNIST.DCGAN
import GAN.CIFAR.GAN, GAN.CIFAR.DCGAN
from GAN.utils import *
from dataloader import *

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

mnist_loader, cifar_loader = get_data()
dtype = torch.cuda.FloatTensor
# dtype = torch.FloatTensor

To train a model and see the pictures for yourself, you could run the following code to train a vanilla GAN (Goodfellow et al.) on MNIST dataset:

D = GAN.MNIST.GAN.get_discriminator().type(dtype)
G = GAN.MNIST.GAN.get_generator().type(dtype)

D_optim = GAN.utils.get_optimizer(D)
G_optim = GAN.utils.get_optimizer(G)

GAN.MNIST.train(D, G, D_optim, G_optim, discriminator_loss,
                generator_loss, dtype, mnist_loader)

To switch to CIFAR dataset, change all MNIST to CIFAR would do the trick. To switch to Deep Convolutional GAN (i.e., DCGAN, Radford et al.), use GAN.MNIST.DCGAN instead in the first two lines (you could also do CIFAR).

References

  1. X. Chen, Y. Duan, R. Houthooft, J. Schulman, I. Sutskever, and P. Abbeel. "Infogan: Interpretable representation learning by information maximizing generative adversarial nets", in NIPS, 2016.
  2. I. Goodfellow, J. Pouget-Abdie, M.Mirza, B. Xu, D. Warde-Farley, S.Ozair, A. Courville, and Y. Bengio, "Generative Adversarial Nets", in NIPS, 2014, pp.2672-2680.
  3. A. Radford, L. Metz, and S. Chintala, "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks", in ICLR, 2016.

About

computer vision course project: Adversarial Training and Generative Models

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages