From 5587371b12f1316944289bb99c8aac552a279fe7 Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Fri, 3 Nov 2023 19:43:25 -0700 Subject: [PATCH] Fix Module.clone in deepclone mode for internal usage. --- flax/linen/module.py | 2 ++ tests/linen/linen_module_test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/flax/linen/module.py b/flax/linen/module.py index 1b9680e64..0770391e6 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1495,6 +1495,8 @@ def clone_fn(m: Module) -> Module: # _map_submodules will map over all submodules inside attrs # value here can be any pytree, non-module values are ignored for field_name, value in attrs.items(): + if field_name == 'parent': + continue attrs[field_name] = _map_submodules(clone_fn, value) module = self.__class__(**attrs) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 3abf5b66b..c58331629 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2712,6 +2712,36 @@ def __call__(self, x): with self.assertRaises(errors.NameInUseError): vs = foo.init(k, x) + def test_internal_deep_clone(self): + class Child(nn.Module): + @nn.compact + def __call__(self, x): + w = self.param('w', nn.initializers.zeros, (5, x.shape[1])) + return x @ w + + class Parent(nn.Module): + num_layers: int + child_template: Child + + @nn.compact + def __call__(self, x): + for i in range(self.num_layers): + x = self.child_template.clone( + parent=self, _deep_clone=True, name=None + )(x) + return x + + model = Parent(num_layers=2, child_template=Child()) + x = jnp.ones((32, 5)) + variables = model.init(jax.random.key(0), x) + output = model.apply(variables, x) + self.assertTrue( + jnp.all( + variables['params']['Child_0']['w'] + == variables['params']['Child_1']['w'] + ) + ) + class FrozenDictTests(absltest.TestCase): def test_frozendict_flag(self):