Skip to content

Commit

Permalink
Merge pull request #143 from jheek/fix-vae
Browse files Browse the repository at this point in the history
Fix vae
  • Loading branch information
jheek authored Apr 1, 2020
2 parents 8bf6e36 + c969cf3 commit 2f69b48
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 175 deletions.
10 changes: 7 additions & 3 deletions examples/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@ pip install -r requirements.txt
python main.py
```

## Examples;
## Examples

If you run the code by above command, you can get some generated images:

![generated_mnist](./example.png)
![generated_mnist](./sample.png)

Also, you can obtain ELBO of a test set as `107.0544 ± 0.2496` (5 times of trials)
and reconstructions of test set digitis:

![reconstruction_mnist](./reconstruction.png)

The test set ELBO after 10 epochs should be around `106`.
Binary file removed examples/vae/example.png
Binary file not shown.
171 changes: 0 additions & 171 deletions examples/vae/main.py

This file was deleted.

Binary file added examples/vae/reconstruction.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/vae/results/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
*.png
*.png
Binary file added examples/vae/sample.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
202 changes: 202 additions & 0 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl import app
from absl import flags

import numpy as np
import jax.numpy as jnp

import jax
from jax import random

from flax import nn
from flax import optim

import tensorflow as tf
import tensorflow_datasets as tfds

from utils import save_image


FLAGS = flags.FLAGS

flags.DEFINE_float(
'learning_rate', default=1e-3,
help=('The learning rate for the Adam optimizer.')
)

flags.DEFINE_integer(
'batch_size', default=128,
help=('Batch size for training.')
)

flags.DEFINE_integer(
'num_epochs', default=30,
help=('Number of training epochs.')
)

flags.DEFINE_integer(
'latents', default=20,
help=('Number of latent variables.')
)


class Encoder(nn.Module):

def apply(self, x, latents):
x = nn.Dense(x, 500, name='fc1')
x = nn.relu(x)
mean_x = nn.Dense(x, latents, name='fc2_mean')
logvar_x = nn.Dense(x, latents, name='fc2_logvar')
return mean_x, logvar_x


class Decoder(nn.Module):

def apply(self, z):
z = nn.Dense(z, 500, name='fc1')
z = nn.relu(z)
z = nn.Dense(z, 784, name='fc2')
return z


class VAE(nn.Module):

def apply(self, x, z_rng, latents=20):
decoder = self._create_decoder()
mean, logvar = Encoder(x, latents, name='encoder')
z = reparameterize(z_rng, mean, logvar)
recon_x = decoder(z)
return recon_x, mean, logvar

@nn.module_method
def generate(self, z, **unused_kwargs):
decoder = self._create_decoder()
return nn.sigmoid(decoder(z))

def _create_decoder(self):
return Decoder.shared(name='decoder')


def reparameterize(rng, mean, logvar):
std = jnp.exp(0.5 * logvar)
eps = random.normal(rng, logvar.shape)
return mean + eps * std


@jax.vmap
def kl_divergence(mean, logvar):
return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))


@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
logits = nn.log_sigmoid(logits)
return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))


def compute_metrics(recon_x, x, mean, logvar):
bce_loss = binary_cross_entropy_with_logits(recon_x, x).mean()
kld_loss = kl_divergence(mean, logvar).mean()
return {
'bce': bce_loss,
'kld': kld_loss,
'loss': bce_loss + kld_loss
}


@jax.jit
def train_step(optimizer, batch, z_rng):
def loss_fn(model):
recon_x, mean, logvar = model(batch, z_rng)

bce_loss = binary_cross_entropy_with_logits(recon_x, batch).mean()
kld_loss = kl_divergence(mean, logvar).mean()
loss = bce_loss + kld_loss
return loss, recon_x
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
_, grad = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grad)
return optimizer


@jax.jit
def eval(model, images, z, z_rng):
recon_images, mean, logvar = model(images, z_rng)

comparison = jnp.concatenate([images[:8].reshape(-1, 28, 28, 1),
recon_images[:8].reshape(-1, 28, 28, 1)])

generate_images = model.generate(z)
generate_images = generate_images.reshape(-1, 28, 28, 1)
metrics = compute_metrics(recon_images, images, mean, logvar)

return metrics, comparison, generate_images


def prepare_image(x):
x = tf.cast(x['image'], tf.float32)
x = tf.reshape(x, (-1,))
return x


def main(argv):
del argv
rng = random.PRNGKey(0)
rng, key = random.split(rng)

train_ds = tfds.load('binarized_mnist', split=tfds.Split.TRAIN)
train_ds = train_ds.map(prepare_image)
train_ds = train_ds.cache()
train_ds = train_ds.repeat()
train_ds = train_ds.shuffle(50000)
train_ds = train_ds.batch(FLAGS.batch_size)
train_ds = tfds.as_numpy(train_ds)

test_ds = tfds.load('binarized_mnist', split=tfds.Split.TEST)
test_ds = test_ds.map(prepare_image).batch(10000)
test_ds = np.array(list(test_ds)[0])
test_ds = jax.device_put(test_ds)

model_def = VAE.partial(latents=FLAGS.latents)
_, params = model_def.init_by_shape(
key, [(FLAGS.batch_size, 784)], z_rng=random.PRNGKey(0))
vae = nn.Model(model_def, params)

optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae)
optimizer = jax.device_put(optimizer)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, FLAGS.latents))

steps_per_epoch = 50000 // FLAGS.batch_size

for epoch in range(FLAGS.num_epochs):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
optimizer = train_step(optimizer, batch, key)

metrics, comparison, sample = eval(optimizer.target, test_ds, z, eval_rng)
save_image(comparison, f'results/reconstruction_{epoch}.png', nrow=8)
save_image(sample, f'results/sample_{epoch}.png', nrow=8)

print('eval epoch: {}, loss: {:.4f}, BCE: {:.4f}, KLD: {:.4f}'.format(
epoch + 1, metrics['loss'], metrics['bce'], metrics['kld']
))


if __name__ == '__main__':
app.run(main)

0 comments on commit 2f69b48

Please sign in to comment.