Using precomputed mean and var for BatchNorm #673
Answered
by
avital
matthias-wright
asked this question in
Q&A
-
Hi all, I am trying to load precomputed batch stats into the From the source code I gather that both the mean and the variance are registered as a state. Here is what I tried: import flax.nn as nn
from jax import random
key = random.PRNGKey(0)
scale = random.normal(key, shape=(3,))
bias = random.normal(key, shape=(3,))
mean = random.normal(key, shape=(3,))
var = random.normal(key, shape=(3,))
x = random.normal(key, shape=(1, 5, 5, 3))
batch_stats = nn.Collection({'mean': mean, 'var': var}).mutate()
bn = nn.normalization.BatchNorm.partial(bias_init=lambda *_ : bias,
scale_init=lambda *_ : scale,
use_running_average=True,
batch_stats=batch_stats)
x_out = bn.init(key, x) The error message I get is: Has anyone here tried this before? Thanks! |
Beta Was this translation helpful? Give feedback.
Answered by
avital
Nov 26, 2020
Replies: 1 comment 4 replies
-
Hi @matthias-wright -- you're using the deprecated |
Beta Was this translation helpful? Give feedback.
4 replies
Answer selected by
matthias-wright
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @matthias-wright -- you're using the deprecated
flax.nn
API -- please try usingflax.linen
instead, where "batch_stats" are simply another variable collection just like "params".