How to load PyTorch checkpoints into JAX/Flax? #927
Answered
by
marcvanzee
marcvanzee
asked this question in
Q&A
-
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
Jan 22, 2021
Replies: 1 comment 6 replies
-
Pytorch checkpoints contain a
Often @nikitakit wrote the following code for importing PyTorch BERT checkpoints into a Flax model: https://github.com/nikitakit/flax_bert/blob/master/import_weights.py |
Beta Was this translation helpful? Give feedback.
6 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Pytorch checkpoints contain a
state_dict
with all the weights/parameters for the models, and converting it to Flax involves:NCHW
dimensions for conv weights.Often
flax.traverse_util.flatten_dict
is useful, because you only need to operate on a flat dict instead of a nested dict. Once they align you useunflatten_dict
to get the normal form back.@nikitakit wrote the following code for importing PyTorch BERT checkpoints into a Flax model: https://github.com/nikitakit/flax_bert/blob/master/import_weights.py