Skip to content

Commit

Permalink
Merge pull request #3704 from google:nnx-fix-state-sub
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611783222
  • Loading branch information
Flax Authors committed Mar 1, 2024
2 parents 1abfa87 + 4dcc0ef commit a1ecdf7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
11 changes: 7 additions & 4 deletions flax/experimental/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from flax import traverse_util
from flax.experimental.nnx.nnx import filterlib, reprlib
from flax.experimental.nnx.nnx.variables import Variable
from flax.typing import Path, Leaf
from flax.typing import Leaf, Path

A = tp.TypeVar('A')

Expand Down Expand Up @@ -180,7 +180,7 @@ def flat_state(self) -> dict[Key, Variable[Leaf]]:
return traverse_util.flatten_dict(self._mapping, sep='/') # type: ignore

@classmethod
def from_flat_path(cls, flat_state: FlatState) -> State:
def from_flat_path(cls, flat_state: FlatState, /) -> State:
nested_state = traverse_util.unflatten_dict(flat_state, sep='/')
return cls(nested_state)

Expand Down Expand Up @@ -274,8 +274,11 @@ def __sub__(self, other: 'State') -> 'State':
if not other:
return self

_mapping = {k: v for k, v in self._mapping.items() if k not in other}
return State(_mapping)
self_flat = self.flat_state()
other_flat = other.flat_state()
diff = {k: v for k, v in self_flat.items() if k not in other_flat}

return State.from_flat_path(diff)


def _state_flatten_with_keys(x: State):
Expand Down
4 changes: 2 additions & 2 deletions flax/experimental/nnx/nnx/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,8 @@ def scan_apply(
args,
is_leaf=lambda x: x is None,
)
broadcast_args = jax.tree_util.tree_map(
lambda axis, node: None if axis is not None else node,
broadcast_args = jax.tree_map(
lambda axis, node: node if axis is None else None,
options.in_args_axes,
args,
is_leaf=lambda x: x is None,
Expand Down

0 comments on commit a1ecdf7

Please sign in to comment.