A conditional generative adversarial network (CGAN) is a type of GAN that also takes advantage of labels during the training process.
-
Generator: Given a label and uniform random variable array as input, and this network builds a mapping function from prior noise to our data space.
-
Discriminator: Given batches of labeled data containing observations from both the training data and generated data from the generator, this network attempts return a single scalar representing the probability that x came form training data rather than generator distribution.
The goal of the generator is to fool the discriminator, so the generative neural network is trained to maximise the final classification error (between true and generated data)
The goal of the discriminator is to detect fake generated data, so the discriminative neural network is trained to minimise the final classification error
The MNIST database of handwritten digits has a training set of 60,000 examples and a test set of 10,000 samples. I used pytorch datasets for downloading dataset :
train_dataset = datasets.MNIST('mnist/', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('mnist/', train=False, download=True, transform=transform)
Below is a model architecture diagram for a Conditional DCGAN. Note that the high-level architecture is essentially the same as in the cGAN, except the Generator and Discriminator contain additional layers, such as Convolutions and Transposed Convolutions.
Trainer class Does the main part of code which is training model, plot the training process and save model each n epochs.
I Defined Adam
Optimizer with learning rate 0.0002.
Each generative model training step occurse in train_generator
function, descriminator model training step in train_descriminator
and whole trining process in
train
function.
- You can set epoch size :
EPOCHS
and batch size :BATCH_SIZE
. - Set
device
that you want to train model on it :device
(default runs on cuda if it's available) - You can set one of three
verboses
that prints info you want => 0 == nothing || 1 == model architecture || 2 == print optimizer || 3 == model parameters size. - Each time you train model weights and plot(if
save_plots
== True) will be saved insave_dir
. - You can find a
configs
file insave_dir
that contains some information about run. - You can choose Optimizer:
OPTIMIZER