Skip to content

Flax LSTM Layers taking a lot of RAM. Not able to train big datasets. #2192

Answered by cgarciae
sunny2309 asked this question in Q&A
Discussion options

You must be logged in to vote

Hey @sunny2309! Its not totally clear that its the LSTM layer that is at fault here. I would suggest the following:

  1. Try to create a jited train_step function as shown in the Annotated MNIST example. This can help lower memory consumption.
  2. Try to use the jax.profiler to find out which portion of your code might be causing the leak.

I'll be moving this into a discussion as its still unclear where the bug is. If the error still persists and you gather more information feel free to open another issue.

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@sunny2309
Comment options

Answer selected by marcvanzee
Comment options

You must be logged in to vote
1 reply
@sunny2309
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #2187 on June 13, 2022 15:06.