Skip to content

Commit

Permalink
[nnx] non-str State keys
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 30, 2024
1 parent fc6c901 commit c987dbb
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 119 deletions.
22 changes: 11 additions & 11 deletions flax/experimental/nnx/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

import builtins
import dataclasses
from flax.typing import Path
from flax.typing import PathParts
import typing as tp

if tp.TYPE_CHECKING:
ellipsis = builtins.ellipsis
else:
ellipsis = tp.Any

Predicate = tp.Callable[[Path, tp.Any], bool]
Predicate = tp.Callable[[PathParts, tp.Any], bool]
FilterLiteral = tp.Union[type, str, Predicate, bool, ellipsis, None]
Filter = tp.Union[FilterLiteral, tuple[FilterLiteral, ...], list[FilterLiteral]]

Expand All @@ -48,17 +48,17 @@ def to_predicate(filter: Filter) -> Predicate:

@dataclasses.dataclass
class AtPath:
path: str
str_key: str

def __call__(self, path: Path, x: tp.Any):
return self.path == path
def __call__(self, path: PathParts, x: tp.Any):
return self.str_key in path


@dataclasses.dataclass
class OfType:
type: type

def __call__(self, path: Path, x: tp.Any):
def __call__(self, path: PathParts, x: tp.Any):
return isinstance(x, self.type)


Expand All @@ -68,7 +68,7 @@ def __init__(self, *filters: Filter):
to_predicate(collection_filter) for collection_filter in filters
)

def __call__(self, path: Path, x: tp.Any):
def __call__(self, path: PathParts, x: tp.Any):
return any(predicate(path, x) for predicate in self.predicates)


Expand All @@ -78,23 +78,23 @@ def __init__(self, *filters: Filter):
to_predicate(collection_filter) for collection_filter in filters
)

def __call__(self, path: Path, x: tp.Any):
def __call__(self, path: PathParts, x: tp.Any):
return all(predicate(path, x) for predicate in self.predicates)


class Not:
def __init__(self, collection_filter: Filter):
self.predicate = to_predicate(collection_filter)

def __call__(self, path: Path, x: tp.Any):
def __call__(self, path: PathParts, x: tp.Any):
return not self.predicate(path, x)


class Everything:
def __call__(self, path: Path, x: tp.Any):
def __call__(self, path: PathParts, x: tp.Any):
return True


class Nothing:
def __call__(self, path: Path, x: tp.Any):
def __call__(self, path: PathParts, x: tp.Any):
return False
Loading

0 comments on commit c987dbb

Please sign in to comment.