-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flatten (and Unflatten) Pytree to 1D array #2056
Comments
Thanks for the question! Check out #1928 and |
By the way, flattening down to a 1D array can in some cases build surprisingly large programs that ultimately lead to long compile times. It's often better to keep things as pytrees of arrays and work with those directly (ie generalize the optimizer code you have, perhaps by using |
@mattjj Also thank you for the advice, i might eventually code the optimizer for a list of arrays and use tree_flatten in the decorator, but as long as the flattening to 1D arrays doesn't cause any problems that will remain quite low on the todo list... |
@hawkinsp would it be possible to export |
Sure! Can you open a feature request so we don't forget?
|
I have similar need for this feature, as I have a boutique optimizer for my problem that involves constructing a matrix and using jnp.linalg.lstsq to approximately apply its inverse to precondition the gradient. Therefore, I need something to canonically order all the free parameters, flatten the gradient accordingly, order the rows and columns of the matrix, apply jnp.linalg.lstsq, and then convert the preconditioned gradient to the proper pytree shape. |
Thanks for the question - that might be possible. So it doesn't get lost, would you mind opening a new feature request along with a minimal example of the kind of operation you'd like to see supported? Thanks! |
Is there any way to support If this is the typical use: long_param_vector, unravel_fn = ravel_pytree(pytree_of_params) Ideally, I'd like to do something like this: batched_param_vector, batch_unravel_fn = jax.vmap(ravel_pytree)(batched_pytree_of_params)
assert batched_param_vector.shape[0] == batch_size
batched_param_vector = batched_param_vector.at[:split].set(...)
batched_param_vector = batched_param_vector.at[split:].set(...)
batched_pytree_of_params = batch_unravel_fn(batched_param_vector) |
Until then, here's a quick and dirty batched version for others who may have a similar use case: import warnings
import numpy as np
from jax import lax, vmap
import jax.numpy as jnp
from jax._src import dtypes
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import safe_zip, unzip2, HashablePartial
zip = safe_zip
def batch_ravel_pytree(pytree):
"""Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array.
Args:
pytree: a pytree of arrays and scalars to ravel.
Returns:
A pair where the first element is a (batch_size, 1D) array representing the flattened and
concatenated leaf values, with dtype determined by promoting the dtypes of
leaf values, and the second element is a callable for unflattening a (batch_size, 1D)
vector of the same length back to a pytree of of the same structure as the
input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as
a convention a 1D empty array of dtype float32 is returned in the first
component of the output.
For details on dtype promotion, see
https://jax.readthedocs.io/en/latest/type_promotion.html.
"""
leaves, treedef = tree_flatten(pytree)
flat, unravel_list = _ravel_list(leaves)
return flat, HashablePartial(unravel_pytree, treedef, unravel_list)
def unravel_pytree(treedef, unravel_list, flat):
return tree_unflatten(treedef, unravel_list(flat))
@vmap
def vmapped_ravel(a):
return jnp.ravel(a)
def _ravel_list(lst):
if not lst: return jnp.array([], jnp.float32), lambda _: []
from_dtypes = tuple(dtypes.dtype(l) for l in lst)
to_dtype = dtypes.result_type(*from_dtypes)
# here 1 is n_leading_batch_dimensions
sizes, shapes = unzip2((np.prod(jnp.shape(x)[1:]), jnp.shape(x)[1:]) for x in lst)
indices = tuple(np.cumsum(sizes))
if all(dt == to_dtype for dt in from_dtypes):
# Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`.
# See https://github.com/google/jax/issues/7809.
del from_dtypes, to_dtype
# axis = n_leading_batch_dimensions
# vmap n_leading_batch_dimensions times
raveled = jnp.concatenate([vmapped_ravel(e) for e in lst], axis=1)
return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes)
# When there is more than one distinct input dtype, we perform type
# conversions and produce a dtype-specific unravel function.
ravel = lambda e: jnp.ravel(lax.convert_element_type(e, to_dtype))
raveled = jnp.concatenate([vmapped_ravel(e) for e in lst])
unrav = HashablePartial(_unravel_list, indices, shapes, from_dtypes, to_dtype)
return raveled, unrav
def _unravel_list_single_dtype(indices, shapes, arr):
# axis is n_leading_batch_dimensions
chunks = jnp.split(arr, indices[:-1], axis=1)
# the number of -1s is the number of leading batch dimensions
return [chunk.reshape((-1, *shape)) for chunk, shape in zip(chunks, shapes)]
def _unravel_list(indices, shapes, from_dtypes, to_dtype, arr):
arr_dtype = dtypes.dtype(arr)
if arr_dtype != to_dtype:
raise TypeError(f"unravel function given array of dtype {arr_dtype}, "
f"but expected dtype {to_dtype}")
# axis is n_leading_batch_dimensions
chunks = jnp.split(arr, indices[:-1], axis=1)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
# the number of -1s is the number of leading batch dimensions
return [lax.convert_element_type(chunk.reshape((-1, *shape)), dtype)
for chunk, shape, dtype in zip(chunks, shapes, from_dtypes)] |
I've got a decorator that tries to solve a similar problem by wrapping functions that accept arrays:
|
I am not sure if such a feature exists.
If so, could you just point me to it?
If not, i want to request it.
Some context:
I am currently implementing the L-BFGS algorithm for complex optimisation problems using gradients from jax.
I have an implementation that optimises cost-functions whose argument is a single 1D array.
I want to write a decorator that promotes this to an optimiser that can optimise cost-functions whose argument are arbitrary pytrees.
I have looked at the
@optimizer
decorator injax.experimental.optimizers
for inspiratrion.As far as i understand it,
@optimizer
unpacks the pytrees to a list of arrays (leaves) and then usesmap
to let theupdate
function act on all the leaves.This is not directly applicable to my case, since i need (for example) dot products of gradients with parameters (that's what i call the argument of the cost-function)
The feature i am looking for
What i am looking for is a pair of functions, much like
tree_flatten
andtree_unflatten
fromjax.tree_util
but instead of flattening to a list of nD arrays i want the data to be flattened to a single 1D array. This, of course, assumes that there is a compatible datatype that can encapsulate all of the data.A quick and dirty implementation
The text was updated successfully, but these errors were encountered: