Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow compilation for functions acting on PyTrees #4667

Open
nrontsis opened this issue Oct 21, 2020 · 2 comments
Open

Slow compilation for functions acting on PyTrees #4667

nrontsis opened this issue Oct 21, 2020 · 2 comments

Comments

@nrontsis
Copy link

nrontsis commented Oct 21, 2020

I am working with discrete dynamical systems that depend on some parameters that are to be trained. The key point is that I want to express the state of the dynamical system not as an array, but as a Dict (or more generally as a PyTree) so that I can write e.g. state["position"] instead of the much less readable state[idx].

In the following minimal example, I demonstrate the issues I am running into when trying to do this using generic tools from jax.tree_util that act on a state PyTree. The compilation takes a lot of time and the resulting (pre-optimised) HLO files are up to 26 MBytes!

Implementation with PyTrees
import jax.numpy as np
from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten
from jax import jit, partial, grad
from jax.lax import scan


@partial(jit, static_argnums=(0,))
def sum_of_squares_loss(dynamics: callable, states, parameters):
    initial_states = tree_map(lambda s: s[0], states)
    horizon_length = tree_flatten(states)[0][0].shape[0]
    predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters)
    errors = tree_multimap(lambda s, p: np.sum((s - p)**2, keepdims=True), states, predictions)
    return tree_reduce(sum, errors)[0]


sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,))


def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters):
    return scan(
        f=lambda state, _: (dynamics(state, parameters), state),
        init=initial_states,
        xs=None,
        length=horizon_length
    )[1]


### Example call
PROPAGATION_HORIZON_LENGTH = 200
STATE_DIMENSION = 100
STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)]


def example_dynamics(states, parameters):
    return tree_multimap(lambda s, p: 0.999*s + 1e-3*p, states, parameters)


states = {name: np.ones(PROPAGATION_HORIZON_LENGTH) for name in STATE_NAMES}
parameters = {name: 1.0 for name in STATE_NAMES}

# Compile functions
sum_of_squares_loss(example_dynamics, states, parameters)
sum_of_squares_loss_gradient(example_dynamics, states, parameters)

Resulting XLA dump for the above

For comparison, an equivalent version of the above example that only acts on arrays results in (pre-optimised) HLO files of up to 50KBytes.

Fully vectorised code (no pytrees/dicts involved)
import jax.numpy as np
from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten
from jax import jit, partial, grad
from jax.lax import scan


@partial(jit, static_argnums=(0,))
def sum_of_squares_loss(dynamics: callable, states, parameters):
    initial_states = states[0]
    horizon_length = states.shape[0]
    predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters)
    return np.sum(predictions - states)


sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,))


def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters):
    return scan(
        f=lambda state, _: (dynamics(state, parameters), state),
        init=initial_states,
        xs=None,
        length=horizon_length
    )[1]


### Example call
PROPAGATION_HORIZON_LENGTH = 200
STATE_DIMENSION = 100
STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)]


def example_dynamics(states, parameters):
    return 0.999*states + 1e-3*parameters

states = np.ones((PROPAGATION_HORIZON_LENGTH, STATE_DIMENSION))
parameters = np.ones((STATE_DIMENSION,))

# Compile function
sum_of_squares_loss(example_dynamics, states, parameters)
sum_of_squares_loss_gradient(example_dynamics, states, parameters)

Resulting XLA dump for the above

Finally, as a comparison I considered flattening and unflattening the states into a dict inside every call of the dynamics function. I thought that this might be inefficient due to this comment. However, it performs better than the PyTree version, resulting in (pre-optimised) HLO files of up to 1.1Mbytes.

Explicitly convert from arrays to dicts and back at the innermost function

import jax.numpy as np
from jax.tree_util import tree_multimap, tree_map, tree_reduce, tree_flatten
from jax import jit, partial, grad
from jax.lax import scan


@partial(jit, static_argnums=(0,))
def sum_of_squares_loss(dynamics: callable, states, parameters):
    initial_states = states[0]
    horizon_length = states.shape[0]
    predictions = propagate_dynamics(dynamics, initial_states, horizon_length, parameters)
    return np.sum(predictions - states)


sum_of_squares_loss_gradient = jit(grad(sum_of_squares_loss, argnums=2), static_argnums=(0,))


def propagate_dynamics(dynamics: callable, initial_states, horizon_length, parameters):
    return scan(
        f=lambda state, _: (dynamics(state, parameters), state),
        init=initial_states,
        xs=None,
        length=horizon_length
    )[1]


### Example call
PROPAGATION_HORIZON_LENGTH = 200
STATE_DIMENSION = 100
STATE_NAMES = ["state_" + str(i) for i in range(STATE_DIMENSION)]


def example_dynamics(states, parameters):
    states_dict = {n: s for n, s in zip(STATE_NAMES, states)}
    parameters_dict = {n: p for n, p in zip(STATE_NAMES, parameters)}
    new_states_dict = {n: 0.999*states_dict[n] + 1e-3*parameters_dict[n] for n in STATE_NAMES}
    return np.array(list(new_states_dict.values()))


states = np.ones((PROPAGATION_HORIZON_LENGTH, STATE_DIMENSION))
parameters = np.ones((STATE_DIMENSION,))

# Compile functions
sum_of_squares_loss(example_dynamics, states, parameters)
sum_of_squares_loss_gradient(example_dynamics, states, parameters)

Resulting XLA dump for the above

So my questions is: what is the best way to write/wrap a function that acts on a PyTree, without having prohibitively large compile times?

@esbenscriver
Copy link

Did you find any solution to this problem? And do you know why the pytrees use so much memory?

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 5, 2024

I believe the issue is not with pytrees per se, but rather with defining hundreds of array arguments, and compiling functions which take hundreds of array arguments. Generally speaking, compilation costs should not scale with the size of the individual arrays being operated on, but we do expect them to scale with the number of array objects being passed to the function. In the best case, you might achieve linear scaling – but I suspect in reality you'll see closer to quadratic scaling with the number of array inputs. This is not unexpected, because it will generally lead to much larger programs which require much more logic to optimize.

Does that answer your question?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants