diff --git a/flax/linen/module.py b/flax/linen/module.py index d2f0eee77f..4c284af288 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1474,6 +1474,7 @@ def clone_fn(m: Module) -> Module: # _map_submodules will map over all submodules inside attrs # value here can be any pytree, non-module values are ignored for field_name, value in attrs.items(): + if field_name == 'parent': continue attrs[field_name] = _map_submodules(clone_fn, value) module = self.__class__(**attrs) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index adbedb0a90..a76d368bae 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -2788,6 +2788,27 @@ def __call__(self, x): with self.assertRaises(errors.NameInUseError): vs = foo.init(k, x) + def test_internal_deep_clone(self): + class Child(nn.Module): + @nn.compact + def __call__(self, x): + w = self.param('w', nn.initializers.zeros, (5, x.shape[1])) + return x @ w + class Parent(nn.Module): + num_layers: int + child_template: Child + @nn.compact + def __call__(self, x): + for i in range(self.num_layers): + x = self.child_template.clone(parent=self, _deep_clone=True, name=None)(x) + return x + + model = Parent(num_layers=2, child_template=Child()) + x = jnp.ones((32, 5)) + variables = model.init(jax.random.key(0), x) + output = model.apply(variables, x) + self.assertTrue(jnp.all(variables['params']['Child_0']['w'] == variables['params']['Child_1']['w'])) + class FrozenDictTests(absltest.TestCase):