-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix vae #143
Fix vae #143
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the rng bug, @jheek !
I left a bunch of comments, most of which aren't actually related to your improvements, but are rather requests for more improvements of the example. I'm happy to leave them for a later PR, if you prefer.
I also noticed that result this VAE gets doesn't actually match the paper. We should be getting ~100, but I am not sure after how many training epochs.
examples/vae/main.py
Outdated
return { | ||
'bce': jnp.mean(bce_loss), | ||
'kld': jnp.mean(kld_loss), | ||
'loss': jnp.mean(bce_loss + kld_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it not better (in terms of op count) to sum the averaged bce and kld loss? Although it doesn't really matter on such a small-scale example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
examples/vae/main.py
Outdated
'learning_rate', default=1e-3, | ||
help=('The leanring rate for the Adam optimizer') | ||
'learning_rate', default=1e-3, | ||
help=('The leanring rate for the Adam optimizer') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a dot at the end of the help string. Here and in other flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
examples/vae/main.py
Outdated
@@ -127,10 +124,10 @@ def loss_fn(model): | |||
|
|||
|
|||
@jax.jit | |||
def eval(model, eval_ds, z): | |||
def eval(model, eval_ds, z, z_rng): | |||
xs = eval_ds['image'] / 255.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this not be a part of the input pipeline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also part of the input pipeline now
examples/vae/main.py
Outdated
for epoch in range(FLAGS.num_epochs): | ||
for batch in tfds.as_numpy(train_ds): | ||
rng, key = random.split(rng) | ||
batch['image'] = batch['image'].reshape(-1, 784) / 255.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are images flattened and then unflattened again in eval
? If flattening serves a purpose, consider moving the first flattening op into the IO pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
flattening is now part of the input pipeline
examples/vae/main.py
Outdated
for epoch in range(FLAGS.num_epochs): | ||
for batch in tfds.as_numpy(train_ds): | ||
rng, key = random.split(rng) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reason for manually handling key splitting instead of using the stochastic context? Using the latter would showcase more of Flax's features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like to use the simple and explicit Jax APIs as much as possible. nn.stochastic
was introduced to avoid having to pass around rngs through a complex model. But within a single function I think manually splitting rngs is still better.
examples/vae/main.py
Outdated
for epoch in range(FLAGS.num_epochs): | ||
for batch in tfds.as_numpy(train_ds): | ||
rng, key = random.split(rng) | ||
batch['image'] = batch['image'].reshape(-1, 784) / 255.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After dividing by 255. the data becomes "continuous". You need to dynamically binarise it- both for train and eval. Many VAEs report results on dynamically binarised MNIST and on statically binarised MNIST (some fixed dataset, which I believe isn't a part of TFDS).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scrap half of the previous comment - TFDS does have the binarized mnist dataset. It's called binarized_mnist
:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is your recommendation here? To eval on both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just stick to the binarized_mnist
in training and eval. This should be equivalent to the "static MNIST" setup from VAE literature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -12,6 +12,10 @@ python main.py | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Git doesn't let me comment on parts that are not considered different. But could you please remove a semicolon after ## Examples
?
examples/vae/main.py
Outdated
vae = nn.Model(VAE, params) | ||
|
||
optimizer = optim.Adam(learning_rate=FLAGS.learning_rate).create(vae) | ||
|
||
rng, z_key, eval_rng = random.split(rng, 3) | ||
z = random.normal(z_key, (64, 20)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
64 -> FLAGS.batch_size ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also make 20 (latent dimensionality) a parameter? I don't expect it being used a lot, but giving it a name would improve readability IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
latents are a flag now 64 is because the visualization is 8x8
examples/vae/main.py
Outdated
xs = eval_ds['image'] / 255.0 | ||
xs = xs.reshape(-1, 784) | ||
recon_xs, mean, logvar = model(xs) | ||
recon_xs, mean, logvar = model(xs, z_rng) | ||
|
||
comparison = jnp.concatenate([xs[:8].reshape(-1, 28, 28, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return all samples+reconstructions and subsample+concat outside of eval before saving the image?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a lot of reconstructions (10,000) so I think pre-processing the data for the comparison image is not a bad idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't realise this. I thought it was going to be just the regular batch size.
- Fixes bug causing fixed noise during training - Switch to 30 epochs to better reproduce paper - Switch to binarized mnist - Add reconstruction image to the README.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thank you for re-working this example @jheek !
No description provided.