Skip to content

Commit

Permalink
Fix Module.clone in deepclone mode for internal usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
levskaya committed Nov 4, 2023
1 parent c1023d9 commit 5587371
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,8 @@ def clone_fn(m: Module) -> Module:
# _map_submodules will map over all submodules inside attrs
# value here can be any pytree, non-module values are ignored
for field_name, value in attrs.items():
if field_name == 'parent':
continue
attrs[field_name] = _map_submodules(clone_fn, value)

module = self.__class__(**attrs)
Expand Down
30 changes: 30 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2712,6 +2712,36 @@ def __call__(self, x):
with self.assertRaises(errors.NameInUseError):
vs = foo.init(k, x)

def test_internal_deep_clone(self):
class Child(nn.Module):
@nn.compact
def __call__(self, x):
w = self.param('w', nn.initializers.zeros, (5, x.shape[1]))
return x @ w

class Parent(nn.Module):
num_layers: int
child_template: Child

@nn.compact
def __call__(self, x):
for i in range(self.num_layers):
x = self.child_template.clone(
parent=self, _deep_clone=True, name=None
)(x)
return x

model = Parent(num_layers=2, child_template=Child())
x = jnp.ones((32, 5))
variables = model.init(jax.random.key(0), x)
output = model.apply(variables, x)
self.assertTrue(
jnp.all(
variables['params']['Child_0']['w']
== variables['params']['Child_1']['w']
)
)


class FrozenDictTests(absltest.TestCase):
def test_frozendict_flag(self):
Expand Down

0 comments on commit 5587371

Please sign in to comment.