-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyDags proposal: a mechanism to preserve relative identities for Pytrees #7919
Comments
I might be missing something in the proposal, but what if you unflatten twice with different elements? You can't reuse the same objects to construct two distinct pytrees. |
Hey @apaszke! Unflattening many times is not a problem, implementation-wise you would probably create a new dictionary to keep track of the identities and store the first unflattened instance each time you call m = Module()
tree = (m, m)
leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)
m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)
assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m |
It might be worth looking at how Python's pickle solves this problem. It does essentially the same thing (preserving identity through serialization) by default:
In principle, this might be a reasonable thing for JAX to do with pytrees. I haven't looked into how the implementations differ, though. |
A terminological point: preserving object identity would mean violating referential transparency, and in particular these would be more like "pydags" than "pytrees". I don't think we want to break referential transparency with existing pytree types. That would prevent us from processing them recursively in a functionally pure way (as @cgarciae already mentioned under "Implementation", which refers to basically a side-effecting memoization process). Moreover I wouldn't be surprised if we leverage the referential transparency assumption in lots of different places. But just for pytree flattening/unflattening alone as in @cgarciae's most recent comment, the API is already general enough to handle DAGs in your own custom pytree types, so long as you are willing to break referential transparency in your own flattening functions. You just need to do the deduplication-by-python-object-id (and equality-up-to-alpha-renaming) yourself: from typing import Any, NamedTuple
import itertools as it
from collections import defaultdict
import jax
from jax.util import unzip2
from jax.tree_util import register_pytree_node
class MyTuple:
elts: tuple[Any]
def __init__(self, *elts):
self.elts = elts
def __iter__(self):
return iter(self.elts)
def flatten_mytuple(x):
counts = it.count()
id_to_name = defaultdict(lambda: next(counts))
name_list = [id_to_name[id(e)] for e in x.elts]
uniques = {id(e): e for e in x.elts}
unique_names, unique_vals = unzip2((id_to_name[i], v) for i, v in uniques.items())
return unique_vals, (unique_names, name_list)
def unflatten_mytuple(aux, unique_vals):
unique_names, name_list = aux
uniques = dict(zip(unique_names, unique_vals))
elts = [uniques[name] for name in name_list]
return MyTuple(*elts)
register_pytree_node(MyTuple, flatten_mytuple, unflatten_mytuple)
class Module(NamedTuple): pass # added this as trivial pytree
###
m = Module()
tree = MyTuple(m, m) # NOTE: changed from Python tuple to MyTuple!
leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)
m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)
assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m I'm not sure of the limitations of this approach. For example, when Tracers are involved, this reliance on Python object identity might lead to surprising results (but maybe it'd be okay to rely on Python object identity just for values which cant be wrapped in Tracers, i.e. in your own custom pytree data types?). |
Hey @mattjj! Good point about this being more like a Pydag. I think the main problem with your suggestion is that when you have nested structures like this class Child:
shared: Module
class Parent:
left: Child
right: Child
shared = Module()
parent = Parent(left=Child(shared), right=Child(shared)) local information is not enough because Implementation Proposal 2: Unflatten callbacksA light-weight alternative closer to @mattjj's idea could to that JAX created a e.g. # tree_util.py
def register_unflatten_callabacks(
tree_unflatten_start_fn: Callable[[], Any],
tree_unflatten_end_fn: Callable[[Any], None],
) If Module implementationBased on this a Module library could define a global import functools
import typing as tp
import jax
import jax.tree_util
# in reality use something based on threading.local
_IDENTITIES: tp.Optional[tp.Dict[str, tp.Any]] = None
def tree_unflatten_start_fn() -> tp.Any:
global _IDENTITIES
old_identities = _IDENTITIES
_IDENTITIES = {}
return old_identities
def tree_unflatten_end_fn(old_identities: tp.Any) -> None:
global _IDENTITIES
_IDENTITIES = old_identities
# register callbacks
jax.tree_util.register_unflatten_callabacks(tree_unflatten_start, tree_unflatten_end) Now we can define class Module:
def tree_flatten(self):
tree = vars(self)
obj_id = id(self)
return (tree,), (obj_id,)
@classmethod
def tree_unflatten(cls, aux, children):
tree = children[0]
obj_id = aux[0]
if _IDENTITIES is not None and obj_id in _IDENTITIES:
return _IDENTITIES[obj_id]
obj = cls.__new__(cls)
obj.__dict__.update(tree)
if _IDENTITIES is not None:
_IDENTITIES[obj_id] = obj
return obj
def __init_subclass__(cls) -> None:
jax.tree_util.register_pytree_node_class(cls) ExampleBased on these definitions we can recreate the original example, all the following assertions should hold. class Shared(Module):
pass
class Child(Module):
def __init__(self, shared: Shared) -> None:
self.shared = shared
class Parent(Module):
def __init__(self, left: Child, right: Child) -> None:
self.left = left
self.right = right
shared = Shared()
parent = Parent(Child(shared), Child(shared))
leaves, treedef = jax.tree_flatten(parent)
leaves2 = treedef.flatten_up_to(parent)
parent1: Parent = jax.tree_unflatten(treedef, leaves)
parent2: Parent = jax.tree_unflatten(treedef, leaves2)
# checks
assert parent1 is not parent
assert parent2 is not parent1
assert parent1.left.shared is parent1.right.shared
assert parent1.left.shared is not parent.left.shared
assert parent1.right.shared is not parent.right.shared
assert parent1.left is not parent1.right
assert parent2.left.shared is parent2.right.shared
assert parent2.left.shared is not parent1.left.shared
assert parent2.right.shared is not parent1.right.shared
assert parent2.left is not parent2.right PrototypeI've created a prototype putting all the previous together where I just creating my own Show codeimport functools
import typing as tp
import jax
import jax.tree_util
_IDENTITIES: tp.Optional[tp.Dict[str, tp.Any]] = None
# these 2 function behave like a context manager around `jax.tree_unflatten`
def tree_unflatten_start() -> tp.Any:
global _IDENTITIES
old_identities = _IDENTITIES
_IDENTITIES = {}
return old_identities
def tree_unflatten_end(old_identities: tp.Any) -> None:
global _IDENTITIES
_IDENTITIES = old_identities
# in reality you would just call: jax.tree_util.register_unflatten_callabacks(tree_unflatten_start, tree_unflatten_end)
def tree_unflatten(*args, **kwargs):
old_identities = tree_unflatten_start()
try:
return jax.tree_unflatten(*args, **kwargs)
finally:
tree_unflatten_end(old_identities)
class Module:
def tree_flatten(self):
tree = vars(self)
obj_id = id(self)
return (tree,), (obj_id,)
@classmethod
def tree_unflatten(cls, aux, children):
tree = children[0]
obj_id = aux[0]
if _IDENTITIES is not None and obj_id in _IDENTITIES:
return _IDENTITIES[obj_id]
obj = cls.__new__(cls)
obj.__dict__.update(tree)
if _IDENTITIES is not None:
_IDENTITIES[obj_id] = obj
return obj
def __init_subclass__(cls) -> None:
jax.tree_util.register_pytree_node_class(cls)
class Shared(Module):
pass
class Child(Module):
def __init__(self, shared: Shared) -> None:
self.shared = shared
class Parent(Module):
def __init__(self, left: Child, right: Child) -> None:
self.left = left
self.right = right
shared = Shared()
parent = Parent(Child(shared), Child(shared))
leaves, treedef = jax.tree_flatten(parent)
leaves2 = treedef.flatten_up_to(parent)
parent1: Parent = tree_unflatten(treedef, leaves)
parent2: Parent = tree_unflatten(treedef, leaves2)
assert parent1 is not parent
assert parent2 is not parent1
assert parent1.left.shared is parent1.right.shared
assert parent1.left.shared is not parent.left.shared
assert parent1.right.shared is not parent.right.shared
assert parent1.left is not parent1.right
assert parent2.left.shared is parent2.right.shared
assert parent2.left.shared is not parent1.left.shared
assert parent2.right.shared is not parent1.right.shared
assert parent2.left is not parent2.right |
You're right, I meant to mention that but I neglected to: in general the flattening function would be responsible for flattening its whole subtree, not just flattening one node as usual, by calling into your own set of stateful flatteners! That is, the pytree flattening function for My example code did not do recursive DAG flattening. Here's a version that does! First, a general from functools import partial
import itertools as it
from collections import defaultdict
from typing import (Callable, Type, Hashable, Dict, Any, NamedTuple, Tuple,
Sequence, List, Union)
from jax.util import unzip2
Name = int
Names = Sequence[int]
AuxData = Any
class NodeType(NamedTuple):
name: str
to_iterable: Callable[[Callable, Any], Tuple[AuxData, Names]]
from_iterable: Callable[[Callable, AuxData, Names], Any]
def __repr__(self): return f'DagNode[{self.name}]'
def register_pydag_node(ty: Type, to_iter: Callable, from_iter: Callable)-> None:
node_types[ty] = NodeType(str(ty), to_iter, from_iter)
node_types: Dict[Type, NodeType] = {}
register_pydag_node(tuple,
lambda f, t: (None, [f(e) for e in t]),
lambda u, _, names: tuple(u(n) for n in names))
register_pydag_node(list,
lambda f, l: (None, [f(e) for e in l]),
lambda u, _, names: [u(n) for n in names])
register_pydag_node(dict,
lambda f, d: (tuple(sorted(d)), [f(d[k]) for k in sorted(d)]),
lambda u, keys, names: {k: u(n) for k, n in zip(keys, names)})
class PyDagNode(NamedTuple):
node_type: NodeType
node_auxdata: Hashable
names: Names
class PyDagLeaf(NamedTuple):
name: Name
PyDagDef = Union[PyDagNode, PyDagLeaf]
class FlattenState(NamedTuple):
id_to_name: Dict[int, Name]
name_to_obj: Dict[Name, Any]
def dag_flatten(x: Any) -> Tuple[List[Any], Tuple[List[Name], PyDagDef]]:
names = it.count()
state = FlattenState(defaultdict(lambda: next(names)), dict())
dagdef = _dag_flatten(state, x)
unique_names, unique_vals = unzip2(state.name_to_obj.items())
return unique_vals, (unique_names, dagdef)
def _dag_flatten(state: FlattenState, x: Any) -> PyDagDef:
node_type = node_types.get(type(x))
if node_type:
node_auxdata, names = node_type.to_iterable(partial(_dag_flatten, state), x)
return PyDagNode(node_type, node_auxdata, names)
else:
name = state.id_to_name[id(x)]
state.name_to_obj[name] = x
return PyDagLeaf(name)
class UnflattenState(NamedTuple):
name_to_obj: Dict[Name, Any]
def dag_unflatten(dagspec: Tuple[Names, PyDagDef], xs: List[Any]) -> Any:
names, dagdef = dagspec
state = UnflattenState(dict(zip(names, xs)))
return _dag_unflatten(state, dagdef)
def _dag_unflatten(state: UnflattenState, dagdef: PyDagDef) -> Any:
if type(dagdef) is PyDagLeaf:
return state.name_to_obj[dagdef.name]
else:
u = partial(_dag_unflatten, state)
return dagdef.node_type.from_iterable(u, dagdef.node_auxdata, dagdef.names) This is probably "hella buggy", as we'd say where I come from, and I reserve the right to edit this github comment to fix embarrassing mistakes. But it passed literally one example I tried it on, so ship it! Now, here's a MyTuple pytree which calls into that dag flattening (i.e. interfaces pydags with the existing pytree system): class MyTuple:
elts: tuple[Any]
def __init__(self, *elts):
self.elts = elts
def __iter__(self):
return iter(self.elts)
# register with our pydag system
register_pydag_node(MyTuple,
lambda f, t: (None, [f(e) for e in t]),
lambda u, _, names: MyTuple(*[u(n) for n in names]))
# register as a pytree with jax, but tell it to flatten like a dag
from jax.tree_util import register_pytree_node
register_pytree_node(MyTuple, dag_flatten, dag_unflatten)
###
import jax
class Module(NamedTuple): pass # added this as trivial pytree
# Test 1
m = Module()
tree = MyTuple(m, m)
leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)
m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)
assert m11 is m12
assert m21 is m22
assert m11 is not m21 and m12 is not m22
assert m11 is not m and m12 is not m
# Test 2
m = Module()
tree = MyTuple(MyTuple(m, m), MyTuple(m, m))
leaves, treedef = jax.tree_flatten(tree)
((m11, m12), (m21, m22)) = jax.tree_unflatten(treedef, leaves)
assert m11 is m12 is m21 is m22 My main point is just that I think with the existing pytree system you can at least flatten subtrees of your custom pytree types however you'd like, including as dags-by-objectid. Maybe that can unblock you! Of course, we could also consider building some pydag behavior into JAX, which would let us deduplicate across all argument lists (even if the top-level container is not a custom pytree type you control). It's worth considering! Like I said before, I'm a bit wary of where we might leverage the referential transparency assumption. But maybe it'd all work out... experimenting with the above pydag approach might help us learn things! WDYT? Does this approach unblock you, without needing JAX-internal changes? |
@mattjj class Module:
def __init__(self, x):
self.x = x
register_pydag_node(
Module,
lambda f, t: (None, [f(t.x)]),
lambda u, _, names: Module(u(names[0])),
)
# Test 1
m = Module(0)
tree = MyTuple(m, m)
leaves, treedef = jax.tree_flatten(tree)
leaves2 = treedef.flatten_up_to(tree)
assert len(leaves) == 1
m11, m12 = jax.tree_unflatten(treedef, leaves)
m21, m22 = jax.tree_unflatten(treedef, leaves2)
assert m11 is m12 # Fails here Maybe I implemented pydag for Module wrong. I haven't been able to pin point the part where nodes are reused if they have already been unflattened in A downside I see with having Pydags in a separate world from Pytrees is that custom Pytrees are just |
Back to my original proposal, what if |
Update: My solution wasn't taking into account the |
Is there an easy way to illustrate the is_leaf issue in a toy example? (I believe you that it's not working, I just don't yet grok what you mean.) |
In the mean time, I found a relatively simple (yet highly inefficient) way to get some sort of identity preservation just using pytrees: import dataclasses
from typing import Any, Dict, Tuple, TypeVar
import jax
import numpy as np
A = TypeVar("A")
#-------------------------------------------------------------------------------
class Nothing:
pass
def _flatten_nothing(_: Nothing) -> Tuple[Tuple[()], None]:
return (), None
def _unflatten_nothing(_: None, __: Tuple[()]) -> Nothing:
return Nothing()
jax.tree_util.register_pytree_node(Nothing, _flatten_nothing, _unflatten_nothing)
#-------------------------------------------------------------------------------
@dataclasses.dataclass(frozen=True)
class ValueIndex:
value: Any
index: int
def _flatten_value_index(value_index: ValueIndex) -> Tuple[Tuple[(Any,)], int]:
return (value_index.value,), value_index.index
def _unflatten_value_index(index: int, children: Tuple[Any]) -> ValueIndex:
return ValueIndex(children[0], index)
jax.tree_util.register_pytree_node(
ValueIndex, _flatten_value_index, _unflatten_value_index
)
#-------------------------------------------------------------------------------
def deref(pytree: A, is_leaf=None) -> A:
id_to_index: Dict[int, int] = {}
def deref_fn(leaf: Any) -> Any:
leaf_id = id(leaf)
if leaf_id not in id_to_index:
id_to_index[leaf_id] = len(id_to_index)
return ValueIndex(leaf, index=id_to_index[leaf_id])
else:
return ValueIndex(Nothing(), index=id_to_index[leaf_id])
return jax.tree_map(deref_fn, pytree, is_leaf=is_leaf)
def reref(pytree: A) -> A:
index_to_value: Dict[int, Any] = {}
def reref_fn(value_index: ValueIndex) -> Any:
if value_index.index not in index_to_value:
index_to_value[value_index.index] = value_index.value
value = index_to_value[value_index.index]
return value
return jax.tree_map(reref_fn, pytree, is_leaf=lambda x: isinstance(x, ValueIndex))
#-------------------------------------------------------------------------------
def dag_map(f, pytree, is_leaf=None):
pytree = deref(pytree, is_leaf=is_leaf)
pytree = jax.tree_map(f, pytree)
return reref(pytree)
#-------------------------------------------------------------------------------
a = np.array(0)
b = np.array(0)
pytree = {"x": [a, a, b], "y": b}
def add_noise(x):
return x + np.random.normal()
print("jax.tree_map")
print(jax.tree_map(add_noise, pytree))
print("dag_map")
print(dag_map(add_noise, pytree))
It doesn't preserve the identity of intermediate pytrees though. However, you can use @jax.jit
def f(pytree):
pytree = reref(pytree)
assert pytree["x"][0] is pytree["x"][1]
assert pytree["x"][2] is pytree["y"]
return deref(pytree)
pytree = reref(f(deref(pytree)))
assert pytree["x"][0] is pytree["x"][1]
assert pytree["x"][2] is pytree["y"] |
+1 for this - I immediately hit this issue when moving from MyGrad to JAX. Jax's autograd wasn't as performant as mygrad so I attempted to JIT - my code is fairly 'standard' object-oriented code with plenty of refs scattered throughout (I'm actually building a graph of interrelated objects). I'm doing physics simulation and using grad to solve nonlinear equations - would love to see something like this to make code more 'natural'... Thanks for the proposal - IIUC I could use this to solve the above problem? |
@cgarciae Here's my attempt at an inheritable class that allows preserving object identities.... Seems to be working for many cases, but I still haven't worked out what the correct semantics should be for UUIDs (should they be stored in class ReffableTreeNode(metaclass=ABCMeta):
children_to_flatten_counter: dict[UUID, int] = {}
@classmethod
def add_child_to_flatten(cls, uuid: UUID) -> None:
if not uuid in cls.children_to_flatten_counter:
cls.children_to_flatten_counter[uuid] = 0
cls.children_to_flatten_counter[uuid] += 1
@classmethod
def remove_child_to_flatten(cls, uuid: UUID) -> None:
cls.children_to_flatten_counter[uuid] -= 1
if cls.children_to_flatten_counter[uuid] == 0:
del cls.children_to_flatten_counter[uuid]
@classmethod
def clear_children_to_flatten(cls) -> None:
cls.children_to_flatten_counter = {}
def is_top_level_flatten(self) -> bool:
return self.get_uuid() not in ReffableTreeNode.children_to_flatten_counter
def __init__(self) -> None:
self._uuid = uuid4()
self._did_set_refs = False
def get_uuid(self) -> UUID:
return self._uuid
@final
def tree_flatten(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
is_top_level_flatten = self.is_top_level_flatten()
if not is_top_level_flatten:
ReffableTreeNode.remove_child_to_flatten(self.get_uuid())
self._assertion = False
children, aux_data = self.tree_flatten_adia()
assert self._assertion
assert "uuid" in aux_data
if is_top_level_flatten:
all_refs = self.get_refs_iterative()
for ref in all_refs.keys():
ReffableTreeNode.add_child_to_flatten(ref)
children["all_refs"] = all_refs
return (children,), aux_data
@final
@classmethod
def tree_unflatten(cls, aux_data: dict[str, Any], children: tuple[Any, ...]) -> Any:
dict_children = children[0]
instance = object.__new__(cls)
instance._assertion2 = False
instance.tree_unflatten_adia(aux_data, dict_children)
if "all_refs" in dict_children:
instance.set_refs_tracked(dict_children["all_refs"])
assert instance._assertion2
return instance
def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
self._assertion = True
return {}, {"uuid": self._uuid}
def tree_unflatten_adia(
self, aux_data: dict[str, Any], children: dict[str, Any]
) -> None:
del children
self._assertion2 = True
self._uuid = aux_data["uuid"]
@final
def get_refs_iterative(self) -> dict[UUID, ReffableTreeNode]:
refs = self.get_refs()
full_refs: dict[UUID, ReffableTreeNode] = {}
refs_visited: set[UUID] = set()
refs_to_visit = list(refs.values())
while refs_to_visit:
ref = refs_to_visit.pop()
if ref.get_uuid() not in refs_visited:
full_refs[ref.get_uuid()] = ref
full_refs |= ref.get_refs()
refs_to_visit.extend(ref.get_refs().values())
refs_visited.add(ref.get_uuid())
return full_refs
def get_refs(self) -> dict[UUID, ReffableTreeNode]:
return {}
@final
def set_refs_tracked(self, refs: dict[UUID, ReffableTreeNode]) -> None:
for ref in refs.values():
ref._assertion = False
ref.set_refs(refs)
assert ref._assertion
self._assertion = False
self.set_refs(refs)
assert self._assertion
def set_refs(self, refs: dict[UUID, ReffableTreeNode]) -> None:
self._assertion = True
def uuid_flatten(uuid: UUID) -> tuple[tuple[Any, ...], Dict[str, Any]]:
children: tuple[Any, ...] = tuple()
aux_data = {"uuid": uuid.hex}
return children, aux_data
def uuid_unflatten(aux_data: Dict[str, Any], children: Iterable[Any]) -> UUID:
del children
uuid = UUID(hex=aux_data["uuid"])
return uuid
register_pytree_node(UUID, uuid_flatten, uuid_unflatten) Toy usage: @register_pytree_node_class
class GradientProvider(ReffableTreeNode):
def __init__(self, value: float) -> None:
super().__init__()
self.value = value
def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
children, aux_data = super().tree_flatten_adia()
children["value"] = self.value
return children, aux_data
def tree_unflatten_adia(
self, aux_data: dict[str, Any], children: dict[str, Any]
) -> None:
super().tree_unflatten_adia(aux_data, children)
self.value = children["value"]
def __str__(self) -> str:
return f"GradientProvider(value={self.value})"
def __repr__(self) -> str:
return self.__str__()
@register_pytree_node_class
class Equation(ReffableTreeNode):
def __init__(self, constant: float, gradient: GradientProvider, x: float) -> None:
super().__init__()
self.constant = constant
self.gradient = gradient
self.x = x
def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
children, aux_data = super().tree_flatten_adia()
children["constant"] = self.constant
children["gradient"] = self.gradient.get_uuid()
children["x"] = self.x
return children, aux_data
def tree_unflatten_adia(
self, aux_data: dict[str, Any], children: dict[str, Any]
) -> None:
super().tree_unflatten_adia(aux_data, children)
self.constant = children["constant"]
self.gradient = children["gradient"]
self.x = children["x"]
def get_refs(self) -> dict[UUID, ReffableTreeNode]:
return super().get_refs() | {self.gradient.get_uuid(): self.gradient}
def set_refs(self, refs: dict[UUID, ReffableTreeNode]) -> None:
super().set_refs(refs)
self.gradient = cast(GradientProvider, refs[cast(UUID, self.gradient)])
@staticmethod
def error_pure(x: float, gradient: GradientProvider, constant: float) -> float:
# Solve for mx=c
return (constant - gradient.value * x) ** 2
def error(self) -> float:
return Equation.error_pure(self.x, self.gradient, self.constant)
def get_error(self) -> float:
return jax.lax.cond(
True,
self.error,
self.error,
)
def __str__(self) -> str:
return f"Equation(constant={self.constant}, gradient={self.gradient.value}, x={self.x})"
def __repr__(self) -> str:
return self.__str__() |
Improvement to the above - now UUIDs are only locally consistent within a given flatten - this means similarly 'shaped' trees of objects have the same TreeDef, allowing jitting to work properly, e.g when using vmap... from __future__ import annotations
from typing import Any, final, NewType, ClassVar
from abc import ABCMeta
RefIdentifier = NewType("RefIdentifier", int)
class ReffableTreeNode(metaclass=ABCMeta):
_uuid_counter: ClassVar[RefIdentifier] = RefIdentifier(1)
_current_flatten_id: ClassVar[RefIdentifier] = RefIdentifier(0)
children_to_flatten_counter: dict[int, int] = {}
@classmethod
def add_child_to_flatten(cls, uuid: RefIdentifier) -> None:
if not uuid in cls.children_to_flatten_counter:
cls.children_to_flatten_counter[uuid] = 0
cls.children_to_flatten_counter[uuid] += 1
@classmethod
def remove_child_to_flatten(cls, uuid: RefIdentifier) -> None:
cls.children_to_flatten_counter[uuid] -= 1
if cls.children_to_flatten_counter[uuid] == 0:
del cls.children_to_flatten_counter[uuid]
@classmethod
def clear_children_to_flatten(cls) -> None:
cls.children_to_flatten_counter = {}
def is_top_level_flatten(self) -> bool:
return self.get_uuid() not in ReffableTreeNode.children_to_flatten_counter
def get_uuid(self) -> RefIdentifier:
if (
not hasattr(self, "_uuid_flatten_id")
) or self._uuid_flatten_id < ReffableTreeNode._current_flatten_id:
self._uuid = ReffableTreeNode._uuid_counter
ReffableTreeNode._uuid_counter = RefIdentifier(
ReffableTreeNode._uuid_counter + 1
)
self._uuid_flatten_id: RefIdentifier = ReffableTreeNode._current_flatten_id
return self._uuid
@final
def tree_flatten(self) -> tuple[tuple[Any, ...], dict[str, Any]]:
is_top_level_flatten = self.is_top_level_flatten()
if not is_top_level_flatten:
ReffableTreeNode.remove_child_to_flatten(self.get_uuid())
else:
ReffableTreeNode.clear_children_to_flatten()
ReffableTreeNode._current_flatten_id = RefIdentifier(
ReffableTreeNode._current_flatten_id + 1
)
ReffableTreeNode._uuid_counter = RefIdentifier(1)
self._did_call_super_flatten_assertion = False
children, aux_data = self.tree_flatten_adia()
assert self._did_call_super_flatten_assertion
assert "uuid" in aux_data
if is_top_level_flatten:
all_refs = self.get_refs_iterative()
for ref in sorted(all_refs.keys()):
ReffableTreeNode.add_child_to_flatten(ref)
all_refs = dict(sorted(all_refs.items()))
children["all_refs"] = all_refs
aux_data["uuid"] = 0
return (children,), aux_data
@final
@classmethod
def tree_unflatten(cls, aux_data: dict[str, Any], children: tuple[Any, ...]) -> Any:
dict_children = children[0]
instance = object.__new__(cls)
instance._did_call_super_unflatten_assertion = False
instance.tree_unflatten_adia(aux_data, dict_children)
if "all_refs" in dict_children:
instance.set_refs_tracked(dict_children["all_refs"])
assert instance._did_call_super_unflatten_assertion
return instance
def tree_flatten_adia(self) -> tuple[dict[str, Any], dict[str, Any]]:
self._did_call_super_flatten_assertion = True
return {}, {"uuid": self._uuid}
def tree_unflatten_adia(
self, aux_data: dict[str, Any], children: dict[str, Any]
) -> None:
del children
self._did_call_super_unflatten_assertion = True
self._uuid = aux_data["uuid"]
@final
def get_refs_iterative(self) -> dict[RefIdentifier, ReffableTreeNode]:
refs = self.get_refs()
full_refs: dict[RefIdentifier, ReffableTreeNode] = {}
refs_visited: set[RefIdentifier] = set()
refs_to_visit = list(refs.values())
while refs_to_visit:
ref = refs_to_visit.pop()
if ref.get_uuid() not in refs_visited:
full_refs[ref.get_uuid()] = ref
full_refs |= ref.get_refs()
refs_to_visit.extend(ref.get_refs().values())
refs_visited.add(ref.get_uuid())
return full_refs
def get_refs(self) -> dict[RefIdentifier, ReffableTreeNode]:
return {}
_did_call_super_set_refs_assertion: bool = False
@final
def set_refs_tracked(self, refs: dict[RefIdentifier, ReffableTreeNode]) -> None:
for ref in refs.values():
ref._did_call_super_set_refs_assertion = False
ref.set_refs(refs)
assert ref._did_call_super_set_refs_assertion
self._did_call_super_set_refs_assertion = False
self.set_refs(refs)
assert self._did_call_super_set_refs_assertion
def set_refs(self, refs: dict[RefIdentifier, ReffableTreeNode]) -> None:
del refs
self._did_call_super_set_refs_assertion = True |
More on the above here: I eventually worked out the correct semantics for UUIDs is to store them in aux data as we don't want them to be traced out - we want them to help define the shape of the tree. |
This stuff grew into Flax NNX. I think we can close the issue thread here! |
@mattjj could you perhaps outline how NNX compares to what I used above - would be good to know if I can stop maintaining that |
Motivation
When
jax.tree_unflatten
is called new instances of all pytree objects are created and their original identities are lost. While this is expected for basic container types like lists, dicts, and tuples, for other types such as Pytree Modules this can be inconvenient as it makes tasks like parameter sharing difficult. Here is an example that currently doesn't work of trying to share 2Child
modules in the sameParent
module:Proposal
Enable preserving the relative identities of custom Pytree classes that opt-in to this behaviour.
By relative identities it means that if two objects in a Pytree have the same identities before flattening, the objects will share the same identity between them after unflattening them, but they won't have the same identity as their original objects. This means the following assertions are true assuming
Module
opted-in to this behavior:Implementation
To achieve this
register_pytree_node
could accept an optionalpreserve_relative_identities: bool
flag (or something like this) that indicates that objects of this class opt-in to preserve their relative identities. Whentree_flatten
is called each node's objectid
could be stored in thePyTreeDef
such that whentree_unflatten
is unflattening a node, and that nodes class hadpreserve_relative_identities=True
, thentree_unflatten
will check if it had already unflattened that element based on theid
and reuse that node if that is the case.preserve_relative_identities
should also be available forregister_pytree_node_class
.Implications
All current code should run normally, only new code that opt-in to this behavior will use this feature.
The text was updated successfully, but these errors were encountered: