Skip to content
This repository has been archived by the owner on Jun 20, 2022. It is now read-only.

Added GAN-JAX-CelebA-demo.ipynb #43

Merged
merged 2 commits into from
Mar 16, 2022
Merged

Added GAN-JAX-CelebA-demo.ipynb #43

merged 2 commits into from
Mar 16, 2022

Conversation

susnato
Copy link
Contributor

@susnato susnato commented Mar 14, 2022

Description

Added JAX GAN demo (Google Colab Notebook Included)

Issue

probml/pyprobml#675

Checklist:

[✓] Performed a self-review of the code
[✓] Tested on Google Colab.

Potential problems/Important remarks

As described in this issue probml/pyprobml#675 I implemented GANs using JAX [Tensorflow was only used for Datasets purpose]with the help of https://github.com/valentingol/GANJax.

Edit: There were some issues related to my first PR that @murphyk pointed out, I solved almost all of them except this one - "It would be great if you could make a detailed 'recipe' people could follow to train this model in GCP, and then just use Colab for visualizing it.", I couldn't do it because I am not familiar with GCS but I still described how you can train your model on Kaggle if Colab does not permit you to train long enough.
If there are still errors please let me know.

@susnato
Copy link
Contributor Author

susnato commented Mar 14, 2022

Can you please review my code? @murphyk @mjsML @gerdm

@murphyk
Copy link
Member

murphyk commented Mar 14, 2022

@susnato
Please can you add a link to a runnable version of your colab, since it is hard to code review notebooks in github.

@susnato
Copy link
Contributor Author

susnato commented Mar 14, 2022

https://colab.research.google.com/drive/1YS-G6DML1xq3AhlNpKmh18Dc4tkn9T4U?usp=sharing

It's the runnable version of the notebook.
@murphyk , let me know if you need anything else.

@murphyk
Copy link
Member

murphyk commented Mar 14, 2022

This is much better. But a few more things:

  • No need to install things like TF, Cuda or Jax into colab, they are already there.
    You just need !pip install dm-haiku optax pickle5 -q
  • No need for comments like "This project aims to bring the power of JAX, a Python framework developped by Google and DeepMind to train Generative Adversarial Networks for images generation."
  • Change title from "Implementation_of_GANs_using_JAX" to "GAN-JAX-CelebA-demo" or something like that.
  • This line is a bit hard to read:
  ax[r_, c_].imshow((((ei_-np.max(ei_))/(np.max(ei_)-np.min(ei_)))*255).astype(np.uint8))

Maybe rewrite as

im = ((ei_-np.max(ei_))/(np.max(ei_)-np.min(ei_))) # normalize 0..1
ax[r_, c_].imshow((im*255).astype(np.uint8))

@susnato
Copy link
Contributor Author

susnato commented Mar 15, 2022

Hi, @murphyk thanks for your feedback,

  1. I removed the TF and JAX installation code (Actually when I was going through the GanJAX repo I found these comments in the Installation section,

"However, Tensorfow allocate memory of the GPU on use (which is not optimal for running calculation with JAX). Therefore, you should install Tensorflow on the CPU instead of the GPU.")

That is why I was running that code to install tensorflow-cpu and other things however, it is working fine without those lines as you said.
I am sorry that I never checked without those lines.

  1. I removed comments like "This project aims to bring the power of JAX, a Python ... ".

  2. I changed from "Implementation_of_GANs_using_JAX" to "GAN-JAX-CelebA-demo".

  3. I rewrote the imshow line.

  4. The reason that I have to replicate so much of the config code from the original repo is because of a Pickle Error, Actually during the testing we load the pre-trained model using these two lines,

config = DCGANConfig().load(os.path.join(pretrained_model_path, 'config.pickle'))
params, state = load_jax_model(os.path.join(pretrained_model_path, 'generator'))

And I was getting this error, "ValueError: unsupported pickle protocol: 5", I searched online and found out that it happens because of "For pandas users who saved something to a pickle file with protocol 5 in python 3.8 and need to load it into python 3.6 which only supports protocol 4. Link for the StackOverflow thread is this : https://stackoverflow.com/questions/63329657/python-3-7-error-unsupported-pickle-protocol-5,
That's why I had to import pickle5 as pickle and change each block of code where the main repo loads something using pickle.

  1. Yes, Kaggle lets you train longer than Colab! For more information please visit https://www.kaggle.com/docs/notebooks and please refer to the "Technical Specifications" sections.

Pushing the commits in a moment!

@susnato
Copy link
Contributor Author

susnato commented Mar 15, 2022

Hi, @murphyk I did those changes as described above and pushed again, if you want to view through Google Colab, the link is the same, https://colab.research.google.com/drive/1YS-G6DML1xq3AhlNpKmh18Dc4tkn9T4U?usp=sharing .
Please let me know if I need to make any more changes or not.

@susnato susnato changed the title Added Implementation_of_GANs_using_JAX.ipynb Added GAN-JAX-CelebA-demo.ipynb Mar 15, 2022
@susnato
Copy link
Contributor Author

susnato commented Mar 15, 2022

Hi, @murphyk could you please tell me if the notebook I made is ok or not since I want to solve another issue but can't start until I know if I solved this properly or not.

@murphyk
Copy link
Member

murphyk commented Mar 16, 2022

LGTM!

@murphyk murphyk merged commit b8d2b5a into probml:main Mar 16, 2022
@murphyk
Copy link
Member

murphyk commented Mar 16, 2022

@susnato In the future, please submit your notebooks to https://github.com/probml/probml-notebooks/tree/main/notebooks.
The https://github.com/probml/probml-notebooks/tree/main/notebooks-d2l directory is reserved for notebooks derived from the d2l.ai book (these are mostly in torch, but are currently being converted to jax).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants