Skip to content

Commit

Permalink
Merge pull request #9569 from GJBoth:tree_flatten_docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 441878577
  • Loading branch information
jax authors committed Apr 14, 2022
2 parents 0443f5e + add6c82 commit 375777f
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
Args:
tree: a pytree to flatten.
is_leaf: an optionally specified function that will be called at each
flattening step. It should return a boolean, which indicates whether
the flattening should traverse the current object, or if it should be
stopped immediately, with the whole subtree being treated as a leaf.
flattening step. It should return a boolean, with true stopping the
traversal and the whole subtree being treated as a leaf, and false
indicating the flattening should traverse the current object.
Returns:
A pair where the first element is a list of leaf values and the second
element is a treedef representing the structure of the flattened tree.
Expand Down

0 comments on commit 375777f

Please sign in to comment.