You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am working with discrete dynamical systems that depend on some parameters that are to be trained. The key point is that I want to express the state of the dynamical system not as an array, but as a Dict (or more generally as a PyTree) so that I can write e.g. state["position"] instead of the much less readable state[idx].
In the following minimal example, I demonstrate the issues I am running into when trying to do this using generic tools from jax.tree_util that act on a statePyTree. The compilation takes a lot of time and the resulting (pre-optimised) HLO files are up to 26 MBytes!
Finally, as a comparison I considered flattening and unflattening the states into a dict inside every call of the dynamics function. I thought that this might be inefficient due to this comment. However, it performs better than the PyTree version, resulting in (pre-optimised) HLO files of up to 1.1Mbytes.
Explicitly convert from arrays to dicts and back at the innermost function
I believe the issue is not with pytrees per se, but rather with defining hundreds of array arguments, and compiling functions which take hundreds of array arguments. Generally speaking, compilation costs should not scale with the size of the individual arrays being operated on, but we do expect them to scale with the number of array objects being passed to the function. In the best case, you might achieve linear scaling – but I suspect in reality you'll see closer to quadratic scaling with the number of array inputs. This is not unexpected, because it will generally lead to much larger programs which require much more logic to optimize.
I am working with discrete dynamical systems that depend on some parameters that are to be trained. The key point is that I want to express the state of the dynamical system not as an array, but as a
Dict
(or more generally as aPyTree
) so that I can write e.g.state["position"]
instead of the much less readablestate[idx]
.In the following minimal example, I demonstrate the issues I am running into when trying to do this using generic tools from
jax.tree_util
that act on astate
PyTree
. The compilation takes a lot of time and the resulting (pre-optimised) HLO files are up to26
MBytes!Implementation with PyTrees
Resulting XLA dump for the above
For comparison, an equivalent version of the above example that only acts on arrays results in (pre-optimised) HLO files of up to 50KBytes.
Fully vectorised code (no pytrees/dicts involved)
Resulting XLA dump for the above
Finally, as a comparison I considered flattening and unflattening the states into a dict inside every call of the dynamics function. I thought that this might be inefficient due to this comment. However, it performs better than the PyTree version, resulting in (pre-optimised) HLO files of up to 1.1Mbytes.
Explicitly convert from arrays to dicts and back at the innermost function
Resulting XLA dump for the above
So my questions is: what is the best way to write/wrap a function that acts on a PyTree, without having prohibitively large compile times?
The text was updated successfully, but these errors were encountered: