Skip to content

Commit

Permalink
Merge pull request #66 from wazeerzulfikar/add/dcgan
Browse files Browse the repository at this point in the history
Added dcgan architecture with tests
  • Loading branch information
satra authored Oct 13, 2021
2 parents 2630133 + d9e17cf commit 3f54abe
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nobrainer/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .autoencoder import autoencoder
from .dcgan import dcgan
from .highresnet import highresnet
from .meshnet import meshnet
from .progressivegan import progressivegan
Expand All @@ -25,6 +26,7 @@ def get(name):
"unet": unet,
"autoencoder": autoencoder,
"progressivegan": progressivegan,
"dcgan": dcgan,
}

try:
Expand Down
96 changes: 96 additions & 0 deletions nobrainer/models/dcgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Model definition for DCGAN.
"""
import math

from tensorflow.keras import layers, models


def dcgan(
output_shape,
z_dim=256,
n_base_filters=16,
batchnorm=True,
batch_size=None,
name="dcgan",
):
"""Instantiate DCGAN Architecture.
Parameters
----------
output_shape: list or tuple of four ints, the shape of the output images. Should be
scaled to [0,1]. Omit the batch dimension, and include the number of channels.
Currently, only squares and cubes supported.
z_dim: int, the dimensions of the encoding of the latent code. This would translate
to a latent code of dimensions encoding_dimx1.
n_base_filters: int, number of base filters the models first convolutional layer.
The subsequent layers have n_filters which are multiples of n_base_filters.
batchnorm: bool, whether to use batch normalization in the network.
batch_size: int, number of samples in each batch. This must be set when
training on TPUs.
name: str, name to give to the resulting model object.
Returns
-------
Generator Model object.
Discriminator Model object.
"""

conv_kwds = {"kernel_size": 4, "activation": None, "padding": "same", "strides": 2}

conv_transpose_kwds = {
"kernel_size": 4,
"strides": 2,
"activation": None,
"padding": "same",
}

dimensions = output_shape[:-1]
n_dims = len(dimensions)

if not (n_dims in [2, 3] and dimensions[1:] == dimensions[:-1]):
raise ValueError("Dimensions should be of square or cube!")

Conv = getattr(layers, "Conv{}D".format(n_dims))
ConvTranspose = getattr(layers, "Conv{}DTranspose".format(n_dims))
n_layers = int(math.log(dimensions[0], 2))

# Generator
z_input = layers.Input(shape=(z_dim,), batch_size=batch_size)

project = layers.Dense(pow(4, n_dims) * z_dim)(z_input)
project = layers.ReLU()(project)
project = layers.Reshape((4,) * n_dims + (z_dim,))(project)
x = project

for i in range(n_layers - 2)[::-1]:
n_filters = min(n_base_filters * (2 ** (i)), z_dim)

x = ConvTranspose(n_filters, **conv_transpose_kwds)(x)
if batchnorm:
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

outputs = Conv(1, 3, activation="sigmoid", padding="same")(x)

generator = models.Model(
inputs=[z_input], outputs=[outputs], name=name + "_generator"
)

# PatchGAN Discriminator with output of 8x8(x8)
inputs = layers.Input(shape=(output_shape), batch_size=batch_size)
x = inputs
for i in range(n_layers - 3):
n_filters = min(n_base_filters * (2 ** (i)), z_dim)

x = Conv(n_filters, **conv_kwds)(x)
if batchnorm:
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)

pred = Conv(1, 3, padding="same", activation="sigmoid")(x)

discriminator = models.Model(
inputs=[inputs], outputs=[pred], name=name + "_discriminator"
)

return generator, discriminator
20 changes: 20 additions & 0 deletions nobrainer/models/tests/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..autoencoder import autoencoder
from ..bayesian_vnet import bayesian_vnet
from ..bayesian_vnet_semi import bayesian_vnet_semi
from ..dcgan import dcgan
from ..highresnet import highresnet
from ..meshnet import meshnet
from ..progressivegan import progressivegan
Expand Down Expand Up @@ -114,6 +115,25 @@ def test_progressivegan():
assert fake_labels_pred.shape == (real_image_input.shape[0], label_size)


def test_dcgan():
"""Special test for dcgan."""

output_shape = (1, 32, 32, 32, 1)
z_dim = 32
z = np.random.random((1, z_dim))

pred_shape = (1, 8, 8, 8, 1)

generator, discriminator = dcgan(output_shape[1:], z_dim=z_dim)
generator.compile(tf.optimizers.Adam(), "mse")
discriminator.compile(tf.optimizers.Adam(), "mse")

fake_images = generator.predict(z)
fake_pred = discriminator.predict(fake_images)

assert fake_images.shape == output_shape and fake_pred.shape == pred_shape


def test_vnet():
model_test(vnet, n_classes=1, input_shape=(1, 32, 32, 32, 1))

Expand Down
112 changes: 112 additions & 0 deletions nobrainer/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,115 @@ def save_weights(self, filepath, **kwargs):
self.generator.save(
os.path.join(filepath, "generator_res_{}".format(self.resolution))
)


class GANTrainer(tf.keras.Model):
"""Generative Adversarial Network Trainer.
Trains discriminator and generator alternatively in an adversarial manner for generation of
brain MRI images.
Parameters
----------
discriminator : tf.keras.Model, Instantiated using nobrainer.models
generator : tf.keras.Model, Instantiated using nobrainer.models
gradient_penalty : boolean, Use gradient penalty on discriminator for smooth training.
References
----------
Links
-----
"""

def __init__(self, discriminator, generator, gradient_penalty=False):
super(GANTrainer, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.gradient_penalty = gradient_penalty
self.latent_size = generator.latent_size

def compile(self, d_optimizer, g_optimizer, g_loss_fn, d_loss_fn):
super(GANTrainer, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer

self.g_loss_fn = compile_utils.LossesContainer(g_loss_fn)
self.d_loss_fn = compile_utils.LossesContainer(d_loss_fn)

if self.gradient_penalty:
self.gradient_penalty_fn = compile_utils.LossesContainer(gradient_penalty)

def train_step(self, reals):
if isinstance(reals, tuple):
reals = reals[0]

# get batch size dynamically
batch_size = tf.shape(reals)[0]

# normalize the real images using minmax to [-1, 1]
reals = _adjust_dynamic_range(reals, [0.0, 255.0], [-1.0, 1.0])

# train discriminator
latents = tf.random.normal((batch_size, self.latent_size))
fake_labels = tf.ones((batch_size, 1)) * -1
real_labels = tf.ones((batch_size, 1))

with tf.GradientTape() as tape:
fakes = self.generator(latents)
fakes_pred, labels_pred_fake = self.discriminator(fakes)
reals_pred, labels_pred_real = self.discriminator(reals)

fake_loss = self.d_loss_fn(fake_labels, fakes_pred)
real_loss = self.d_loss_fn(real_labels, reals_pred)
d_loss = 0.5 * (fake_loss + real_loss)

# calculate and add the gradient penalty loss using average samples for discriminator
if self.gradient_penalty:
weight_shape = (tf.shape(reals)[0],) + (
1,
1,
1,
1,
) # broadcasting to right shape
weight = tf.random.uniform(weight_shape, minval=0, maxval=1)
average_samples = (weight * reals) + ((1 - weight) * fakes)
average_pred = self.discriminator(average_samples)
gradients = tf.gradients(average_pred, average_samples)[0]
gp_loss = self.gradient_penalty_fn(gradients, reals_pred)
d_loss += gp_loss

d_gradients = tape.gradient(d_loss, self.discriminator.trainable_variables)
self.d_optimizer.apply_gradients(
zip(d_gradients, self.discriminator.trainable_variables)
)

# train generator
misleading_labels = tf.ones((batch_size, 1))

latents = tf.random.normal((batch_size, self.latent_size))
with tf.GradientTape() as tape:
fakes = self.generator(latents)
fakes_pred, labels_pred = self.discriminator(fakes)

g_loss = self.g_loss_fn(misleading_labels, fakes_pred)

g_gradients = tape.gradient(g_loss, self.generator.trainable_variables)
self.g_optimizer.apply_gradients(
zip(g_gradients, self.generator.trainable_variables)
)

return {"d_loss": d_loss, "g_loss": g_loss}

def save_weights(self, filepath, **kwargs):
"""
Override base class function to save the weights of the constituent models
"""
self.generator.save_weights(
os.path.join(filepath, "g_weights_res_{}.h5".format(self.resolution)),
**kwargs
)
self.discriminator.save_weights(
os.path.join(filepath, "d_weights_res_{}.h5".format(self.resolution)),
**kwargs
)

0 comments on commit 3f54abe

Please sign in to comment.