diff --git a/flax/linen/module.py b/flax/linen/module.py index 1b9680e648..0770391e6e 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1495,6 +1495,8 @@ def clone_fn(m: Module) -> Module: # _map_submodules will map over all submodules inside attrs # value here can be any pytree, non-module values are ignored for field_name, value in attrs.items(): + if field_name == 'parent': + continue attrs[field_name] = _map_submodules(clone_fn, value) module = self.__class__(**attrs) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 3abf5b66b6..c58331629f 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2712,6 +2712,36 @@ def __call__(self, x): with self.assertRaises(errors.NameInUseError): vs = foo.init(k, x) + def test_internal_deep_clone(self): + class Child(nn.Module): + @nn.compact + def __call__(self, x): + w = self.param('w', nn.initializers.zeros, (5, x.shape[1])) + return x @ w + + class Parent(nn.Module): + num_layers: int + child_template: Child + + @nn.compact + def __call__(self, x): + for i in range(self.num_layers): + x = self.child_template.clone( + parent=self, _deep_clone=True, name=None + )(x) + return x + + model = Parent(num_layers=2, child_template=Child()) + x = jnp.ones((32, 5)) + variables = model.init(jax.random.key(0), x) + output = model.apply(variables, x) + self.assertTrue( + jnp.all( + variables['params']['Child_0']['w'] + == variables['params']['Child_1']['w'] + ) + ) + class FrozenDictTests(absltest.TestCase): def test_frozendict_flag(self):