Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten (and Unflatten) Pytree to 1D array #2056

Closed
Jakob-Unfried opened this issue Jan 23, 2020 · 10 comments
Closed

Flatten (and Unflatten) Pytree to 1D array #2056

Jakob-Unfried opened this issue Jan 23, 2020 · 10 comments
Assignees
Labels
question Questions for the JAX team

Comments

@Jakob-Unfried
Copy link
Contributor

Jakob-Unfried commented Jan 23, 2020

I am not sure if such a feature exists.
If so, could you just point me to it?
If not, i want to request it.

Some context:
I am currently implementing the L-BFGS algorithm for complex optimisation problems using gradients from jax.

I have an implementation that optimises cost-functions whose argument is a single 1D array.

I want to write a decorator that promotes this to an optimiser that can optimise cost-functions whose argument are arbitrary pytrees.

I have looked at the @optimizer decorator in jax.experimental.optimizers for inspiratrion.
As far as i understand it, @optimizer unpacks the pytrees to a list of arrays (leaves) and then uses map to let the update function act on all the leaves.
This is not directly applicable to my case, since i need (for example) dot products of gradients with parameters (that's what i call the argument of the cost-function)

The feature i am looking for
What i am looking for is a pair of functions, much like tree_flattenand tree_unflatten from jax.tree_util but instead of flattening to a list of nD arrays i want the data to be flattened to a single 1D array. This, of course, assumes that there is a compatible datatype that can encapsulate all of the data.

A quick and dirty implementation

import jax.numpy as np
from jax.tree_util import tree_flatten, tree_unflatten

def full_flatten(pytree):
    flat_tree, tree = tree_flatten(pytree)
    shapes = [flat_tree[0].shape]
    last = flat_tree[0].size
    slices = [(0, last)]
    flat = flat_tree[0].flatten()
    for arr in flat_tree:
        shapes.append(arr.shape)
        slices.append((last, last + arr.size))
        last = last + arr.size
        flat = np.concatenate(flat, arr.flatten())
    return flat, tree, shapes, slices


def full_unflatten(flat, tree, shapes, slices):
    flat_tree = []
    for shape, _slice in zip(shapes, slices):
        flat_tree.append(flat[_slice[0]:_slice[1]].reshape(shape))
    pytree = tree_unflatten(tree, flat_tree)
    return pytree
@mattjj
Copy link
Collaborator

mattjj commented Jan 23, 2020

Thanks for the question!

Check out #1928 and ravel_pytree in jax.flatten_util. Does that do what you want?

@mattjj mattjj self-assigned this Jan 23, 2020
@mattjj mattjj added the question Questions for the JAX team label Jan 23, 2020
@mattjj
Copy link
Collaborator

mattjj commented Jan 23, 2020

By the way, flattening down to a 1D array can in some cases build surprisingly large programs that ultimately lead to long compile times. It's often better to keep things as pytrees of arrays and work with those directly (ie generalize the optimizer code you have, perhaps by using jax.tree_util.tree_map and/or jax.tree_util.tree_multimap, or by flattening to lists of arrays and just using a regular Python map over those).

@Jakob-Unfried
Copy link
Contributor Author

Jakob-Unfried commented Jan 23, 2020

@mattjj
Thanks for the pointer! ravel_pytree seems to be exactly what i want.

Also thank you for the advice, i might eventually code the optimizer for a list of arrays and use tree_flatten in the decorator, but as long as the flattening to 1D arrays doesn't cause any problems that will remain quite low on the todo list...
It is great to be aware that this might be a problem, though. thanks.

@PhilipVinc
Copy link
Contributor

@hawkinsp would it be possible to export ravel_pytree ? Right now the whole file jax.flatten_util is not exported

@mattjj
Copy link
Collaborator

mattjj commented Feb 12, 2021

Sure! Can you open a feature request so we don't forget?

ravel_pytree may do weird things with integer dtype arrays, like lose precision or something. I'm not sure if it's ideal.

@Goodbrake
Copy link

I have similar need for this feature, as I have a boutique optimizer for my problem that involves constructing a matrix and using jnp.linalg.lstsq to approximately apply its inverse to precondition the gradient. Therefore, I need something to canonically order all the free parameters, flatten the gradient accordingly, order the rows and columns of the matrix, apply jnp.linalg.lstsq, and then convert the preconditioned gradient to the proper pytree shape. ravel_pytree works for this perfectly, but the unravel function it returns is not jittable. Is there a way to either make the returned unravel function jittable, or modify the outputs of tree_flatten, etc. to recreate the behavior of ravel_pytree in a jittable way, or perhaps some better idea?

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 30, 2022

Thanks for the question - that might be possible. So it doesn't get lost, would you mind opening a new feature request along with a minimal example of the kind of operation you'd like to see supported? Thanks!

@amifalk
Copy link

amifalk commented Nov 13, 2023

Is there any way to support vmap for ravel_pytree? For context, I'm working on implementing ensemble MCMC samplers - parameters are stored in a pytree and each chain's state update is dependent on the other state of the chains.

If this is the typical use:

long_param_vector, unravel_fn = ravel_pytree(pytree_of_params)

Ideally, I'd like to do something like this:

batched_param_vector, batch_unravel_fn = jax.vmap(ravel_pytree)(batched_pytree_of_params)

assert batched_param_vector.shape[0] == batch_size

batched_param_vector = batched_param_vector.at[:split].set(...) 
batched_param_vector = batched_param_vector.at[split:].set(...) 

batched_pytree_of_params = batch_unravel_fn(batched_param_vector)

@amifalk
Copy link

amifalk commented Nov 13, 2023

Until then, here's a quick and dirty batched version for others who may have a similar use case:

import warnings

import numpy as np

from jax import lax, vmap
import jax.numpy as jnp

from jax._src import dtypes
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import safe_zip, unzip2, HashablePartial

zip = safe_zip


def batch_ravel_pytree(pytree):
    """Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array.   
    Args:
      pytree: a pytree of arrays and scalars to ravel.  
    Returns:
      A pair where the first element is a (batch_size, 1D) array representing the flattened and
      concatenated leaf values, with dtype determined by promoting the dtypes of
      leaf values, and the second element is a callable for unflattening a (batch_size, 1D)
      vector of the same length back to a pytree of of the same structure as the
      input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
      a convention a 1D empty array of dtype float32 is returned in the first
      component of the output.  
    For details on dtype promotion, see
    https://jax.readthedocs.io/en/latest/type_promotion.html.   
    """
    
    leaves, treedef = tree_flatten(pytree)
    flat, unravel_list = _ravel_list(leaves)
    return flat, HashablePartial(unravel_pytree, treedef, unravel_list)

def unravel_pytree(treedef, unravel_list, flat):
    return tree_unflatten(treedef, unravel_list(flat))

@vmap
def vmapped_ravel(a):
    return jnp.ravel(a)

def _ravel_list(lst):
    if not lst: return jnp.array([], jnp.float32), lambda _: []
    from_dtypes = tuple(dtypes.dtype(l) for l in lst)
    to_dtype = dtypes.result_type(*from_dtypes)
    
    # here 1 is n_leading_batch_dimensions    
    sizes, shapes = unzip2((np.prod(jnp.shape(x)[1:]), jnp.shape(x)[1:]) for x in lst)
    indices = tuple(np.cumsum(sizes))
    
    if all(dt == to_dtype for dt in from_dtypes):
        # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
        # See https://github.com/google/jax/issues/7809.
        del from_dtypes, to_dtype
        
        # axis = n_leading_batch_dimensions
        # vmap n_leading_batch_dimensions times
        raveled = jnp.concatenate([vmapped_ravel(e) for e in lst], axis=1)
        return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes)
    
    # When there is more than one distinct input dtype, we perform type
    # conversions and produce a dtype-specific unravel function.
    ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
    raveled = jnp.concatenate([vmapped_ravel(e) for e in lst])
    unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype)
    return raveled, unrav
    
    
def _unravel_list_single_dtype(indices, shapes, arr):
    # axis is n_leading_batch_dimensions
    chunks = jnp.split(arr, indices[:-1], axis=1)

    # the number of -1s is the number of leading batch dimensions
    return [chunk.reshape((-1, *shape)) for chunk, shape in zip(chunks, shapes)]


def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
    arr_dtype = dtypes.dtype(arr)
    if arr_dtype != to_dtype:
      raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
                      f"but expected dtype {to_dtype}")
    
    # axis is n_leading_batch_dimensions
    chunks = jnp.split(arr, indices[:-1], axis=1)
    with warnings.catch_warnings():
      warnings.simplefilter("ignore")  # ignore complex-to-real cast warning
      # the number of -1s is the number of leading batch dimensions
      return [lax.convert_element_type(chunk.reshape((-1, *shape)), dtype)
              for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)]

@allen-adastra
Copy link

I've got a decorator that tries to solve a similar problem by wrapping functions that accept arrays:
f: arr0->arr1
to instead accept PyTrees.
tree0 -> arr0 -> f -> arr1 -> tree1.

def transform_to_pytree_io(fun: typing.Callable, out_example: PyTree[ArrayLike]) -> typing.Callable:
    """Given a function that takes a 1D array and returns a 1D array, build a function that
    takes a PyTree and returns a PyTree. For example, one might have a NN with signature:
        arr_out = nn(arr_in)
    But ones data might be a PyTree:
        tree_example = {"a": jnp.array(1.0), "b": {"c": jnp.array(2.0), "d": jnp.array(3.0)}}
    This function allows one to use the NN with the PyTree.
    If a out_example is not provided, the function will attempt to output a PyTree with the same
    structure as the input. Note that, in this case, the input leaves must be scalars.
    If a out_example is provided, the function will attempt to output a PyTree with the same
    structure as the out_example.

    Args:
        fun (typing.Callable): A function that takes a 1D array and returns a 1D array.
        out_example (PyTree[ArrayLike]): An example of the output structure.

    Returns:
        typing.Callable: A function that takes a PyTree and returns a PyTree.
    """
    if out_example:
        out_example_leaves, out_structure = jax.tree_flatten(out_example)
        out_array_idxs = list(jnp.cumsum(jnp.array([x.size for x in out_example_leaves]))[:-1])
    else:
        out_structure, out_array_idxs = None, None

    def transformed_fun(inputs: PyTree[ScalarLike]) -> PyTree[ScalarLike]:
        in_leaves_1d, _ = jfu.ravel_pytree(inputs)
        out_array_1d = fun(in_leaves_1d)

        out_leaves_1d = jnp.split(out_array_1d, out_array_idxs) if out_array_idxs is not None else out_array_1d

        out_tree = jax.tree_util.tree_unflatten(
            out_structure if out_structure is not None else jax.tree_util.tree_structure(inputs), out_leaves_1d
        )
        return out_tree

    return transformed_fun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

7 participants