Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyDags proposal: a mechanism to preserve relative identities for Pytrees #7919

Closed
cgarciae opened this issue Sep 15, 2021 · 18 comments
Closed
Assignees
Labels
enhancement New feature or request

Comments

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 15, 2021

Motivation

When jax.tree_unflatten is called new instances of all pytree objects are created and their original identities are lost. While this is expected for basic container types like lists, dicts, and tuples, for other types such as Pytree Modules this can be inconvenient as it makes tasks like parameter sharing difficult. Here is an example that currently doesn't work of trying to share 2 Child modules in the same Parent module:

class Child(Module):
    x: jnp.ndarray # assume this is a leaf
   ...

class Parent(Module):
    left: Child # assume these are subtrees
    right: Child 
    ...

child = Child(x=jnp.array(1))
parent = Parent(left=child, right=child)  # <<<< child is shared

@jax.jit
def f(parent):
    assert parent.left is parent.right  # Bad
    return parent

parent2 = f(parent)
assert parent2.left is parent2.right  # Bad

Proposal

Enable preserving the relative identities of custom Pytree classes that opt-in to this behaviour.

By relative identities it means that if two objects in a Pytree have the same identities before flattening, the objects will share the same identity between them after unflattening them, but they won't have the same identity as their original objects. This means the following assertions are true assuming Module opted-in to this behavior:

m = Module()

@jax.jit
def f(m1, m2):
    assert m1 is m2 
    return m1, m2

m1, m2 = f(m, m)

assert m1 is m2
assert m is not m1 and m is not m2

Implementation

To achieve this register_pytree_node could accept an optional preserve_relative_identities: bool flag (or something like this) that indicates that objects of this class opt-in to preserve their relative identities. When tree_flatten is called each node's object id could be stored in the PyTreeDef such that when tree_unflatten is unflattening a node, and that nodes class had preserve_relative_identities=True, then tree_unflatten will check if it had already unflattened that element based on the id and reuse that node if that is the case.

preserve_relative_identities should also be available for register_pytree_node_class.

Implications

All current code should run normally, only new code that opt-in to this behavior will use this feature.

@cgarciae cgarciae added the enhancement New feature or request label Sep 15, 2021
@cgarciae cgarciae changed the title Mechanism to preserve relative identities for custom Pytrees Proposal: mechanism to preserve relative identities for custom Pytrees Sep 15, 2021
@zhangqiaorjc
Copy link
Collaborator

@hawkinsp

@apaszke
Copy link
Collaborator

apaszke commented Sep 15, 2021

I might be missing something in the proposal, but what if you unflatten twice with different elements? You can't reuse the same objects to construct two distinct pytrees.

@cgarciae
Copy link
Collaborator Author

cgarciae commented Sep 15, 2021

Hey @apaszke! Unflattening many times is not a problem, implementation-wise you would probably create a new dictionary to keep track of the identities and store the first unflattened instance each time you call tree_unflatten. For example, the following should hold:

m = Module()
tree = (m, m)

leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)

m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)

assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m

@shoyer
Copy link
Collaborator

shoyer commented Sep 15, 2021

It might be worth looking at how Python's pickle solves this problem. It does essentially the same thing (preserving identity through serialization) by default:

>>> import pickle
>>> x = (1, 2, 3)
>>> y1, y2 = pickle.loads(pickle.dumps((x, x)))
>>> y1 is y2

In principle, this might be a reasonable thing for JAX to do with pytrees. I haven't looked into how the implementations differ, though.

@mattjj
Copy link
Collaborator

mattjj commented Sep 16, 2021

A terminological point: preserving object identity would mean violating referential transparency, and in particular these would be more like "pydags" than "pytrees".

I don't think we want to break referential transparency with existing pytree types. That would prevent us from processing them recursively in a functionally pure way (as @cgarciae already mentioned under "Implementation", which refers to basically a side-effecting memoization process). Moreover I wouldn't be surprised if we leverage the referential transparency assumption in lots of different places.

But just for pytree flattening/unflattening alone as in @cgarciae's most recent comment, the API is already general enough to handle DAGs in your own custom pytree types, so long as you are willing to break referential transparency in your own flattening functions. You just need to do the deduplication-by-python-object-id (and equality-up-to-alpha-renaming) yourself:

from typing import Any, NamedTuple
import itertools as it
from collections import defaultdict

import jax
from jax.util import unzip2
from jax.tree_util import register_pytree_node

class MyTuple:
  elts: tuple[Any]
  def __init__(self, *elts):
    self.elts = elts
  def __iter__(self):
    return iter(self.elts)

def flatten_mytuple(x):
  counts = it.count()
  id_to_name = defaultdict(lambda: next(counts))
  name_list = [id_to_name[id(e)] for e in x.elts]
  uniques = {id(e): e for e in x.elts}
  unique_names, unique_vals = unzip2((id_to_name[i], v) for i, v in uniques.items())
  return unique_vals, (unique_names, name_list)

def unflatten_mytuple(aux, unique_vals):
  unique_names, name_list = aux
  uniques = dict(zip(unique_names, unique_vals))
  elts = [uniques[name] for name in name_list]
  return MyTuple(*elts)

register_pytree_node(MyTuple, flatten_mytuple, unflatten_mytuple)

class Module(NamedTuple): pass  # added this as trivial pytree

###

m = Module()
tree = MyTuple(m, m)  # NOTE: changed from Python tuple to MyTuple!

leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)

m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)

assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m

I'm not sure of the limitations of this approach. For example, when Tracers are involved, this reliance on Python object identity might lead to surprising results (but maybe it'd be okay to rely on Python object identity just for values which cant be wrapped in Tracers, i.e. in your own custom pytree data types?).

@cgarciae
Copy link
Collaborator Author

cgarciae commented Sep 16, 2021

Hey @mattjj! Good point about this being more like a Pydag.

I think the main problem with your suggestion is that when you have nested structures like this

class Child:
    shared: Module

class Parent:
    left: Child
    right: Child

shared = Module()
parent = Parent(left=Child(shared), right=Child(shared))

local information is not enough because shared here is in separate branches, to solve this you need to keep track of all objects within the pytree during unflatten.

Implementation Proposal 2: Unflatten callbacks

A light-weight alternative closer to @mattjj's idea could to that JAX created a e.g. jax.tree_util.register_unflatten_callabacks function that let frameworks register 2 callbacks that would be called at the beginning and end PyTreeDef.unflatten:

# tree_util.py
def register_unflatten_callabacks(
    tree_unflatten_start_fn: Callable[[], Any],
    tree_unflatten_end_fn: Callable[[Any], None],
)

If PyTreeDef.unflatten calls other PyTreeDef.unflattens recursively only the first one (top-level) should trigger the callbacks.

Module implementation

Based on this a Module library could define a global _IDENTITIES dictionary to keep track of object identities during unflattening, the tree_unflatten_start_fn and tree_unflatten_end_fn together would behave like a context manager keeping _IDENTITIES new every time tree_unflatten is called and setting it back to its original value at the end:

import functools
import typing as tp

import jax
import jax.tree_util

# in reality use something based on threading.local
_IDENTITIES: tp.Optional[tp.Dict[str, tp.Any]] = None 

def tree_unflatten_start_fn() -> tp.Any:
    global _IDENTITIES
    old_identities = _IDENTITIES
    _IDENTITIES = {}
    return old_identities

def tree_unflatten_end_fn(old_identities: tp.Any) -> None:
    global _IDENTITIES
    _IDENTITIES = old_identities

# register callbacks
jax.tree_util.register_unflatten_callabacks(tree_unflatten_start, tree_unflatten_end)

Now we can define Module such that it leverages _IDENTITIES to keep track of objects with the same identity that where already unflattened:

class Module:
    def tree_flatten(self):
        tree = vars(self)
        obj_id = id(self)
        return (tree,), (obj_id,)

    @classmethod
    def tree_unflatten(cls, aux, children):
        tree = children[0]
        obj_id = aux[0]

        if _IDENTITIES is not None and obj_id in _IDENTITIES:
            return _IDENTITIES[obj_id]

        obj = cls.__new__(cls)
        obj.__dict__.update(tree)

        if _IDENTITIES is not None:
            _IDENTITIES[obj_id] = obj

        return obj

    def __init_subclass__(cls) -> None:
        jax.tree_util.register_pytree_node_class(cls)

Example

Based on these definitions we can recreate the original example, all the following assertions should hold.

class Shared(Module):
    pass

class Child(Module):
    def __init__(self, shared: Shared) -> None:
        self.shared = shared

class Parent(Module):
    def __init__(self, left: Child, right: Child) -> None:
        self.left = left
        self.right = right

shared = Shared()
parent = Parent(Child(shared), Child(shared))

leaves, treedef = jax.tree_flatten(parent)
leaves2 = treedef.flatten_up_to(parent)

parent1: Parent = jax.tree_unflatten(treedef, leaves)
parent2: Parent = jax.tree_unflatten(treedef, leaves2)

