Flax + Diffrax: Tracer error #2891
Replies: 10 comments 7 replies
-
My guess is that this code:
```py
def _odefun(t, y, args):
return self.ode_term(y)
```
only works when you're within a Flax context (which Diffrax doesn't do,
logically). Instead, you'll have to turn `self.ode_term` into a pure
function, e.g. `lambda pure_fn(y): self.ode_term.apply(self.variables, y)`
before passing it in to Diffrax.
There may be better ways but I suspect this will work.
…On Mon, Feb 20, 2023 at 4:25 PM Filippo Vicentini ***@***.***> wrote:
I'm trying to use flax together with diffrax, but even for a simple
example I get a quite unreadable error.
Could you help me understand what is causing the error, and if it's
possible to make the two work together?
MWE:
import flax.linen as nn
import jax.numpy as jnp
import jax
import diffrax as dfx
from typing import Any
class NeuralODE(nn.Module):
hidden_dims : int = (4,)
solver: Any = dfx.Euler()
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
def _odefun(t, y, args):
return self.ode_term(y)
term = dfx.ODETerm(_odefun)
solution = dfx.diffeqsolve(term, self.solver, t0=0, t1=1, dt0=0.1, y0=x)
return solution
node = NeuralODE()
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = node.init(key, x)
Error:
---------------------------------------------------------------------------
JaxTransformError Traceback (most recent call last)
Input In [4], in <cell line: 26>()
24 x = jnp.ones(1)
25 key = jax.random.PRNGKey(1)
---> 26 pars = node.init(key, x)
[... skipping hidden 9 frame]
Input In [4], in NeuralODE.__call__(self, x)
16 return self.ode_term(y)
18 term = dfx.ODETerm(_odefun)
---> 19 solution = dfx.diffeqsolve(term, self.solver, t0=0, t1=1, dt0=0.1, y0=x)
20 return solution
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:82, in _JitWrapper.__call__(_JitWrapper__self, *args, **kwargs)
81 def __call__(__self, *args, **kwargs):
---> 82 return __self._fun_wrapper(False, args, kwargs)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:78, in _JitWrapper._fun_wrapper(self, is_lower, args, kwargs)
76 return self._cached.lower(dynamic, static)
77 else:
---> 78 dynamic_out, static_out = self._cached(dynamic, static)
79 return combine(dynamic_out, static_out.value)
[... skipping hidden 11 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:30, in _filter_jit_cache.<locals>.fun_wrapped(dynamic, static)
28 fun = hashable_combine(dynamic_fun, static_fun)
29 args, kwargs = hashable_combine(dynamic_spec, static_spec)
---> 30 out = fun(*args, **kwargs)
31 dynamic_out, static_out = partition(out, filter_out)
32 return dynamic_out, Static(static_out)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:858, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, discrete_terminating_event, max_steps, throw, solver_state, controller_state, made_jump)
834 init_state = _State(
835 y=y0,
836 tprev=tprev,
(...)
851 dense_save_index=dense_save_index,
852 )
854 #
855 # Main loop
856 #
--> 858 final_state, aux_stats = adjoint.loop(
859 args=args,
860 terms=terms,
861 solver=solver,
862 stepsize_controller=stepsize_controller,
863 discrete_terminating_event=discrete_terminating_event,
864 saveat=saveat,
865 t0=t0,
866 t1=t1,
867 dt0=dt0,
868 max_steps=max_steps,
869 init_state=init_state,
870 throw=throw,
871 passed_solver_state=passed_solver_state,
872 passed_controller_state=passed_controller_state,
873 )
875 #
876 # Finish up
877 #
879 if saveat.t0 or saveat.t1 or saveat.steps or (saveat.ts is not None):
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/adjoint.py:132, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
130 def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs):
131 del throw, passed_solver_state, passed_controller_state
--> 132 return self._loop_fn(**kwargs, is_bounded=True)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:484, in loop(solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, is_bounded)
481 base = compiled_num_steps
482 max_steps = min(max_steps, compiled_num_steps)
--> 484 final_state = bounded_while_loop(
485 cond_fun, body_fun, init_state, max_steps, base=base
486 )
487 else:
488 compiled_num_steps = None
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:137, in bounded_while_loop(cond_fun, body_fun, init_val, max_steps, base)
135 init_data = (cond_fun(init_val), init_val, 0)
136 rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
--> 137 _, val, _ = _while_loop(_cond_fun, body_fun, init_data, rounded_max_steps, base)
138 return val
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
238 if max_steps != base:
239 _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False)
--> 241 return lax.scan(_scan_fn, data, xs=None, length=base)[0]
[... skipping hidden 16 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.<locals>._scan_fn(_data, _)
233 _pred, _, _ = _data
234 _unvmap_pred = eqxi.unvmap_any(_pred)
--> 235 return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
[... skipping hidden 13 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.<locals>._call(_data)
229 def _call(_data):
--> 230 return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
238 if max_steps != base:
239 _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False)
--> 241 return lax.scan(_scan_fn, data, xs=None, length=base)[0]
[... skipping hidden 16 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.<locals>._scan_fn(_data, _)
233 _pred, _, _ = _data
234 _unvmap_pred = eqxi.unvmap_any(_pred)
--> 235 return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
[... skipping hidden 13 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.<locals>._call(_data)
229 def _call(_data):
--> 230 return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
238 if max_steps != base:
239 _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False)
--> 241 return lax.scan(_scan_fn, data, xs=None, length=base)[0]
[... skipping hidden 9 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.<locals>._scan_fn(_data, _)
233 _pred, _, _ = _data
234 _unvmap_pred = eqxi.unvmap_any(_pred)
--> 235 return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
[... skipping hidden 13 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.<locals>._call(_data)
229 def _call(_data):
--> 230 return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:211, in _while_loop(cond_fun, body_fun, data, max_steps, base)
208 pred, val, step = data
210 inplace_update = _InplaceUpdate(pred)
--> 211 new_val = body_fun(val, inplace_update)
213 def _make_update(_new_val, _val):
214 if isinstance(_new_val, HadInplaceUpdate):
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:137, in loop.<locals>.body_fun(state, inplace)
130 def body_fun(state, inplace):
131
132 #
133 # Actually do some differential equation solving! Make numerical steps, adapt
134 # step sizes, all that jazz.
135 #
--> 137 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
138 terms,
139 state.tprev,
140 state.tnext,
141 state.y,
142 args,
143 state.solver_state,
144 False if cannot_make_jump else state.made_jump,
145 )
147 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
148 # we get a negative value for y, and then get a NaN vector field. (And then
149 # everything breaks.) See #143.
150 y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/solver/euler.py:46, in Euler.step(***failed resolving arguments***)
44 del solver_state, made_jump
45 control = terms.contr(t0, t1)
---> 46 y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω
47 dense_info = dict(y0=y0, y1=y1)
48 return y1, None, dense_info, None, RESULTS.successful
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:138, in AbstractTerm.vf_prod(self, t, y, args, control)
94 def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree:
95 r"""The composition of [`diffrax.AbstractTerm.vf`][] and
96 [`diffrax.AbstractTerm.prod`][].
97
(...)
136 This function must be linear in `control`.
137 """
--> 138 return self.prod(self.vf(t, y, args), control)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:364, in WrapTerm.vf(self, t, y, args)
362 def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
363 t = t * self.direction
--> 364 return self.term.vf(t, y, args)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:173, in ODETerm.vf(self, t, y, args)
172 def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
--> 173 return self.vector_field(t, y, args)
Input In [4], in NeuralODE.__call__.<locals>._odefun(t, y, args)
15 def _odefun(t, y, args):
---> 16 return self.ode_term(y)
[... skipping hidden 2 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/flax/linen/linear.py:187, in Dense.__call__(self, inputs)
177 @compact
178 def __call__(self, inputs: Array) -> Array:
179 """Applies a linear transformation to the inputs along the last dimension.
180
181 Args:
(...)
185 The transformed input.
186 """
--> 187 kernel = self.param('kernel',
188 self.kernel_init,
189 (jnp.shape(inputs)[-1], self.features),
190 self.param_dtype)
191 if self.use_bias:
192 bias = self.param('bias', self.bias_init, (self.features,),
193 self.param_dtype)
[... skipping hidden 4 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/flax/core/tracers.py:36, in check_trace_level(base_level)
34 level = trace_level(current_trace())
35 if level != base_level:
---> 36 raise errors.JaxTransformError()
JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
—
Reply to this email directly, view it on GitHub
<#2887>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAJFUWVOABL6ABAAC7JSQ3WYOEHXANCNFSM6AAAAAAVCAJE7U>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
Thank you.
Do you have some ideas... def __call__(self, x):
_odefun = lambda t, y, args: self.ode_term.apply(self.variables, y)
term = dfx.ODETerm(_odefun)
solution = dfx.diffeqsolve(term, self.solver, t0=0, t1=1, dt0=0.1, y0=x)
return solution
node = NeuralODE()
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = node.init(key, x) Error: Show```python --------------------------------------------------------------------------- ScopeCollectionNotFound Traceback (most recent call last) Input In [63], in () 28 x = jnp.ones(1) 29 key = jax.random.PRNGKey(1) ---> 30 pars = node.init(key, x)
Input In [63], in NeuralODE.call(self, x) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:82, in _JitWrapper.call(_JitWrapper__self, *args, **kwargs) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:78, in _JitWrapper._fun_wrapper(self, is_lower, args, kwargs)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:30, in _filter_jit_cache..fun_wrapped(dynamic, static) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:858, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, discrete_terminating_event, max_steps, throw, solver_state, controller_state, made_jump) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/adjoint.py:132, in RecursiveCheckpointAdjoint.loop(failed resolving arguments) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:484, in loop(solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, is_bounded) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:137, in bounded_while_loop(cond_fun, body_fun, init_val, max_steps, base) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.._scan_fn(_data, _)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.._call(_data) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.._scan_fn(_data, _)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.._call(_data) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.._scan_fn(_data, _)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.._call(_data) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:211, in _while_loop(cond_fun, body_fun, data, max_steps, base) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:137, in loop..body_fun(state, inplace) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/solver/euler.py:46, in Euler.step(failed resolving arguments) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:138, in AbstractTerm.vf_prod(self, t, y, args, control) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:364, in WrapTerm.vf(self, t, y, args) File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:173, in ODETerm.vf(self, t, y, args) Input In [63], in NeuralODE.call..(t, y, args)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/flax/linen/linear.py:187, in Dense.call(self, inputs)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/flax/core/scope.py:815, in Scope.param(self, name, init_fn, unbox, *init_args) ScopeCollectionNotFound: Tried to access "kernel" from collection "params" in "/" but the collection is empty. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeCollectionNotFound)
|
Beta Was this translation helpful? Give feedback.
-
I understand why this is happening. |
Beta Was this translation helpful? Give feedback.
-
I'm trying to unbind the import flax.linen as nn
import jax.numpy as jnp
import jax
import diffrax as dfx
from typing import Any
class NeuralODE(nn.Module):
solver: Any = dfx.Euler()
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
ma_unbinded, ode_term_vars = self.ode_term.unbind()
print("pars are", ma_unbinded)
print("o are", o)
def _odefun(t, y, args):
print("args is", args)
ode_term = ma_unbinded.bind(args)
return ode_term.apply(self.variables, y)
term = dfx.ODETerm(_odefun)
solution = dfx.diffeqsolve(term, self.solver, t0=0, t1=1, dt0=0.1, y0=x, args=ode_term_vars)
return solution
node = NeuralODE()
print(hash(node))
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = node.init(key, x) gives the error: Show
pars are Dense(
# attributes
features = 1
use_bias = True
dtype = None
param_dtype = float32
precision = None
kernel_init = init
bias_init = zeros
)
o are FrozenDict({})
args is FrozenDict({})
---------------------------------------------------------------------------
ScopeCollectionNotFound Traceback (most recent call last)
Input In [94], in <cell line: 31>()
29 x = jnp.ones(1)
30 key = jax.random.PRNGKey(1)
---> 31 pars = node.init(key, x)
[... skipping hidden 9 frame]
Input In [94], in NeuralODE.__call__(self, x)
20 return ode_term.apply(self.variables, y)
22 term = dfx.ODETerm(_odefun)
---> 23 solution = dfx.diffeqsolve(term, self.solver, t0=0, t1=1, dt0=0.1, y0=x, args=ode_term_vars)
25 return solution
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:82, in _JitWrapper.__call__(_JitWrapper__self, *args, **kwargs)
81 def __call__(__self, *args, **kwargs):
---> 82 return __self._fun_wrapper(False, args, kwargs)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:78, in _JitWrapper._fun_wrapper(self, is_lower, args, kwargs)
76 return self._cached.lower(dynamic, static)
77 else:
---> 78 dynamic_out, static_out = self._cached(dynamic, static)
79 return combine(dynamic_out, static_out.value)
[... skipping hidden 11 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/equinox/jit.py:30, in _filter_jit_cache.<locals>.fun_wrapped(dynamic, static)
28 fun = hashable_combine(dynamic_fun, static_fun)
29 args, kwargs = hashable_combine(dynamic_spec, static_spec)
---> 30 out = fun(*args, **kwargs)
31 dynamic_out, static_out = partition(out, filter_out)
32 return dynamic_out, Static(static_out)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:858, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, discrete_terminating_event, max_steps, throw, solver_state, controller_state, made_jump)
834 init_state = _State(
835 y=y0,
836 tprev=tprev,
(...)
851 dense_save_index=dense_save_index,
852 )
854 #
855 # Main loop
856 #
--> 858 final_state, aux_stats = adjoint.loop(
859 args=args,
860 terms=terms,
861 solver=solver,
862 stepsize_controller=stepsize_controller,
863 discrete_terminating_event=discrete_terminating_event,
864 saveat=saveat,
865 t0=t0,
866 t1=t1,
867 dt0=dt0,
868 max_steps=max_steps,
869 init_state=init_state,
870 throw=throw,
871 passed_solver_state=passed_solver_state,
872 passed_controller_state=passed_controller_state,
873 )
875 #
876 # Finish up
877 #
879 if saveat.t0 or saveat.t1 or saveat.steps or (saveat.ts is not None):
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/adjoint.py:132, in RecursiveCheckpointAdjoint.loop(***failed resolving arguments***)
130 def loop(self, *, throw, passed_solver_state, passed_controller_state, **kwargs):
131 del throw, passed_solver_state, passed_controller_state
--> 132 return self._loop_fn(**kwargs, is_bounded=True)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:484, in loop(solver, stepsize_controller, discrete_terminating_event, saveat, t0, t1, dt0, max_steps, terms, args, init_state, is_bounded)
481 base = compiled_num_steps
482 max_steps = min(max_steps, compiled_num_steps)
--> 484 final_state = bounded_while_loop(
485 cond_fun, body_fun, init_state, max_steps, base=base
486 )
487 else:
488 compiled_num_steps = None
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:137, in bounded_while_loop(cond_fun, body_fun, init_val, max_steps, base)
135 init_data = (cond_fun(init_val), init_val, 0)
136 rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
--> 137 _, val, _ = _while_loop(_cond_fun, body_fun, init_data, rounded_max_steps, base)
138 return val
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
238 if max_steps != base:
239 _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False)
--> 241 return lax.scan(_scan_fn, data, xs=None, length=base)[0]
[... skipping hidden 16 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.<locals>._scan_fn(_data, _)
233 _pred, _, _ = _data
234 _unvmap_pred = eqxi.unvmap_any(_pred)
--> 235 return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
[... skipping hidden 13 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.<locals>._call(_data)
229 def _call(_data):
--> 230 return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
238 if max_steps != base:
239 _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False)
--> 241 return lax.scan(_scan_fn, data, xs=None, length=base)[0]
[... skipping hidden 16 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.<locals>._scan_fn(_data, _)
233 _pred, _, _ = _data
234 _unvmap_pred = eqxi.unvmap_any(_pred)
--> 235 return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
[... skipping hidden 13 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.<locals>._call(_data)
229 def _call(_data):
--> 230 return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:241, in _while_loop(cond_fun, body_fun, data, max_steps, base)
238 if max_steps != base:
239 _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False)
--> 241 return lax.scan(_scan_fn, data, xs=None, length=base)[0]
[... skipping hidden 9 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:235, in _while_loop.<locals>._scan_fn(_data, _)
233 _pred, _, _ = _data
234 _unvmap_pred = eqxi.unvmap_any(_pred)
--> 235 return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None
[... skipping hidden 13 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:230, in _while_loop.<locals>._call(_data)
229 def _call(_data):
--> 230 return _while_loop(cond_fun, body_fun, _data, max_steps // base, base)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/misc/bounded_while_loop.py:211, in _while_loop(cond_fun, body_fun, data, max_steps, base)
208 pred, val, step = data
210 inplace_update = _InplaceUpdate(pred)
--> 211 new_val = body_fun(val, inplace_update)
213 def _make_update(_new_val, _val):
214 if isinstance(_new_val, HadInplaceUpdate):
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/integrate.py:137, in loop.<locals>.body_fun(state, inplace)
130 def body_fun(state, inplace):
131
132 #
133 # Actually do some differential equation solving! Make numerical steps, adapt
134 # step sizes, all that jazz.
135 #
--> 137 (y, y_error, dense_info, solver_state, solver_result) = solver.step(
138 terms,
139 state.tprev,
140 state.tnext,
141 state.y,
142 args,
143 state.solver_state,
144 False if cannot_make_jump else state.made_jump,
145 )
147 # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
148 # we get a negative value for y, and then get a NaN vector field. (And then
149 # everything breaks.) See #143.
150 y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/solver/euler.py:46, in Euler.step(***failed resolving arguments***)
44 del solver_state, made_jump
45 control = terms.contr(t0, t1)
---> 46 y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω
47 dense_info = dict(y0=y0, y1=y1)
48 return y1, None, dense_info, None, RESULTS.successful
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:138, in AbstractTerm.vf_prod(self, t, y, args, control)
94 def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree:
95 r"""The composition of [`diffrax.AbstractTerm.vf`][] and
96 [`diffrax.AbstractTerm.prod`][].
97
(...)
136 This function must be linear in `control`.
137 """
--> 138 return self.prod(self.vf(t, y, args), control)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:364, in WrapTerm.vf(self, t, y, args)
362 def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
363 t = t * self.direction
--> 364 return self.term.vf(t, y, args)
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/diffrax/term.py:173, in ODETerm.vf(self, t, y, args)
172 def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
--> 173 return self.vector_field(t, y, args)
Input In [94], in NeuralODE.__call__.<locals>._odefun(t, y, args)
18 print("args is", args)
19 ode_term = ma_unbinded.bind(args)
---> 20 return ode_term.apply(self.variables, y)
[... skipping hidden 6 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/flax/linen/linear.py:187, in Dense.__call__(self, inputs)
177 @compact
178 def __call__(self, inputs: Array) -> Array:
179 """Applies a linear transformation to the inputs along the last dimension.
180
181 Args:
(...)
185 The transformed input.
186 """
--> 187 kernel = self.param('kernel',
188 self.kernel_init,
189 (jnp.shape(inputs)[-1], self.features),
190 self.param_dtype)
191 if self.use_bias:
192 bias = self.param('bias', self.bias_init, (self.features,),
193 self.param_dtype)
[... skipping hidden 1 frame]
File ~/Documents/pythonenvs/netket/python-3.10.6/lib/python3.10/site-packages/flax/core/scope.py:815, in Scope.param(self, name, init_fn, unbox, *init_args)
813 if not self.is_mutable_collection('params'):
814 if self.is_collection_empty('params'):
--> 815 raise errors.ScopeCollectionNotFound('params', name, self.path_text)
816 raise errors.ScopeParamNotFoundError(name, self.path_text)
817 value = init_fn(self.make_rng('params'), *init_args)
ScopeCollectionNotFound: Tried to access "kernel" from collection "params" in "/" but the collection is empty. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeCollectionNotFound)
</p>
</details> |
Beta Was this translation helpful? Give feedback.
-
AHAH! Found it. I have a MWE not depending on Diffrax import flax.linen as nn
import jax.numpy as jnp
import jax
import diffrax as dfx
from typing import Any
class Test(nn.Module):
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
ma_unbinded, ode_term_vars = self.ode_term.unbind()
print("pars are", ma_unbinded)
print("o are", ode_term_vars)
ma_binded = ma_unbinded.bind(ode_term_vars)
print("ma_binded is", ma_binded)
solution = ma_binded(x)
return solution
node = Test()
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = node.init(key, x) |
Beta Was this translation helpful? Give feedback.
-
Ah, yes. Then just check whether you're in init or not:
https://flax.readthedocs.io/en/latest/_modules/flax/linen/module.html#Module.is_initializing
…On Mon, Feb 20, 2023 at 5:04 PM Filippo Vicentini ***@***.***> wrote:
AHAH! Found it.
It's an error in the initialisation.
I have a MWE not depending on Diffrax
import flax.linen as nnimport jax.numpy as jnpimport jaximport diffrax as dfxfrom typing import Any
class Test(nn.Module):
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
ma_unbinded, ode_term_vars = self.ode_term.unbind()
print("pars are", ma_unbinded)
print("o are", ode_term_vars)
ma_binded = ma_unbinded.bind(ode_term_vars)
print("ma_binded is", ma_binded)
solution = ma_binded(x)
return solution
node = NeuralODE()
x = jnp.ones(1)key = jax.random.PRNGKey(1)pars = node.init(key, x)
—
Reply to this email directly, view it on GitHub
<#2887 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAJFUXUZ2MOYTOIVQZZBBDWYOI2DANCNFSM6AAAAAAVCAJE7U>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
Yep, i figured that... I managed to make this work, but the solution is somewhat 'ugly' in my opinion because you have to "break through" the flax abstraction and manually add the parameters using import flax.linen as nn
import jax.numpy as jnp
import jax
from typing import Any
class Test(nn.Module):
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
ma_unbinded, ode_term_vars = self.ode_term.unbind()
if self.is_initializing():
ode_term_vars = self.ode_term.init(self.make_rng("params"), x)
self.scope.put_variable("params", "ode_term", ode_term_vars["params"])
ma_binded = ma_unbinded.bind(ode_term_vars)
solution = ma_binded(x)
return solution
ma = Test()
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = ma.init(key, x)
print("::::::::::::::::::::::::::::::")
print("pars are:", pars)
res = ma.apply(pars, x) |
Beta Was this translation helpful? Give feedback.
-
Oh no that shouldn't be necessary, I think!
…On Mon, Feb 20, 2023 at 8:06 PM Filippo Vicentini ***@***.***> wrote:
Yep, i figured that... I managed to make this work, but the solution is
somewhat 'ugly' in my opinion because you have to "break through" the flax
abstraction and manually add the parameters using put_variables.
import flax.linen as nnimport jax.numpy as jnpimport jaxfrom typing import Any
class Test(nn.Module):
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
ma_unbinded, ode_term_vars = self.ode_term.unbind()
if self.is_initializing():
ode_term_vars = self.ode_term.init(self.make_rng("params"), x)
self.scope.put_variable("params", "ode_term", ode_term_vars["params"])
ma_binded = ma_unbinded.bind(ode_term_vars)
solution = ma_binded(x)
return solution
ma = Test()
x = jnp.ones(1)key = jax.random.PRNGKey(1)pars = ma.init(key, x)print("::::::::::::::::::::::::::::::")print("pars are:", pars)res = ma.apply(pars, x)
—
Reply to this email directly, view it on GitHub
<#2887 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAAJFUUY7OHVI3SNJZ2IAR3WYO6FDANCNFSM6AAAAAAVCAJE7U>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Beta Was this translation helpful? Give feedback.
-
Then... do you have a better idea? import flax.linen as nn
import jax.numpy as jnp
import jax
from typing import Any
import diffrax as dfx
class Test(nn.Module):
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
ode_term_unbinded, ode_term_vars = self.ode_term.unbind()
if self.is_initializing():
ode_term_vars = self.ode_term.init(self.make_rng("params"), x)
self.scope.put_variable("params", "ode_term", ode_term_vars["params"])
def odefun(t, y, args):
ode_term = ode_term_unbinded.bind(args)
return ode_term(y)
term = dfx.ODETerm(odefun)
solution = dfx.diffeqsolve(term, dfx.Euler(), t0=0, t1=1, dt0=0.1, y0=x, args=ode_term_vars)
return solution
ma = Test()
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = ma.init(key, x)
print("::::::::::::::::::::::::::::::")
print("pars are:", pars)
res = ma.apply(pars, x) What am I missing? |
Beta Was this translation helpful? Give feedback.
-
Hey @PhilipVinc, I cleaned your example a bit: import flax.linen as nn
import jax.numpy as jnp
import jax
from typing import Any
import diffrax as dfx
class Test(nn.Module):
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
if self.is_initializing():
self.ode_term(x)
ode_params = self.ode_term.variables['params']
def odefun(t, y, params):
return self.ode_term.apply({'params': params}, y)
return dfx.diffeqsolve(
dfx.ODETerm(odefun), dfx.Euler(), t0=0, t1=1, dt0=0.1, y0=x, args=ode_params)
ma = Test()
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = ma.init(key, x)
print("::::::::::::::::::::::::::::::")
print("pars are:", pars)
res = ma.apply(pars, x) That said, if possible consider performing import flax.linen as nn
import jax.numpy as jnp
import jax
from typing import Any
import diffrax as dfx
class Test(nn.Module):
def setup(self):
self.ode_term = nn.Dense(features=1)
def __call__(self, x):
return self.ode_term(x)
ma = Test()
x = jnp.ones(1)
key = jax.random.PRNGKey(1)
pars = ma.init(key, x)
def odefun(t, y, params):
return ma.apply({'params': params}, y)
solution = dfx.diffeqsolve(
dfx.ODETerm(odefun), dfx.Euler(), t0=0, t1=1, dt0=0.1, y0=x, args=pars['params'])
print("::::::::::::::::::::::::::::::")
print("pars are:", pars)
print(f"Solution is: {solution}") |
Beta Was this translation helpful? Give feedback.
-
I'm trying to use
flax
together withdiffrax
, but even for a simple example I get a quite unreadable error.Could you help me understand what is causing the error, and if it's possible to make the two work together?
MWE:
Error:
Show
Beta Was this translation helpful? Give feedback.
All reactions