Skip to content

Commit

Permalink
Fix Module.clone in deepclone mode for internal usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
levskaya committed Nov 4, 2023
1 parent 8d09772 commit 6eca28a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,7 @@ def clone_fn(m: Module) -> Module:
# _map_submodules will map over all submodules inside attrs
# value here can be any pytree, non-module values are ignored
for field_name, value in attrs.items():
if field_name == 'parent': continue
attrs[field_name] = _map_submodules(clone_fn, value)

module = self.__class__(**attrs)
Expand Down
21 changes: 21 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,6 +2788,27 @@ def __call__(self, x):
with self.assertRaises(errors.NameInUseError):
vs = foo.init(k, x)

def test_internal_deep_clone(self):
class Child(nn.Module):
@nn.compact
def __call__(self, x):
w = self.param('w', nn.initializers.zeros, (5, x.shape[1]))
return x @ w
class Parent(nn.Module):
num_layers: int
child_template: Child
@nn.compact
def __call__(self, x):
for i in range(self.num_layers):
x = self.child_template.clone(parent=self, _deep_clone=True, name=None)(x)
return x

model = Parent(num_layers=2, child_template=Child())
x = jnp.ones((32, 5))
variables = model.init(jax.random.key(0), x)
output = model.apply(variables, x)
self.assertTrue(jnp.all(variables['params']['Child_0']['w'] == variables['params']['Child_1']['w']))


class FrozenDictTests(absltest.TestCase):

Expand Down

0 comments on commit 6eca28a

Please sign in to comment.