Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax NNX and Orbax Checkpointing require hacking to work together #4383

Open
hdrwilkinson opened this issue Nov 15, 2024 · 3 comments
Open

Comments

@hdrwilkinson
Copy link

I'm building a system using flax.nnx and orbax.checkpointing. However, it is overly complicated on how to save and restore models due to the new jax.random.key() being used in flax.nnx rather than jax.random.PRNGkey().

I have had to create a workaround where all layers with rng and key in their path are changed from dtype=key<fry> to a format appropriate for saving. Then, upon restoration, they need to be shanged back.

I am attaching a link to a notebook explaining what I've done but I would be keen to hear if there are simpler workarounds? Or, preferably, if there is a way to simple save and restore models?

https://colab.research.google.com/drive/1ozln9ejG7eRtxvbkqHYU3K6OyPvveH9w?usp=sharing

Note: I am also adding an issue to orbax to see if there is a fix their side (#1337).

@cgarciae
Copy link
Collaborator

Thank you! I'll contact the Orbax team to see if they can fix this on their end.

@mishmish66
Copy link

mishmish66 commented Nov 17, 2024

Hey! Here's a quick and dirty workaround.

Generally the idea is to use nnx.split with the NNX filter functionality to split the nnx.RngState types out of the state and then not save those.

graphdef, rng_state, other_state = nnx.split(model, nnx.RngState, ...)

and then just saving the other_state instead of the full thingy. I've edited your colab notebook to demonstrate this.

This means that RNG state will not be restored, which might be sub-optimal for certain scenarios but should work for most stuff. Hope it helps!

@jkyl
Copy link

jkyl commented Nov 25, 2024

Another workaround for Dropout layers, and maybe custom layers too if they follow the same pattern, is to initialize them without the rngs arg, and only pass the RNG at __call__ time, like so:

import jax.numpy as jnp
import orbax.checkpoint as ocp
from flax import nnx

# Init dropout without rng arg.
model = nnx.Dropout(0.5)

# Pass RNG at call time.
output = model(jnp.ones(()), rngs=nnx.Rngs(0))

# This now works.
ckpt_dir = ocp.test_utils.erase_and_create_empty("/tmp/my-checkpoints/")
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(ckpt_dir / "state", nnx.split(model)[1])

Versus if the RNG is supplied at initialization, the last line throws the following:

TypeError: Cannot interpret 'key<fry>' as a data type

But, this is only a workaround, as the RNG state will still not be serialized, and it makes for a more verbose call signature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants