Replies: 2 comments 2 replies
-
If you are new to JAX and Flax, I recommend starting with some beginner material before diving into transforms. The JAX 101 tutorial is pretty great, and for Flax we have Flax basics which is good to start with. In your case, as the error message says, you should pass the kwarg
|
Beta Was this translation helpful? Give feedback.
-
There is a bit of an issue with this example. import jax
from flax import linen as nn
from jax import lax, random, numpy as jnp
import flax
class StatefulMLP(nn.Module):
train: bool
@nn.compact
def __call__(self, x):
h = nn.Dense(4, name='hidden')(x)
h = nn.BatchNorm(axis_name='batch')(h, use_running_average=not self.train)
h = nn.relu(h)
return nn.Dense(1, name='out')(h)
class LinenStatefulVmapMLP(nn.Module):
@nn.compact
def __call__(self, xs, *, train):
VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0)
return VmapMLP(train=train, name='mlp')(xs)
xs = jnp.ones((32, 4))
variables = LinenStatefulVmapMLP().init(random.PRNGKey(0), xs, train=True) |
Beta Was this translation helpful? Give feedback.
-
I am trying to build on the lift example reported here. In particular the stateful one. Yet when I run the example:
I get
Traceback (most recent call last):
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3378, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 24, in
variables = LinenStatefulVmapMLP().init(random.PRNGKey(0), xs)
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/flax/linen/module.py", line 1381, in init
_, v_out = self.init_with_output(
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/flax/linen/module.py", line 1335, in init_with_output
return init_with_output(
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/flax/core/scope.py", line 897, in wrapper
return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/flax/core/scope.py", line 865, in wrapper
y = fn(root, *args, **kwargs)
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/flax/linen/module.py", line 1798, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/flax/linen/module.py", line 411, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/Users/cdalmaso/Library/Caches/pypoetry/virtualenvs/multi-task-YYPKsy_k-py3.8/lib/python3.8/site-packages/flax/linen/module.py", line 735, in _call_wrapped_method
y = fun(self, *args, **kwargs)
TypeError: call() missing 1 required keyword-only argument: 'train'
I am new to jax and this is an advance topic so I might be doing something stupid.
System information
Beta Was this translation helpful? Give feedback.
All reactions