You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to reinitialize the moments mu and nu of an adam optimizer from time to time. The train state is stored using the flax.training.train_state.TrainState API. The important parts of the code look like this:
# This is inspired by the tutorial on # [Surgery with Optimizers](https://flax.readthedocs.io/en/latest/guides/model_surgery.html#surgery-with-optimizers) # which explains how to update the optimizer statedefinit_opt_params(params):
# flattenflat_params=traverse_util.flatten_dict(params, sep="/")
# modifyflat_params=tree_util.tree_map(lambdax: jnp.zeros_like(x), flat_params)
# unflattenreturntraverse_util.unflatten_dict(flat_params, sep="/")
defcreate_train_state(self, module, input_shape, rng, learning_rate):
params=module.init(rng, jnp.ones(input_shape), jnp.float32)["params"]
tx=optax.adam(learning_rate=learning_rate)
opt_state=tx.init(params)
returnTrainState(
apply_fn=module.apply,
params=params,
tx=tx,
opt_state=opt_state,
step=jnp.array(0),
metrics=Metrics.empty(),
)
...
# Then training goes like this:forbatchindataloader:
# does a adam gradient descent update on the parameters of the modelstate=self.train_step(state, batch)
ifreinit_weight:
# initialize moment estimatesparams_mu=init_opt_params(state.opt_state[0].mu)
params_nu=init_opt_params(state.opt_state[0].nu)
new_opt_state= (state.opt_state[0]._replace(
mu=params_mu,
nu=params_nu,
)
) +state.opt_state[1:]
state=state.replace(opt_state=new_opt_state)
I get the following error:
ValueError: The number of updates and states has to be the same in chain! Make sure you have called init first!
Basically, I want to update the params of the module using the optimizer, and after that, based on that update, reset some of the params of the two moments of the optimizer state to zero.
SOLVED: See below
EDIT:
new_opt_state is not a tuple of states -> add trailing comma to first operand in the addition
newly initialized moments need to be read-only -> flax.core.freeze(params_mu)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello friends,
I want to reinitialize the moments
mu
andnu
of an adam optimizer from time to time. The train state is stored using theflax.training.train_state.TrainState
API. The important parts of the code look like this:I get the following error:
ValueError: The number of updates and states has to be the same in chain! Make sure you have called init first!
Basically, I want to update the params of the module using the optimizer, and after that, based on that update, reset some of the params of the two moments of the optimizer state to zero.
SOLVED: See below
EDIT:
new_opt_state
is not a tuple of states -> add trailing comma to first operand in the additionflax.core.freeze(params_mu)
Beta Was this translation helpful? Give feedback.
All reactions