How to save/load model? #1876
-
As the title says. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Hi @dev-sora, here is an example: import jax.numpy as jnp
import jax
import flax
import flax.linen as nn
from flax.training import train_state, checkpoints
import optax
import numpy as np
class Net(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.features)(x)
x = nn.Dense(self.features)(x)
return x
model = Net(features=2)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
tx = optax.adam(learning_rate=0.0001)
state = train_state.TrainState.create(apply_fn=model.apply,
params=params,
tx=tx)
CKPT_DIR = 'ckpts'
checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=0)
restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
assert jax.tree_util.tree_all(jax.tree_multimap(lambda x, y: (x == y).all(), state.params, restored_state.params)) |
Beta Was this translation helpful? Give feedback.
-
The |
Beta Was this translation helpful? Give feedback.
Hi @dev-sora, here is an example: