Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve usage of ravel_pytree for handling flatten view of PyTree internally #376

Closed
junpenglao opened this issue Oct 10, 2022 · 1 comment
Labels
refactoring Change that adds no functionality but improves code quality

Comments

@junpenglao
Copy link
Member

junpenglao commented Oct 10, 2022

Currently, we have multiple places where we use flatten array. One common pattern is:

def fun(input0, *args, **kwargs):
    input0_flatten, unravel_fn = jax.flatten_util.ravel_pytree(input0)
    output0_flatten = _fun(input0_flatten, *args, **kwargs)
    return unravel_fn(output0_flatten)

In general this is needed because:

  • *args, **kwargs already contain flatten array
  • we need to generate random number and map to input at some point, which we need to split the input key to map the pytree

Per communication with @mattjj offline, there are some general guidelines how we should improve this:

raveling pytrees may not be the most performant way to do things.
one thing you can ask yourself is: how would you do the right thing with plain old flat lists? if you know how to do something with flat lists, like splitting a key the right number of times, then just tree-flattening, working with the flat lists, and tree-unflattening is often the best approach

that said, we had this helper in Autograd: https://github.com/HIPS/autograd/blob/c6d81ce7eede6db801d4e9a92b27ec5d409d0eab/autograd/misc/flatten.py#L30
the venerable issue #190 has me saying it might be a good idea to add: jax-ml/jax#190
flattening seems like it could be a performance footgun because it really ties the compiler's hands to pack everything into one flat vector
so yeah unless you're sure you want to work with flat stuff, maybe avoid it and try to work with pytree mapping/flatteing/unflattening

@junpenglao junpenglao added the refactoring Change that adds no functionality but improves code quality label Oct 10, 2022
@junpenglao junpenglao changed the title Refactor internal flatten array and usage of ravel_pytree Improve usage of ravel_pytree for handling flatten view of PyTree internally Oct 10, 2022
@junpenglao
Copy link
Member Author

Did some light benchmarking and using ravel_pytree doesnt seems too terrible. Given that we have quite a few place that flatten view is unavoidable (mostly when we multiple some dense matrix), until we have a good solution for matrix operation of a PyTree (e.g., better than tree-math provides), we will need to work with flatten PyTree that output from ravel_pytree.

I will instead send in a PR to just refactor out some common pattern when using ravel_pytree.

junpenglao added a commit to junpenglao/blackjax that referenced this issue Oct 10, 2022
... that is the same PyTree structure as the input.

relate to blackjax-devs#376
junpenglao added a commit to junpenglao/blackjax that referenced this issue Oct 13, 2022
... that is the same PyTree structure as the input.

relate to blackjax-devs#376
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
refactoring Change that adds no functionality but improves code quality
Projects
None yet
Development

No branches or pull requests

1 participant