Skip to content

Commit

Permalink
add nontransparent
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jan 12, 2024
1 parent ed44f52 commit 780c9b4
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 1 deletion.
1 change: 1 addition & 0 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
Variable as Variable,
apply as apply,
compact as compact,
nontransparent as nontransparent,
disable_named_call as disable_named_call,
enable_named_call as enable_named_call,
init_with_output as init_with_output,
Expand Down
111 changes: 110 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,80 @@ def nowrap(fun: _CallableT) -> _CallableT:
return fun


def nontransparent(fun: _CallableT) -> _CallableT:
"""Creates compact submodules from a method.
This is a decorator that allows you to define compact submodules from a
method. It's intention is to make it easier to port code Haiku code to Flax
by providing the same functionality.
Example::
>>> import flax.linen as nn
>>> import jax
>>> import jax.numpy as jnp
>>> from flax.core import pretty_repr
...
>>> class Foo(nn.Module):
... @nn.nontransparent
... def up(self, x):
... return nn.Dense(3)(x)
...
... @nn.nontransparent
... def down(self, x):
... return nn.Dense(3)(x)
...
... @nn.compact
... def __call__(self, x):
... return self.up(x) + self.down(x) + nn.Dense(3)(x)
...
>>> module = Foo()
>>> variables = module.init(jax.random.PRNGKey(0), jnp.ones([1, 3]))
>>> print(pretty_repr(variables['params']))
{
Dense_0: {
bias: (3,),
kernel: (3, 3),
},
down: {
Dense_0: {
bias: (3,),
kernel: (3, 3),
},
},
up: {
Dense_0: {
bias: (3,),
kernel: (3, 3),
},
},
}
Args:
fun: The Module method to mark as nontransparent.
Returns:
The given function ``fun`` marked as nontransparent.
"""

@functools.wraps(fun)
def nontransparent_wrapper(self: nn.Module, *args, **kwargs):
name = fun.__name__
if not hasattr(self, f'{name}_nontransparent'):
raise ValueError(
f'Cannot call nontransparent method {name!r} on a Module that has not been '
f'setup. This is likely because you are calling {name!r} '
'from outside of init or apply.'
)
module = getattr(self, f'{name}_nontransparent')
return module(*args, **kwargs)

nontransparent_wrapper.nontransparent = True # type: ignore[attr-defined]
nontransparent_wrapper.inner_fun = fun # type: ignore[attr-defined]
nontransparent_wrapper.nowrap = True # type: ignore[attr-defined]
return nontransparent_wrapper # type: ignore[return-value]


def _get_local_method_names(
cls: Any, exclude: Iterable[str] = ()
) -> Tuple[str, ...]:
Expand Down Expand Up @@ -955,6 +1029,7 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None:
# We wrap user-defined methods including setup and __call__ to enforce
# a number of different checks and to provide clear error messages.
cls._verify_single_or_no_compact()
cls._find_nontransparent_methods()
cls._wrap_module_attributes()
# Set empty class defaults.
cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined]
Expand Down Expand Up @@ -1046,6 +1121,17 @@ def _verify_single_or_no_compact(cls):
if n_compact_fns > 1:
raise errors.MultipleMethodsCompactError()

@classmethod
def _find_nontransparent_methods(cls):
"""Finds all nontransparent methods in the class."""
methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)]
nontransparent_fns = tuple(
method_name
for method_name in methods
if hasattr(getattr(cls, method_name), 'nontransparent')
)
cls._nontransparent_methods = nontransparent_fns

@classmethod
def _wrap_module_attributes(cls):
"""Wraps user-defined non-inherited methods and descriptors with state
Expand Down Expand Up @@ -1347,6 +1433,7 @@ def _register_submodules(self, name, val):

def adopt_attr_modules(cache, queue, suffix, subvalue):
if isinstance(subvalue, Module):
current_name = subvalue.name
adopted_name = None
if subvalue.parent is None:
# Preserve sharing-by-reference relationships during adoption
Expand All @@ -1366,7 +1453,11 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
if subvalue.name is None:
object.__setattr__(subvalue, 'parent', self)
if adopted_name is None:
adopted_name = f'{name}{suffix}'
adopted_name = (
f'{name}{suffix}'
if not isinstance(subvalue, NonTransparent)
else current_name
)
object.__setattr__(subvalue, 'name', adopted_name)
queue.append(subvalue)
return subvalue
Expand Down Expand Up @@ -1397,6 +1488,15 @@ def _try_setup(self, shallow: bool = False) -> None:
self._register_submodules(field.name, getattr(self, field.name))
if not shallow:
self.setup()
# create NonTransparent Modules
for name in self._nontransparent_methods:
inner_fun = getattr(type(self), name).inner_fun
setattr(
self,
f'{name}_nontransparent',
NonTransparent(inner_fun, lambda: self, name=name),
)

# We run static checks abstractly once for setup before any transforms
# to detect name collisions and other python errors.
elif self._state.setup_called == SetupState.NEW:
Expand Down Expand Up @@ -2835,3 +2935,12 @@ def init_wrapper(*args, **kwargs):
return init_fn(*args, **kwargs)[1]

return init_wrapper


class NonTransparent(Module):
fn: Callable
module_fn: Callable[[], Module]

@compact
def __call__(self, *args, **kwargs) -> Any:
return self.fn(self.module_fn(), *args, **kwargs)
27 changes: 27 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,33 @@ def my_property(self):
self.assertEqual(obj_loaded.b, 'ok')
self.assertEqual(obj_loaded.my_property, 'okok')

def test_nontransparent(self):
class Foo(nn.Module):
@nn.nontransparent
def up(self, x):
return nn.Dense(3)(x)

@nn.nontransparent
def down(self, x):
return nn.Dense(3)(x)

@nn.compact
def __call__(self, x):
return self.up(x) + self.down(x) + nn.Dense(3)(x)

m = Foo()

self.assertEqual(set(m._nontransparent_methods), {'up', 'down'})

variables = m.init(random.PRNGKey(0), jnp.zeros((1, 3)))
params = variables['params']

self.assertIn('Dense_0', params)
self.assertIn('down', params)
self.assertIn('up', params)
self.assertIn('Dense_0', params['down'])
self.assertIn('Dense_0', params['up'])


class LeakTests(absltest.TestCase):
def test_tracer_leaks(self):
Expand Down

0 comments on commit 780c9b4

Please sign in to comment.