Skip to content

How to save/load model? #1876

Answered by matthias-wright
dev-sora asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @dev-sora, here is an example:

import jax.numpy as jnp
import jax
import flax
import flax.linen as nn
from flax.training import train_state, checkpoints
import optax
import numpy as np


class Net(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.features)(x)
        x = nn.Dense(self.features)(x)
        return x


model = Net(features=2)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))

tx = optax.adam(learning_rate=0.0001)
state = train_state.TrainState.create(apply_fn=model.apply,
                                      params=params,
                                      tx=tx)


CKPT_DIR = 'ckpts'
checkpoints.save_checkp…

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
3 replies
@dev-sora
Comment options

@matthias-wright
Comment options

@dev-sora
Comment options

Answer selected by dev-sora
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants