Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S5: Longer compilation times #25

Open
stergiosba opened this issue Jun 1, 2024 · 3 comments
Open

S5: Longer compilation times #25

stergiosba opened this issue Jun 1, 2024 · 3 comments

Comments

@stergiosba
Copy link

Hey, thanks for providing purejaxrl is pretty awesome.

I have used the experimental S5 code that you provide for a part of my research and after version 0.4.27 (same for 0.4.28) of jaxlib I have been getting 5 times longer compilation times when I increase the n_layers of the S5. Any ideas why this might happen?

@stergiosba stergiosba changed the title Longer compilation times S5: Longer compilation times Jun 1, 2024
@luchris429
Copy link
Owner

Interesting! I'm not sure why that would happen. However, I do think the flax RNN documentation/structure may have changed after version 0.4.27, which could be why you're getting the error.

https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/rnncell_upgrade_guide.html

@stergiosba
Copy link
Author

stergiosba commented Jun 3, 2024

Well there is no error technically speaking. The code runs fine and the model trains fine. The only problem is that compilation on the XLA side of things takes longer. Version 7.4 of XLA (Jax 0.4.26) is much faster than 8.3 (Jax 0.4.28) in generating the code for a GPU device (hope this add clarity to the issue). Also I face this issue in two different GPUs on two different machines.

Anyways, maybe this will alert you on your future JAX endeavors. Thanks again.

@stergiosba
Copy link
Author

I ended up fixing this by making the StackedEncoderModel a scanned version of what you initially had.
There are some minimal code changes for the end user which maybe we can fix.

Some compilation benchmarks:

  1. 1 S5 Layer: 28 sec (old) vs 28 sec(new)
  2. 4 S5 Layers: 90 sec (old) vs 29 sec (new)
  3. 20 S5 Layers: 29 sec (new)
  4. 200 S5 Layers 29 sec (new) - crazy case just for test

I also attach the results from setting jax.config.update("jax_log_compiles", True)

20 layers S5:

Finished jaxpr to MLIR module conversion jit(train) in 1.3401036262512207 sec
Finished XLA compilation of jit(train) in 29.01173186302185 sec vs 
200 layers S5:

Finished jaxpr to MLIR module conversion jit(train) in 1.3581228256225586 sec
2024-06-29 18:11:22.686136: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 8.90GiB (9555618931 bytes) by rematerialization; only reduced to 10.13GiB (10874848652 bytes), down from 10.13GiB (10875108940 bytes) originally
Finished XLA compilation of jit(train) in 29.203452587127686 sec

Tests where run on a single NVIDIA RTX A5500:

Sat Jun 29 18:12:42 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A5500               Off |   00000000:01:00.0  On |                  Off |
| 30%   45C    P8             26W /  230W |     554MiB /  24564MiB |     22%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2311      G   /usr/lib/xorg/Xorg                            182MiB |
|    0   N/A  N/A      2497      G   /usr/bin/gnome-shell                          293MiB |
+-----------------------------------------------------------------------------------------+

Interestingly I do not get exactly the same learning performance. That can mean there is a bug somewhere, however I did test on Cartpole-v1, Acrobot-v1and Mountaincar-v0 and it successfully learns these envs.

Let me know if you are interested for a PR on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants