diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 4b19940e79..b6e8abba48 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -90,6 +90,7 @@ Variable as Variable, apply as apply, compact as compact, + nontransparent as nontransparent, disable_named_call as disable_named_call, enable_named_call as enable_named_call, init_with_output as init_with_output, diff --git a/flax/linen/module.py b/flax/linen/module.py index 8333ff5e71..c2f5b2bdf9 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -538,6 +538,80 @@ def nowrap(fun: _CallableT) -> _CallableT: return fun +def nontransparent(fun: _CallableT) -> _CallableT: + """Creates compact submodules from a method. + + This is a decorator that allows you to define compact submodules from a + method. It's intention is to make it easier to port code Haiku code to Flax + by providing the same functionality. + + Example:: + + >>> import flax.linen as nn + >>> import jax + >>> import jax.numpy as jnp + >>> from flax.core import pretty_repr + ... + >>> class Foo(nn.Module): + ... @nn.nontransparent + ... def up(self, x): + ... return nn.Dense(3)(x) + ... + ... @nn.nontransparent + ... def down(self, x): + ... return nn.Dense(3)(x) + ... + ... @nn.compact + ... def __call__(self, x): + ... return self.up(x) + self.down(x) + nn.Dense(3)(x) + ... + >>> module = Foo() + >>> variables = module.init(jax.random.PRNGKey(0), jnp.ones([1, 3])) + >>> print(pretty_repr(variables['params'])) + { + Dense_0: { + bias: (3,), + kernel: (3, 3), + }, + down: { + Dense_0: { + bias: (3,), + kernel: (3, 3), + }, + }, + up: { + Dense_0: { + bias: (3,), + kernel: (3, 3), + }, + }, + } + + Args: + fun: The Module method to mark as nontransparent. + + Returns: + The given function ``fun`` marked as nontransparent. + """ + + @functools.wraps(fun) + def nontransparent_wrapper(self: nn.Module, *args, **kwargs): + name = fun.__name__ + if not hasattr(self, f'{name}_nontransparent'): + raise ValueError( + f'Cannot call nontransparent method {name!r} on a Module that has not been ' + f'setup. This is likely because you are calling {name!r} ' + 'from outside of init or apply.' + ) + module = getattr(self, f'{name}_nontransparent') + return module(*args, **kwargs) + + nontransparent_wrapper.nontransparent = True # type: ignore[attr-defined] + nontransparent_wrapper.inner_fun = fun # type: ignore[attr-defined] + nontransparent_wrapper.nowrap = True # type: ignore[attr-defined] + return nontransparent_wrapper # type: ignore[return-value] + + def _get_local_method_names( cls: Any, exclude: Iterable[str] = () ) -> Tuple[str, ...]: @@ -955,6 +1029,7 @@ def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None: # We wrap user-defined methods including setup and __call__ to enforce # a number of different checks and to provide clear error messages. cls._verify_single_or_no_compact() + cls._find_nontransparent_methods() cls._wrap_module_attributes() # Set empty class defaults. cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined] @@ -1046,6 +1121,17 @@ def _verify_single_or_no_compact(cls): if n_compact_fns > 1: raise errors.MultipleMethodsCompactError() + @classmethod + def _find_nontransparent_methods(cls): + """Finds all nontransparent methods in the class.""" + methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)] + nontransparent_fns = tuple( + method_name + for method_name in methods + if hasattr(getattr(cls, method_name), 'nontransparent') + ) + cls._nontransparent_methods = nontransparent_fns + @classmethod def _wrap_module_attributes(cls): """Wraps user-defined non-inherited methods and descriptors with state @@ -1347,6 +1433,7 @@ def _register_submodules(self, name, val): def adopt_attr_modules(cache, queue, suffix, subvalue): if isinstance(subvalue, Module): + current_name = subvalue.name adopted_name = None if subvalue.parent is None: # Preserve sharing-by-reference relationships during adoption @@ -1366,7 +1453,11 @@ def adopt_attr_modules(cache, queue, suffix, subvalue): if subvalue.name is None: object.__setattr__(subvalue, 'parent', self) if adopted_name is None: - adopted_name = f'{name}{suffix}' + adopted_name = ( + f'{name}{suffix}' + if not isinstance(subvalue, NonTransparent) + else current_name + ) object.__setattr__(subvalue, 'name', adopted_name) queue.append(subvalue) return subvalue @@ -1397,6 +1488,15 @@ def _try_setup(self, shallow: bool = False) -> None: self._register_submodules(field.name, getattr(self, field.name)) if not shallow: self.setup() + # create NonTransparent Modules + for name in self._nontransparent_methods: + inner_fun = getattr(type(self), name).inner_fun + setattr( + self, + f'{name}_nontransparent', + NonTransparent(inner_fun, lambda: self, name=name), + ) + # We run static checks abstractly once for setup before any transforms # to detect name collisions and other python errors. elif self._state.setup_called == SetupState.NEW: @@ -2835,3 +2935,12 @@ def init_wrapper(*args, **kwargs): return init_fn(*args, **kwargs)[1] return init_wrapper + + +class NonTransparent(Module): + fn: Callable + module_fn: Callable[[], Module] + + @compact + def __call__(self, *args, **kwargs) -> Any: + return self.fn(self.module_fn(), *args, **kwargs) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 8cd792859b..8d1856f82c 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2487,6 +2487,33 @@ def my_property(self): self.assertEqual(obj_loaded.b, 'ok') self.assertEqual(obj_loaded.my_property, 'okok') + def test_nontransparent(self): + class Foo(nn.Module): + @nn.nontransparent + def up(self, x): + return nn.Dense(3)(x) + + @nn.nontransparent + def down(self, x): + return nn.Dense(3)(x) + + @nn.compact + def __call__(self, x): + return self.up(x) + self.down(x) + nn.Dense(3)(x) + + m = Foo() + + self.assertEqual(set(m._nontransparent_methods), {'up', 'down'}) + + variables = m.init(random.PRNGKey(0), jnp.zeros((1, 3))) + params = variables['params'] + + self.assertIn('Dense_0', params) + self.assertIn('down', params) + self.assertIn('up', params) + self.assertIn('Dense_0', params['down']) + self.assertIn('Dense_0', params['up']) + class LeakTests(absltest.TestCase): def test_tracer_leaks(self):