-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
S5: Longer compilation times #25
Comments
Interesting! I'm not sure why that would happen. However, I do think the https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/rnncell_upgrade_guide.html |
Well there is no error technically speaking. The code runs fine and the model trains fine. The only problem is that compilation on the XLA side of things takes longer. Version 7.4 of XLA (Jax 0.4.26) is much faster than 8.3 (Jax 0.4.28) in generating the code for a GPU device (hope this add clarity to the issue). Also I face this issue in two different GPUs on two different machines. Anyways, maybe this will alert you on your future JAX endeavors. Thanks again. |
I ended up fixing this by making the Some compilation benchmarks:
I also attach the results from setting
Tests where run on a single NVIDIA RTX A5500:
Interestingly I do not get exactly the same learning performance. That can mean there is a bug somewhere, however I did test on Let me know if you are interested for a PR on this. |
Hey, thanks for providing purejaxrl is pretty awesome.
I have used the experimental
S5
code that you provide for a part of my research and after version 0.4.27 (same for 0.4.28) ofjaxlib
I have been getting 5 times longer compilation times when I increase then_layers
of theS5
. Any ideas why this might happen?The text was updated successfully, but these errors were encountered: