You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Just wanted to showcase a common gotcha (originally reported by @lucasb-eyer and answered by @hawkinsp. This should be documented somewhere else probably, tooling that @jheek is working can help with an explicit warning to users)
If you jit a function that closes over a jnp.array, then that array is compiled into the jitted function, and the Python reference is lost. This implies two things:
Changes to the variable will not be reflected in subsequent calls to the jitted function
Memory will be duplicated: both in the compiled program, and the original jnp.array.
In general, the workaround is to take in any arrays that are either large or may change as arguments to the jitted function, rather than closing over them.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
(cross-posted as jax-ml/jax#6370)
Just wanted to showcase a common gotcha (originally reported by @lucasb-eyer and answered by @hawkinsp. This should be documented somewhere else probably, tooling that @jheek is working can help with an explicit warning to users)
If you jit a function that closes over a
jnp.array
, then that array is compiled into the jitted function, and the Python reference is lost. This implies two things:jnp.array
.In general, the workaround is to take in any arrays that are either large or may change as arguments to the jitted function, rather than closing over them.
For example, this small code snippet:
will print:
Beta Was this translation helpful? Give feedback.
All reactions