jax.Array must be fully replicated to be saved in aggregate file #3143
Unanswered
tatami-galaxy
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Hey @tatami-galaxy, can you give a minimal example? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm trying to save a checkpoint and getting this error message. Saving code :
ckpt = {'state': state, 'config': model.config} save_args = orbax_utils.save_args_from_target(ckpt) checkpoint_manager.save(global_step + 1, ckpt, save_kwargs={'save_args': save_args})
state
is an instance offlax.training.train_state
. What could be causing this? I tried disablingjax.Array
withjax.config.update('jax_array', False)
but that does not work with jax and jaxlib 0.4.7.Beta Was this translation helpful? Give feedback.
All reactions