# checks
assert parent1 is not parent
assert parent2 is not parent1

assert parent1.left.shared is parent1.right.shared
assert parent1.left.shared is not parent.left.shared
assert parent1.right.shared is not parent.right.shared
assert parent1.left is not parent1.right

assert parent2.left.shared is parent2.right.shared
assert parent2.left.shared is not parent1.left.shared
assert parent2.right.shared is not parent1.right.shared
assert parent2.left is not parent2.right

Prototype

I've created a prototype putting all the previous together where I just creating my own tree_unflatten function that follows the behavior described above. Here is a full version of the code:

Show code
import functools
import typing as tp

import jax
import jax.tree_util

_IDENTITIES: tp.Optional[tp.Dict[str, tp.Any]] = None

# these 2 function behave like a context manager around `jax.tree_unflatten`
def tree_unflatten_start() -> tp.Any:
    global _IDENTITIES
    old_identities = _IDENTITIES
    _IDENTITIES = {}
    return old_identities


def tree_unflatten_end(old_identities: tp.Any) -> None:
    global _IDENTITIES
    _IDENTITIES = old_identities


# in reality you would just call: jax.tree_util.register_unflatten_callabacks(tree_unflatten_start, tree_unflatten_end)
def tree_unflatten(*args, **kwargs):
    old_identities = tree_unflatten_start()
    try:
        return jax.tree_unflatten(*args, **kwargs)
    finally:
        tree_unflatten_end(old_identities)

class Module:
    def tree_flatten(self):
        tree = vars(self)
        obj_id = id(self)

        return (tree,), (obj_id,)

    @classmethod
    def tree_unflatten(cls, aux, children):

        tree = children[0]
        obj_id = aux[0]

        if _IDENTITIES is not None and obj_id in _IDENTITIES:
            return _IDENTITIES[obj_id]

        obj = cls.__new__(cls)
        obj.__dict__.update(tree)

        if _IDENTITIES is not None:
            _IDENTITIES[obj_id] = obj

        return obj

    def __init_subclass__(cls) -> None:
        jax.tree_util.register_pytree_node_class(cls)


class Shared(Module):
    pass

class Child(Module):
    def __init__(self, shared: Shared) -> None:
        self.shared = shared

class Parent(Module):
    def __init__(self, left: Child, right: Child) -> None:
        self.left = left
        self.right = right

shared = Shared()
parent = Parent(Child(shared), Child(shared))


leaves, treedef = jax.tree_flatten(parent)
leaves2 = treedef.flatten_up_to(parent)

parent1: Parent = tree_unflatten(treedef, leaves)
parent2: Parent = tree_unflatten(treedef, leaves2)


assert parent1 is not parent
assert parent2 is not parent1

assert parent1.left.shared is parent1.right.shared
assert parent1.left.shared is not parent.left.shared
assert parent1.right.shared is not parent.right.shared
assert parent1.left is not parent1.right

assert parent2.left.shared is parent2.right.shared
assert parent2.left.shared is not parent1.left.shared
assert parent2.right.shared is not parent1.right.shared
assert parent2.left is not parent2.right

@mattjj
Copy link
Collaborator

mattjj commented Sep 17, 2021

I think the main problem with your suggestion is that when you have nested structures like this [...] local information is not enough because shared here is in separate branches, to solve this you need to keep track of all objects within the pytree during unflatten.

You're right, I meant to mention that but I neglected to: in general the flattening function would be responsible for flattening its whole subtree, not just flattening one node as usual, by calling into your own set of stateful flatteners! That is, the pytree flattening function for MyTuple would recursively call into stateful flatteners for its children, basically the kind of generalized flattener you outlined (or alternatively these could just thread through a reference to mutable object, like a dict, which would be a thread-safe alternative go global state). That way, you could deduplicate within any subtree under one of your pytree classes.

My example code did not do recursive DAG flattening. Here's a version that does! First, a general PyDag system (based on the Python pytree implementation in Autodidax):

from functools import partial
import itertools as it
from collections import defaultdict
from typing import (Callable, Type, Hashable, Dict, Any, NamedTuple, Tuple,
                    Sequence, List, Union)
from jax.util import unzip2

Name = int
Names = Sequence[int]
AuxData = Any

class NodeType(NamedTuple):
  name: str
  to_iterable: Callable[[Callable, Any], Tuple[AuxData, Names]]
  from_iterable: Callable[[Callable, AuxData, Names], Any]
  def __repr__(self): return f'DagNode[{self.name}]'

