Best way to convert PyTorch recurrent model to Flax? #3829
Unanswered
Chulabhaya
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all! I'm in the process of converting some of my code to Jax, and one thing that I'm currently struggling to figure out is how to set up my recurrent model correctly. In PyTorch I can set up a recurrent model pretty easily as below, where I can set up an LSTM that is able to handle batches of sequence data with variable lengths. Furthermore, with the LSTM I am passing the hidden state around when doing inference for action selection.
The part I'm struggling with in particular is how to convert the LSTM layer, handle batches of variable length sequences, and pass the hidden state during inference. I've looked at a couple different codebases on GitHub about how they handle this, but they all seem to do it differently. Based on the docs here, it seems like I can use a
seq_lengths
variable very similar to PyTorch to handle the variable lengths. How do I implement the LSTM, is it with ajax.lax.scan
across multiple cells? Etc.Thanks in advance for any help!
Beta Was this translation helpful? Give feedback.
All reactions