Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[linen] add share_scope #4102

Merged
merged 1 commit into from
Aug 1, 2024
Merged

[linen] add share_scope #4102

merged 1 commit into from
Aug 1, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Jul 23, 2024

What does this PR do?

Addsnn.share_scope can be used to share a scope between two Modules. This means that any parameters created by any of the Modules will be added to their common scope. This is useful when you want to wrap a Module and extend its functionality without changing the parameter structure.

import flax.linen as nn
import jax
from jax import numpy as jnp, random

class DenseLoRA(nn.Module):
  base: nn.Dense
  rank: int

  def setup(self):
    nn.share_scope(self, self.base)

  @nn.compact
  def __call__(self, x: jax.Array):
    din, dout = x.shape[-1], self.base.features
    A = self.param('A', nn.zeros_init(), (din, self.rank))
    B = self.param('B', nn.zeros_init(), (self.rank, dout))
    return self.base(x) + x @ A @ B

class Model(nn.Module):
  @nn.compact
  def __call__(self, x: jax.Array):
    base = nn.Dense(10)  # base scope
    return DenseLoRA(base, rank=2)(x)  # reuse base scope

model = Model()

params = model.init(random.key(0), jnp.ones((1, 5)))['params']
list(params['Dense_0'].keys())
# ['A', 'B', 'kernel', 'bias']

class Transparent(Module, Generic[M]):
"""A Module that shares its scope with an inner Module. This combinator is useful
when you want to wrap a Module and extend itsfunctionality without changing the
parameter structure.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also highlight the difference between

class DenseLora(Dense):
...

vs

class DenseLora(Transparent[Dense]):
...

(The important distinction is that Transparent replaces inheritance with composition, that allows to swap out implementation of inner module with Dense implementation (as long as it has compatible interface), without having to create separate DenseLora for each implementation.

@cgarciae
Copy link
Collaborator Author

cgarciae commented Jul 24, 2024

@marksandler2 wondering if we should just expose _assign_shared_scope as e.g. nn.share_scope an have users manually do the call instead of introducing the Transparent abstraction? The LoRA example would look like this:

class DenseLoRA(nn.Module):
  inner: nn.Dense
  rank: int

  def setup(self):
    nn.share_scope(self, self.inner)

  @nn.compact
  def __call__(self, x: jax.Array):
    din, dout = x.shape[-1], self.inner.features
    A = self.param('A', nn.zeros_init(), (din, self.rank))
    B = self.param('B', nn.zeros_init(), (self.rank, dout))
    return self.inner(x) + x @ A @ B

class Model(nn.Module):
  @nn.compact
  def __call__(self, x: jax.Array):
    return DenseLoRA(nn.Dense(10), rank=2)(x)

@marksandler2
Copy link
Contributor

Hmm, i like this new proposed function a lot actually. It certainly feels a lot less heavy weight (e.g. we don't need to muck around with Transparent[...] inheritance, and the logic could be made a lot more flexible.

@cgarciae cgarciae force-pushed the linen-transparent branch 2 times, most recently from 62f2f4a to a52ffdc Compare July 25, 2024 12:24
@cgarciae cgarciae changed the title [linen] add Transparent [linen] add share_scope Jul 25, 2024
@marksandler2
Copy link
Contributor

Can we merge this?

@copybara-service copybara-service bot merged commit d20f594 into main Aug 1, 2024
18 checks passed
@copybara-service copybara-service bot deleted the linen-transparent branch August 1, 2024 13:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants