Skip to content

Commit

Permalink
Annotate tree_util
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 6, 2022
1 parent e9204e3 commit 60e92ee
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 33 deletions.
6 changes: 3 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,13 +365,13 @@ def _trace_to_jaxpr(fun, in_tree, in_avals):
### Utilities

def saved_residuals(f, *args, **kwargs) -> List[Tuple[core.AbstractValue, str]]:
args, in_tree = tree_flatten((args, kwargs))
in_leaves, in_tree = tree_flatten((args, kwargs))

def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)

jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(*args).jaxpr
jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f_, *args)[1])(*in_leaves).jaxpr
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

Expand All @@ -384,7 +384,7 @@ def f_(*args):
if v in res_vars:
results.append((v.aval, 'from a constant'))

assert len(jaxpr.invars) == len(args)
assert len(jaxpr.invars) == len(in_leaves)
for i, v in enumerate(jaxpr.invars):
if v in res_vars:
src = f'from {pe.arg_info_pytree(f, in_tree, True, [i])}'
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import inspect
import operator
from functools import partial
from typing import (Any, Dict, Iterable, Sequence, Set, Tuple, Union, Optional,
Callable)
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
Set, Tuple, Union)
import warnings

import numpy as np
Expand Down Expand Up @@ -330,7 +330,7 @@ def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):

def donation_vector(donate_argnums, args, kwargs) -> Tuple[bool, ...]:
"""Returns a tuple with a boolean value for each leaf in args."""
res = []
res: List[bool] = []
for i, arg in enumerate(args):
donate = bool(i in donate_argnums)
res.extend((donate,) * tree_structure(arg).num_leaves)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,8 +992,8 @@ def reduce(operands: Any,
flat_init_avals = safe_map(_abstractify, flat_init_values)
jaxpr, consts, out_tree = _variadic_reduction_jaxpr(
computation, tuple(flat_init_avals), init_value_tree)
out = reduce_p.bind(*(flat_operands + flat_init_values), computation=computation,
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
out = reduce_p.bind(*flat_operands, *flat_init_values, computation=computation,
jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
return tree_util.tree_unflatten(out_tree, out)

@cache()
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def reduce_window(operand, init_value, computation: Callable,
'reduce_window output must have the same tree structure as the operands'
f' {operand_tree} vs. {out_tree}')
out_flat = reduce_window_p.bind(
*(flat_operands + flat_init_values), jaxpr=jaxpr, consts=consts,
*flat_operands, *flat_init_values, jaxpr=jaxpr, consts=consts,
window_dimensions=tuple(window_dimensions),
window_strides=tuple(window_strides), padding=padding,
base_dilation=tuple(base_dilation),
Expand Down
58 changes: 34 additions & 24 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import functools
from functools import partial
import operator as op
from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, List,
Dict, Type, TypeVar, overload, TYPE_CHECKING, NamedTuple)
from typing import (Any, Callable, Dict, Hashable, Iterable, List, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, overload,
TYPE_CHECKING)
import textwrap
import warnings

Expand All @@ -32,15 +33,18 @@
traceback_util.register_exclusion(__file__)

T = TypeVar("T")
U = TypeVar("U")
U = TypeVar("U", bound=Type[Any])

Array = Any
if TYPE_CHECKING or xla_extension_version >= 78:
PyTreeDef = pytree.PyTreeDef
else:
PyTreeDef = xla_extension.PyTreeDef # pytype: disable=module-attr


def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
def tree_flatten(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> Tuple[Sequence[Array], PyTreeDef]:
"""Flattens a pytree.
The flattening order (i.e. the order of elements in the output list)
Expand All @@ -54,59 +58,63 @@ def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
traversal and the whole subtree being treated as a leaf, and false
indicating the flattening should traverse the current object.
Returns:
A pair where the first element is a list of leaf values and the second
A pair where the first element is a sequence of leaf values and the second
element is a treedef representing the structure of the flattened tree.
"""
return pytree.flatten(tree, is_leaf)


def tree_unflatten(treedef, leaves):
def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Array]) -> Any:
"""Reconstructs a pytree from the treedef and the leaves.
The inverse of :func:`tree_flatten`.
Args:
treedef: the treedef to reconstruct
leaves: the list of leaves to use for reconstruction. The list must match
the leaves of the treedef.
leaves: the iterable of leaves to use for reconstruction. The iterable
must match the leaves of the treedef.
Returns:
The reconstructed pytree, containing the ``leaves`` placed in the structure
described by ``treedef``.
"""
return treedef.unflatten(leaves)

def tree_leaves(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
def tree_leaves(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None
) -> Sequence[Array]:
"""Gets the leaves of a pytree."""
return pytree.flatten(tree, is_leaf)[0]

def tree_structure(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
def tree_structure(tree: Any,
is_leaf: Optional[Callable[[Any], bool]] = None) -> PyTreeDef:
"""Gets the treedef for a pytree."""
return pytree.flatten(tree, is_leaf)[1]

def treedef_tuple(treedefs):
"""Makes a tuple treedef from a list of child treedefs."""
def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
"""Makes a tuple treedef from an iterable of child treedefs."""
return pytree.tuple(list(treedefs))

def treedef_children(treedef):
def treedef_children(treedef: PyTreeDef) -> Sequence[PyTreeDef]:
return treedef.children()

def treedef_is_leaf(treedef):
def treedef_is_leaf(treedef: PyTreeDef) -> bool:
return treedef.num_nodes == 1

def treedef_is_strict_leaf(treedef):
def treedef_is_strict_leaf(treedef: PyTreeDef) -> bool:
return treedef.num_nodes == 1 and treedef.num_leaves == 1

def all_leaves(iterable, is_leaf: Optional[Callable[[Any], bool]] = None):
def all_leaves(iterable: Iterable[Any],
is_leaf: Optional[Callable[[Any], bool]]) -> bool:
"""Tests whether all elements in the given iterable are all leaves.
>>> tree = {"a": [1, 2, 3]}
>>> assert all_leaves(jax.tree_util.tree_leaves(tree))
>>> assert not all_leaves([tree])
This function is useful in advanced cases, for example if a library allows
arbitrary map operations on a flat list of leaves it may want to check if
the result is still a flat list of leaves.
arbitrary map operations on a flat iterable of leaves it may want to check
if the result is still a flat iterable of leaves.
Args:
iterable: Iterable of leaves.
Expand Down Expand Up @@ -146,7 +154,7 @@ def register_pytree_node(nodetype: Type[T],
pytree.register_node(nodetype, flatten_func, unflatten_func)
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)

def register_pytree_node_class(cls):
def register_pytree_node_class(cls: U) -> U:
"""Extends the set of types that are considered internal nodes in pytrees.
This function is a thin wrapper around ``register_pytree_node``, and provides
Expand Down Expand Up @@ -204,10 +212,12 @@ def tree_map(f: Callable[..., Any], tree: Any, *rest: Any,
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

def build_tree(treedef, xs):
def build_tree(treedef: PyTreeDef, xs: Any) -> Any:
return treedef.from_iterable_tree(xs)

def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
def tree_transpose(outer_treedef: PyTreeDef,
inner_treedef: PyTreeDef,
pytree_to_transpose: Any) -> Any:
"""Transform a tree having tree structure (outer, inner) into one having structure
(inner, outer).
"""
Expand All @@ -217,8 +227,8 @@ def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
if treedef.num_leaves != (inner_size * outer_size):
expected_treedef = outer_treedef.compose(inner_treedef)
raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
flat = iter(flat)
lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
iter_flat = iter(flat)
lol = [[next(iter_flat) for _ in range(inner_size)] for __ in range(outer_size)]
transposed_lol = zip(*lol)
subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
return tree_unflatten(inner_treedef, subtrees)
Expand Down Expand Up @@ -273,7 +283,7 @@ def tree_reduce(function: Callable[[T, Any], T],
else:
return functools.reduce(function, tree_leaves(tree), initializer)

def tree_all(tree):
def tree_all(tree: Any) -> bool:
return all(tree_leaves(tree))

register_pytree_node(
Expand Down

0 comments on commit 60e92ee

Please sign in to comment.