def register_pydag_node(ty: Type, to_iter: Callable, from_iter: Callable)-> None:
  node_types[ty] = NodeType(str(ty), to_iter, from_iter)

node_types: Dict[Type, NodeType] = {}
register_pydag_node(tuple,
                    lambda f, t: (None, [f(e) for e in t]),
                    lambda u, _, names: tuple(u(n) for n in names))
register_pydag_node(list,
                    lambda f, l: (None, [f(e) for e in l]),
                    lambda u, _, names: [u(n) for n in names])
register_pydag_node(dict,
                    lambda f, d: (tuple(sorted(d)), [f(d[k]) for k in sorted(d)]),
                    lambda u, keys, names: {k: u(n) for k, n in zip(keys, names)})

class PyDagNode(NamedTuple):
  node_type: NodeType
  node_auxdata: Hashable
  names: Names

class PyDagLeaf(NamedTuple):
  name: Name

PyDagDef = Union[PyDagNode, PyDagLeaf]

class FlattenState(NamedTuple):
  id_to_name: Dict[int, Name]
  name_to_obj: Dict[Name, Any]

def dag_flatten(x: Any) -> Tuple[List[Any], Tuple[List[Name], PyDagDef]]:
  names = it.count()
  state = FlattenState(defaultdict(lambda: next(names)), dict())
  dagdef = _dag_flatten(state, x)
  unique_names, unique_vals = unzip2(state.name_to_obj.items())
  return unique_vals, (unique_names, dagdef)

def _dag_flatten(state: FlattenState, x: Any) -> PyDagDef:
  node_type = node_types.get(type(x))
  if node_type:
    node_auxdata, names = node_type.to_iterable(partial(_dag_flatten, state), x)
    return PyDagNode(node_type, node_auxdata, names)
  else:
    name = state.id_to_name[id(x)]
    state.name_to_obj[name] = x
    return PyDagLeaf(name)

class UnflattenState(NamedTuple):
  name_to_obj: Dict[Name, Any]

def dag_unflatten(dagspec: Tuple[Names, PyDagDef], xs: List[Any]) -> Any:
  names, dagdef = dagspec
  state = UnflattenState(dict(zip(names, xs)))
  return _dag_unflatten(state, dagdef)

def _dag_unflatten(state: UnflattenState, dagdef: PyDagDef) -> Any:
  if type(dagdef) is PyDagLeaf:
    return state.name_to_obj[dagdef.name]
  else:
    u = partial(_dag_unflatten, state)
    return dagdef.node_type.from_iterable(u, dagdef.node_auxdata, dagdef.names)

This is probably "hella buggy", as we'd say where I come from, and I reserve the right to edit this github comment to fix embarrassing mistakes. But it passed literally one example I tried it on, so ship it!

Now, here's a MyTuple pytree which calls into that dag flattening (i.e. interfaces pydags with the existing pytree system):

class MyTuple:
  elts: tuple[Any]
  def __init__(self, *elts):
    self.elts = elts
  def __iter__(self):
    return iter(self.elts)

# register with our pydag system
register_pydag_node(MyTuple,
                    lambda f, t: (None, [f(e) for e in t]),
                    lambda u, _, names: MyTuple(*[u(n) for n in names]))

# register as a pytree with jax, but tell it to flatten like a dag
from jax.tree_util import register_pytree_node
register_pytree_node(MyTuple, dag_flatten, dag_unflatten)


###

import jax

class Module(NamedTuple): pass  # added this as trivial pytree

# Test 1

m = Module()
tree = MyTuple(m, m)

leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)

m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)

assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m

# Test 2

m = Module()
tree = MyTuple(MyTuple(m, m), MyTuple(m, m))
leaves, treedef = jax.tree_flatten(tree)
((m11, m12), (m21, m22)) = jax.tree_unflatten(treedef, leaves)
assert m11 is m12 is m21 is m22

My main point is just that I think with the existing pytree system you can at least flatten subtrees of your custom pytree types however you'd like, including as dags-by-objectid. Maybe that can unblock you!

Of course, we could also consider building some pydag behavior into JAX, which would let us deduplicate across all argument lists (even if the top-level container is not a custom pytree type you control). It's worth considering! Like I said before, I'm a bit wary of where we might leverage the referential transparency assumption. But maybe it'd all work out... experimenting with the above pydag approach might help us learn things!

WDYT? Does this approach unblock you, without needing JAX-internal changes?

@mattjj mattjj self-assigned this Sep 17, 2021
@cgarciae
Copy link
Collaborator Author

cgarciae commented Sep 17, 2021

@mattjj I got a solution working for a Module but its messy, so I started playing with your code. I modified Module to be a Pydag but now its failing:

class Module:
    def __init__(self, x):
        self.x = x

register_pydag_node(
    Module,
    lambda f, t: (None, [f(t.x)]),
    lambda u, _, names: Module(u(names[0])),
)

# Test 1

m = Module(0)
tree = MyTuple(m, m)

leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)

assert len(leaves) == 1

m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)

assert m11 is m12 # Fails here

Maybe I implemented pydag for Module wrong. I haven't been able to pin point the part where nodes are reused if they have already been unflattened in _dag_unflatten if this is indeed happening, I have a hunch that its missing something.

A downside I see with having Pydags in a separate world from Pytrees is that custom Pytrees are just DagLeafs, it would be nice if anything that is not a Pydag was treated as Pytree. This would probably require to make state a global variable as the chain is broken. I will try make sense of what is happening to see if I can get it to work. I might need to read Autodidax 😅.

@cgarciae
Copy link
Collaborator Author

Back to my original proposal, what if register_pytree_node had a is_pydag: bool = False argument that would let users decide if that node was dag-like (relative identities preserved, values de-duplicated) or tree-like (relative identities not preserved, values duplicated)?

@cgarciae
Copy link
Collaborator Author

Update: My solution wasn't taking into account the is_leaf functionality, given the restrictions imposed by is_leaf I believe there is no way to implement this without JAX supporting this use-case.

@mattjj
Copy link
Collaborator

mattjj commented Sep 21, 2021

Is there an easy way to illustrate the is_leaf issue in a toy example? (I believe you that it's not working, I just don't yet grok what you mean.)

@cgarciae
Copy link
Collaborator Author

cgarciae commented Mar 16, 2023

In the mean time, I found a relatively simple (yet highly inefficient) way to get some sort of identity preservation just using pytrees:

import dataclasses
from typing import Any, Dict, Tuple, TypeVar
import jax
import numpy as np

A = TypeVar("A")
#-------------------------------------------------------------------------------
class Nothing:
    pass
def _flatten_nothing(_: Nothing) -> Tuple[Tuple[()], None]:
    return (), None
def _unflatten_nothing(_: None, __: Tuple[()]) -> Nothing:
    return Nothing()
jax.tree_util.register_pytree_node(Nothing, _flatten_nothing, _unflatten_nothing)
#-------------------------------------------------------------------------------
@dataclasses.dataclass(frozen=True)
class ValueIndex:
    value: Any
    index: int
def _flatten_value_index(value_index: ValueIndex) -> Tuple[Tuple[(Any,)], int]:
    return (value_index.value,), value_index.index
def _unflatten_value_index(index: int, children: Tuple[Any]) -> ValueIndex:
    return ValueIndex(children[0], index)
jax.tree_util.register_pytree_node(
    ValueIndex, _flatten_value_index, _unflatten_value_index
)
#-------------------------------------------------------------------------------
def deref(pytree: A, is_leaf=None) -> A:
    id_to_index: Dict[int, int] = {}
    def deref_fn(leaf: Any) -> Any:
        leaf_id = id(leaf)
        if leaf_id not in id_to_index: 
            id_to_index[leaf_id] = len(id_to_index)
            return ValueIndex(leaf, index=id_to_index[leaf_id])
        else:
            return ValueIndex(Nothing(), index=id_to_index[leaf_id])
    return jax.tree_map(deref_fn, pytree, is_leaf=is_leaf)

def reref(pytree: A) -> A:
    index_to_value: Dict[int, Any] = {}
    def reref_fn(value_index: ValueIndex) -> Any:
        if value_index.index not in index_to_value:
            index_to_value[value_index.index] = value_index.value
        value = index_to_value[value_index.index]
        return value
    return jax.tree_map(reref_fn, pytree, is_leaf=lambda x: isinstance(x, ValueIndex))
#-------------------------------------------------------------------------------
def dag_map(f, pytree, is_leaf=None):
    pytree = deref(pytree, is_leaf=is_leaf)
    pytree = jax.tree_map(f, pytree)
    return reref(pytree)
#-------------------------------------------------------------------------------
a = np.array(0)
b = np.array(0)

pytree = {"x": [a, a, b], "y": b}

def add_noise(x):
    return x + np.random.normal()

print("jax.tree_map")
print(jax.tree_map(add_noise, pytree))
print("dag_map")
print(dag_map(add_noise, pytree))
jax.tree_map
{'x': [0.3064, -0.4374, 0.2485], 'y': -0.8665}
dag_map
{'x': [-0.8106, -0.8106, 0.6139], 'y': 0.6139}

It doesn't preserve the identity of intermediate pytrees though.

However, you can use deref and reref to pass through jit and JAX transformation while preserving leaf identities:

@jax.jit
def f(pytree):
    pytree = reref(pytree)
    assert pytree["x"][0] is pytree["x"][1]
    assert pytree["x"][2] is pytree["y"]
    return deref(pytree)

pytree = reref(f(deref(pytree)))
assert pytree["x"][0] is pytree["x"][1]
assert pytree["x"][2] is pytree["y"]

@cgarciae cgarciae changed the title Proposal: mechanism to preserve relative identities for custom Pytrees PyDags proposal: a mechanism to preserve relative identities for Pytrees Mar 16, 2023
@samskiter
Copy link

+1 for this - I immediately hit this issue when moving from MyGrad to JAX. Jax's autograd wasn't as performant as mygrad so I attempted to JIT - my code is fairly 'standard' object-oriented code with plenty of refs scattered throughout (I'm actually building a graph of interrelated objects).

I'm doing physics simulation and using grad to solve nonlinear equations - would love to see something like this to make code more 'natural'...

Thanks for the proposal - IIUC I could use this to solve the above problem?

@samskiter
Copy link

samskiter commented Sep 6, 2023

@cgarciae Here's my attempt at an inheritable class that allows preserving object identities....

Seems to be working for many cases, but I still haven't worked out what the correct semantics should be for UUIDs (should they be stored in aux_data or children)

class ReffableTreeNode(metaclass=ABCMeta):
    children_to_flatten_counter: dict[UUID, int] = {}

    @classmethod
    def add_child_to_flatten(cls, uuid: UUID) -> None:
        if not uuid in cls.children_to_flatten_counter:
            cls.children_to_flatten_counter[uuid] = 0
        cls.children_to_flatten_counter[uuid] += 1

    @classmethod
    def remove_child_to_flatten(cls, uuid: UUID) -> None:
        cls.children_to_flatten_counter[uuid] -= 1
        if cls.children_to_flatten_counter[uuid] == 0:
            del cls.children_to_flatten_counter[uuid]

    @classmethod
    def clear_children_to_flatten(cls) -> None:
        cls.children_to_flatten_counter = {}

    def is_top_level_flatten(self) -> bool:
        return self.get_uuid() not in ReffableTreeNode.children_to_flatten_counter

    def __init__(self) -> None:
        self._uuid = uuid4()
        self._did_set_refs = False

    def get_uuid(self) -> UUID:
        return self._uuid

    @final
    def tree_flatten(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
        is_top_level_flatten = self.is_top_level_flatten()
        if not is_top_level_flatten:
            ReffableTreeNode.remove_child_to_flatten(self.get_uuid())

        self._assertion = False
        children, aux_data = self.tree_flatten_adia()
        assert self._assertion
        assert "uuid" in aux_data

        if is_top_level_flatten:
            all_refs = self.get_refs_iterative()
            for ref in all_refs.keys():
                ReffableTreeNode.add_child_to_flatten(ref)
            children["all_refs"] = all_refs
        return (children,), aux_data

    @final
    @classmethod
    def tree_unflatten(cls, aux_data: dict[str, Any], children: tuple[Any, ...]) -> Any:
        dict_children = children[0]
        instance = object.__new__(cls)
        instance._assertion2 = False
        instance.tree_unflatten_adia(aux_data, dict_children)
        if "all_refs" in dict_children:
            instance.set_refs_tracked(dict_children["all_refs"])
        assert instance._assertion2
        return instance

    def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
        self._assertion = True
        return {}, {"uuid": self._uuid}

    def tree_unflatten_adia(
        self, aux_data: dict[str, Any], children: dict[str, Any]
    ) -> None:
        del children
        self._assertion2 = True
        self._uuid = aux_data["uuid"]

    @final
    def get_refs_iterative(self) -> dict[UUID, ReffableTreeNode]:
        refs = self.get_refs()
        full_refs: dict[UUID, ReffableTreeNode] = {}
        refs_visited: set[UUID] = set()
        refs_to_visit = list(refs.values())
        while refs_to_visit:
            ref = refs_to_visit.pop()
            if ref.get_uuid() not in refs_visited:
                full_refs[ref.get_uuid()] = ref
                full_refs |= ref.get_refs()
                refs_to_visit.extend(ref.get_refs().values())
                refs_visited.add(ref.get_uuid())
        return full_refs

    def get_refs(self) -> dict[UUID, ReffableTreeNode]:
        return {}

    @final
    def set_refs_tracked(self, refs: dict[UUID, ReffableTreeNode]) -> None:
        for ref in refs.values():
            ref._assertion = False
            ref.set_refs(refs)
            assert ref._assertion

        self._assertion = False
        self.set_refs(refs)
        assert self._assertion

    def set_refs(self, refs: dict[UUID, ReffableTreeNode]) -> None:
        self._assertion = True


def uuid_flatten(uuid: UUID) -> tuple[tuple[Any, ...], Dict[str, Any]]:
    children: tuple[Any, ...] = tuple()
    aux_data = {"uuid": uuid.hex}
    return children, aux_data


def uuid_unflatten(aux_data: Dict[str, Any], children: Iterable[Any]) -> UUID:
    del children
    uuid = UUID(hex=aux_data["uuid"])
    return uuid


register_pytree_node(UUID, uuid_flatten, uuid_unflatten)

Toy usage:

@register_pytree_node_class
class GradientProvider(ReffableTreeNode):
    def __init__(self, value: float) -> None:
        super().__init__()
        self.value = value

    def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
        children, aux_data = super().tree_flatten_adia()
        children["value"] = self.value
        return children, aux_data

    def tree_unflatten_adia(
        self, aux_data: dict[str, Any], children: dict[str, Any]
    ) -> None:
        super().tree_unflatten_adia(aux_data, children)
        self.value = children["value"]

    def __str__(self) -> str:
        return f"GradientProvider(value={self.value})"

    def __repr__(self) -> str:
        return self.__str__()


@register_pytree_node_class 
class Equation(ReffableTreeNode):
    def __init__(self, constant: float, gradient: GradientProvider, x: float) -> None:
        super().__init__()
        self.constant = constant
        self.gradient = gradient
        self.x = x

    def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
        children, aux_data = super().tree_flatten_adia()
        children["constant"] = self.constant
        children["gradient"] = self.gradient.get_uuid()
        children["x"] = self.x
        return children, aux_data

    def tree_unflatten_adia(
        self, aux_data: dict[str, Any], children: dict[str, Any]
    ) -> None:
        super().tree_unflatten_adia(aux_data, children)
        self.constant = children["constant"]
        self.gradient = children["gradient"]
        self.x = children["x"]

    def get_refs(self) -> dict[UUID, ReffableTreeNode]:
        return super().get_refs() | {self.gradient.get_uuid(): self.gradient}

    def set_refs(self, refs: dict[UUID, ReffableTreeNode]) -> None:
        super().set_refs(refs)
        self.gradient = cast(GradientProvider, refs[cast(UUID, self.gradient)])

    @staticmethod
    def error_pure(x: float, gradient: GradientProvider, constant: float) -> float:
        # Solve for mx=c
        return (constant - gradient.value * x) ** 2

    def error(self) -> float:
        return Equation.error_pure(self.x, self.gradient, self.constant)

    def get_error(self) -> float:
        return jax.lax.cond( 
            True,
            self.error,
            self.error,
        )

    def __str__(self) -> str:
        return f"Equation(constant={self.constant}, gradient={self.gradient.value}, x={self.x})"

    def __repr__(self) -> str:
        return self.__str__()

@samskiter
Copy link

Improvement to the above - now UUIDs are only locally consistent within a given flatten - this means similarly 'shaped' trees of objects have the same TreeDef, allowing jitting to work properly, e.g when using vmap...

from __future__ import annotations
from typing import Any, final, NewType, ClassVar
from abc import ABCMeta


RefIdentifier = NewType("RefIdentifier", int)


class ReffableTreeNode(metaclass=ABCMeta):
    _uuid_counter: ClassVar[RefIdentifier] = RefIdentifier(1)
    _current_flatten_id: ClassVar[RefIdentifier] = RefIdentifier(0)
    children_to_flatten_counter: dict[int, int] = {}

    @classmethod
    def add_child_to_flatten(cls, uuid: RefIdentifier) -> None:
        if not uuid in cls.children_to_flatten_counter:
            cls.children_to_flatten_counter[uuid] = 0
        cls.children_to_flatten_counter[uuid] += 1

    @classmethod
    def remove_child_to_flatten(cls, uuid: RefIdentifier) -> None:
        cls.children_to_flatten_counter[uuid] -= 1
        if cls.children_to_flatten_counter[uuid] == 0:
            del cls.children_to_flatten_counter[uuid]

    @classmethod
    def clear_children_to_flatten(cls) -> None:
        cls.children_to_flatten_counter = {}

    def is_top_level_flatten(self) -> bool:
        return self.get_uuid() not in ReffableTreeNode.children_to_flatten_counter

    def get_uuid(self) -> RefIdentifier:
        if (
            not hasattr(self, "_uuid_flatten_id")
        ) or self._uuid_flatten_id < ReffableTreeNode._current_flatten_id:
            self._uuid = ReffableTreeNode._uuid_counter
            ReffableTreeNode._uuid_counter = RefIdentifier(
                ReffableTreeNode._uuid_counter + 1
            )
            self._uuid_flatten_id: RefIdentifier = ReffableTreeNode._current_flatten_id
        return self._uuid

    @final
    def tree_flatten(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
        is_top_level_flatten = self.is_top_level_flatten()
        if not is_top_level_flatten:
            ReffableTreeNode.remove_child_to_flatten(self.get_uuid())
        else:
            ReffableTreeNode.clear_children_to_flatten()
            ReffableTreeNode._current_flatten_id = RefIdentifier(
                ReffableTreeNode._current_flatten_id + 1
            )
            ReffableTreeNode._uuid_counter = RefIdentifier(1)

        self._did_call_super_flatten_assertion = False
        children, aux_data = self.tree_flatten_adia()
        assert self._did_call_super_flatten_assertion
        assert "uuid" in aux_data

        if is_top_level_flatten:
            all_refs = self.get_refs_iterative()
            for ref in sorted(all_refs.keys()):
                ReffableTreeNode.add_child_to_flatten(ref)
            all_refs = dict(sorted(all_refs.items()))
            children["all_refs"] = all_refs
            aux_data["uuid"] = 0
        return (children,), aux_data

    @final
    @classmethod
    def tree_unflatten(cls, aux_data: dict[str, Any], children: tuple[Any, ...]) -> Any:
        dict_children = children[0]
        instance = object.__new__(cls)
        instance._did_call_super_unflatten_assertion = False
        instance.tree_unflatten_adia(aux_data, dict_children)
        if "all_refs" in dict_children:
            instance.set_refs_tracked(dict_children["all_refs"])
        assert instance._did_call_super_unflatten_assertion
        return instance

    def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
        self._did_call_super_flatten_assertion = True
        return {}, {"uuid": self._uuid}

    def tree_unflatten_adia(
        self, aux_data: dict[str, Any], children: dict[str, Any]
    ) -> None:
        del children
        self._did_call_super_unflatten_assertion = True
        self._uuid = aux_data["uuid"]

    @final
    def get_refs_iterative(self) -> dict[RefIdentifier, ReffableTreeNode]:
        refs = self.get_refs()
        full_refs: dict[RefIdentifier, ReffableTreeNode] = {}
        refs_visited: set[RefIdentifier] = set()
        refs_to_visit = list(refs.values())
        while refs_to_visit:
            ref = refs_to_visit.pop()
            if ref.get_uuid() not in refs_visited:
                full_refs[ref.get_uuid()] = ref
                full_refs |= ref.get_refs()
                refs_to_visit.extend(ref.get_refs().values())
                refs_visited.add(ref.get_uuid())
        return full_refs

    def get_refs(self) -> dict[RefIdentifier, ReffableTreeNode]:
        return {}

    _did_call_super_set_refs_assertion: bool = False

    @final
    def set_refs_tracked(self, refs: dict[RefIdentifier, ReffableTreeNode]) -> None:
        for ref in refs.values():
            ref._did_call_super_set_refs_assertion = False
            ref.set_refs(refs)
            assert ref._did_call_super_set_refs_assertion

        self._did_call_super_set_refs_assertion = False
        self.set_refs(refs)
        assert self._did_call_super_set_refs_assertion

    def set_refs(self, refs: dict[RefIdentifier, ReffableTreeNode]) -> None:
        del refs
        self._did_call_super_set_refs_assertion = True

@samskiter
Copy link

More on the above here:

#17341 (comment)

I eventually worked out the correct semantics for UUIDs is to store them in aux data as we don't want them to be traced out - we want them to help define the shape of the tree.

@mattjj
Copy link
Collaborator

mattjj commented Jul 24, 2024

This stuff grew into Flax NNX. I think we can close the issue thread here!

@mattjj mattjj closed this as completed Jul 24, 2024
@samskiter
Copy link

@mattjj could you perhaps outline how NNX compares to what I used above - would be good to know if I can stop maintaining that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

6 participants