Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vae #143

Merged
merged 2 commits into from
Apr 1, 2020
Merged

Fix vae #143

merged 2 commits into from
Apr 1, 2020

Conversation

jheek
Copy link
Member

@jheek jheek commented Mar 30, 2020

No description provided.

Copy link
Contributor

@AlexeyG AlexeyG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the rng bug, @jheek !

I left a bunch of comments, most of which aren't actually related to your improvements, but are rather requests for more improvements of the example. I'm happy to leave them for a later PR, if you prefer.

I also noticed that result this VAE gets doesn't actually match the paper. We should be getting ~100, but I am not sure after how many training epochs.

return {
'bce': jnp.mean(bce_loss),
'kld': jnp.mean(kld_loss),
'loss': jnp.mean(bce_loss + kld_loss)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not better (in terms of op count) to sum the averaged bce and kld loss? Although it doesn't really matter on such a small-scale example.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

'learning_rate', default=1e-3,
help=('The leanring rate for the Adam optimizer')
'learning_rate', default=1e-3,
help=('The leanring rate for the Adam optimizer')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a dot at the end of the help string. Here and in other flags.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -127,10 +124,10 @@ def loss_fn(model):


@jax.jit
def eval(model, eval_ds, z):
def eval(model, eval_ds, z, z_rng):
xs = eval_ds['image'] / 255.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not be a part of the input pipeline?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also part of the input pipeline now

for epoch in range(FLAGS.num_epochs):
for batch in tfds.as_numpy(train_ds):
rng, key = random.split(rng)
batch['image'] = batch['image'].reshape(-1, 784) / 255.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are images flattened and then unflattened again in eval? If flattening serves a purpose, consider moving the first flattening op into the IO pipeline

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flattening is now part of the input pipeline

for epoch in range(FLAGS.num_epochs):
for batch in tfds.as_numpy(train_ds):
rng, key = random.split(rng)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for manually handling key splitting instead of using the stochastic context? Using the latter would showcase more of Flax's features.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to use the simple and explicit Jax APIs as much as possible. nn.stochastic was introduced to avoid having to pass around rngs through a complex model. But within a single function I think manually splitting rngs is still better.

for epoch in range(FLAGS.num_epochs):
for batch in tfds.as_numpy(train_ds):
rng, key = random.split(rng)
batch['image'] = batch['image'].reshape(-1, 784) / 255.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After dividing by 255. the data becomes "continuous". You need to dynamically binarise it- both for train and eval. Many VAEs report results on dynamically binarised MNIST and on statically binarised MNIST (some fixed dataset, which I believe isn't a part of TFDS).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scrap half of the previous comment - TFDS does have the binarized mnist dataset. It's called binarized_mnist :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your recommendation here? To eval on both?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just stick to the binarized_mnist in training and eval. This should be equivalent to the "static MNIST" setup from VAE literature.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

examples/vae/main.py Outdated Show resolved Hide resolved
@@ -12,6 +12,10 @@ python main.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Git doesn't let me comment on parts that are not considered different. But could you please remove a semicolon after ## Examples ?

vae = nn.Model(VAE, params)

optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, 20))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

64 -> FLAGS.batch_size ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also make 20 (latent dimensionality) a parameter? I don't expect it being used a lot, but giving it a name would improve readability IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

latents are a flag now 64 is because the visualization is 8x8

xs = eval_ds['image'] / 255.0
xs = xs.reshape(-1, 784)
recon_xs, mean, logvar = model(xs)
recon_xs, mean, logvar = model(xs, z_rng)

comparison = jnp.concatenate([xs[:8].reshape(-1, 28, 28, 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return all samples+reconstructions and subsample+concat outside of eval before saving the image?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of reconstructions (10,000) so I think pre-processing the data for the comparison image is not a bad idea

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realise this. I thought it was going to be just the regular batch size.

- Fixes bug causing fixed noise during training
- Switch to 30 epochs to better reproduce paper
- Switch to binarized mnist
- Add reconstruction image to the README.
Copy link
Contributor

@AlexeyG AlexeyG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Thank you for re-working this example @jheek !

@jheek jheek merged commit 2f69b48 into google:master Apr 1, 2020
@jheek jheek deleted the fix-vae branch April 1, 2020 11:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants