Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
.
Browse files Browse the repository at this point in the history
NeilGirdhar committed Nov 24, 2023
1 parent f4a9d6e commit acc0718
Showing 2 changed files with 41 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tjax/_src/display/display_generic.py
Original file line number Diff line number Diff line change
@@ -45,7 +45,7 @@
# Extra imports ------------------------------------------------------------------------------------
FlaxModule: type[Any]
try:
from flax.linen import Module as FlaxModule
from flax.experimental.nnx import Module as FlaxModule
flax_loaded = True
except ImportError:
flax_loaded = False
41 changes: 40 additions & 1 deletion tjax/_src/graph.py
Original file line number Diff line number Diff line change
@@ -45,10 +45,49 @@ def tree_flatten(graph: T) -> tuple[Sequence[PyTree], Hashable]:
register_graph_as_jax_pytree(nx.DiGraph)

try:
from flax.serialization import from_state_dict, register_serialization_state, to_state_dict
from flax.experimental.nnx.nnx.graph_utils import register_node_type
except ImportError:
pass
else:
def _flatten_graph(node: nx.Graph
) -> tuple[tuple[dict[str, Any],
dict[str, Any]],
None]:
edge_dict_of_dicts = defaultdict[Any, dict[Any, Any]](dict)
for (source, target), edge_dict in dict(node.edges).items():
edge_dict_of_dicts[source][target] = edge_dict
return ((dict(node.nodes),
dict(edge_dict_of_dicts)),
None)

def _get_key_graph(node: nx.Graph, key: str) -> Any:
return node[key]

def _set_key_graph(node: nx.Graph, key: str, value: Any) -> nx.Graph:
node.nodes[key] = value
return node

def _has_key_graph(node: nx.Graph, key: str) -> bool:
return key in node

def _all_keys_graph(node: nx.Graph) -> tuple[str, ...]:
return tuple(node.keys())

def _create_empty_graph(metadata: None) -> nx.Graph:
return {}

def _init_graph(node: nx.Graph, items: tuple[tuple[str, Any], ...]):
node.update(items)

register_node_type(nx.DiGraph,
_flatten_graph,
_get_key_graph,
_set_key_graph,
_has_key_graph,
_all_keys_graph,
create_empty=_create_empty_graph,
init=_init_graph)

def register_graph_as_flax_state_dict(cls: type[T]) -> None:
def ty_to_state_dict(graph: T) -> dict[str, Any]:
edge_dict_of_dicts = defaultdict[Any, dict[Any, Any]](dict)

0 comments on commit acc0718

Please sign in to comment.