You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey! Great question. There are a few ways that you can go about saving models.
The simplest way is to note that NT params are made of standard python datastructures (tuples and lists) along with JAX arrays, which will be serialized to standard numpy arrays. Thus, one option is to use pickle to save the whole params tree, another is to flatten the tree, save using numpy.save or numpy.savez, and then save the tree structure using pickle.
For more details and sample code for this approach check out the thread over on Haiku: google-deepmind/dm-haiku#18
Another option that's a little bit more complicated is to use jax2tf to convert the model to tensorflow and then save the model as a SavedModel. This has the advantage that it's hermetic (so that you don't need to keep the code to construct the model intact).
In general, I would probably opt to save the model as numpy arrays during training and then if I wanted to have a longer term storage option to use the model on downstream tasks look into the SavedModel pipeline.
The text was updated successfully, but these errors were encountered: