Skip to content

Commit

Permalink
Merge pull request #97 from basf/AB_layer
Browse files Browse the repository at this point in the history
adding AB layernorm and weight decay to Mamba
  • Loading branch information
AnFreTh authored Aug 2, 2024
2 parents 56801dd + 22dad68 commit 0d4442a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 9 deletions.
51 changes: 48 additions & 3 deletions mambular/arch_utils/mamba_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=True,
):
super().__init__()

Expand All @@ -66,6 +69,9 @@ def __init__(
activation,
bidirectional,
use_learnable_interaction,
layer_norm_eps,
AB_weight_decay,
AB_layer_norm,
)
for _ in range(n_layers)
]
Expand Down Expand Up @@ -105,6 +111,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
):
super().__init__()

Expand Down Expand Up @@ -149,8 +158,11 @@ def __init__(
activation=activation,
bidirectional=bidirectional,
use_learnable_interaction=use_learnable_interaction,
layer_norm_eps=layer_norm_eps,
AB_weight_decay=AB_weight_decay,
AB_layer_norm=AB_layer_norm,
)
self.norm = norm(d_model)
self.norm = norm(d_model, eps=layer_norm_eps)

def forward(self, x):
output = self.layers(self.norm(x)) + x
Expand Down Expand Up @@ -189,6 +201,9 @@ def __init__(
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AB_weight_decay=False,
AB_layer_norm=False,
):
super().__init__()
self.d_inner = d_model * expand_factor
Expand Down Expand Up @@ -239,6 +254,7 @@ def __init__(
elif dt_init == "random":
nn.init.uniform_(self.dt_proj_fwd.weight, -dt_init_std, dt_init_std)
if self.bidirectional:

nn.init.uniform_(self.dt_proj_bwd.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
Expand All @@ -262,17 +278,35 @@ def __init__(

A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
self.A_log_fwd = nn.Parameter(torch.log(A))
self.D_fwd = nn.Parameter(torch.ones(self.d_inner))

if self.bidirectional:
self.A_log_bwd = nn.Parameter(torch.log(A))
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

if not AB_weight_decay:
self.A_log_fwd._no_weight_decay = True
self.D_fwd._no_weight_decay = True

self.D_fwd = nn.Parameter(torch.ones(self.d_inner))
if self.bidirectional:
self.D_bwd = nn.Parameter(torch.ones(self.d_inner))

if not AB_weight_decay:
self.A_log_bwd._no_weight_decay = True
self.D_bwd._no_weight_decay = True

self.out_proj = nn.Linear(self.d_inner, d_model, bias=bias)
self.dt_rank = dt_rank
self.d_state = d_state

if AB_layer_norm:
self.dt_layernorm = RMSNorm(self.dt_rank, eps=layer_norm_eps)
self.B_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
self.C_layernorm = RMSNorm(self.d_state, eps=layer_norm_eps)
else:
self.dt_layernorm = None
self.B_layernorm = None
self.C_layernorm = None

def forward(self, x):
_, L, _ = x.shape

Expand Down Expand Up @@ -316,6 +350,15 @@ def forward(self, x):

return output

def _apply_layernorms(self, dt, B, C):
if self.dt_layernorm is not None:
dt = self.dt_layernorm(dt)
if self.B_layernorm is not None:
B = self.B_layernorm(B)
if self.C_layernorm is not None:
C = self.C_layernorm(C)
return dt, B, C

def ssm(self, x, forward=True):
if forward:
A = -torch.exp(self.A_log_fwd.float())
Expand All @@ -324,6 +367,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_fwd(delta))
else:
A = -torch.exp(self.A_log_bwd.float())
Expand All @@ -332,6 +376,7 @@ def ssm(self, x, forward=True):
delta, B, C = torch.split(
deltaBC, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
delta, B, C = self._apply_layernorms(delta, B, C)
delta = F.softplus(self.dt_proj_bwd(delta))

y = self.selective_scan_seq(x, delta, A, B, C, D)
Expand Down
26 changes: 20 additions & 6 deletions mambular/base_models/mambular.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,33 @@ def __init__(
use_learnable_interaction=self.hparams.get(
"use_learnable_interactions", config.use_learnable_interaction
),
AB_weight_decay=self.hparams.get("AB_weight_decay", config.AB_weight_decay),
AB_layer_norm=self.hparams.get("AB_layer_norm", config.AB_layer_norm),
layer_norm_eps=self.hparams.get("layer_norm_eps", config.layer_norm_eps),
)

norm_layer = self.hparams.get("norm", config.norm)
if norm_layer == "RMSNorm":
self.norm_f = RMSNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = RMSNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "LayerNorm":
self.norm_f = LayerNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = LayerNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "BatchNorm":
self.norm_f = BatchNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = BatchNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "InstanceNorm":
self.norm_f = InstanceNorm(self.hparams.get("d_model", config.d_model))
self.norm_f = InstanceNorm(
self.hparams.get("d_model", config.d_model), eps=config.layer_norm_eps
)
elif norm_layer == "GroupNorm":
self.norm_f = GroupNorm(1, self.hparams.get("d_model", config.d_model))
self.norm_f = GroupNorm(
1,
self.hparams.get("d_model", config.d_model),
eps=config.layer_norm_eps,
)
elif norm_layer == "LearnableLayerScaling":
self.norm_f = LearnableLayerScaling(
self.hparams.get("d_model", config.d_model)
Expand Down
7 changes: 7 additions & 0 deletions mambular/configs/mambular_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class DefaultMambularConfig:
Whether to append a cls to the end of each 'sequence'.
shuffle_embeddings : bool, default=False.
Whether to shuffle the embeddings before being passed to the Mamba layers.
layer_norm_eps : float, default=1e-05
Epsilon value for layer normalization.
AB_weight_decay : bool, default=False
wether weight decay is also applied to A-B matrices
"""

lr: float = 1e-04
Expand Down Expand Up @@ -107,3 +111,6 @@ class DefaultMambularConfig:
use_learnable_interaction: bool = False
use_cls: bool = False
shuffle_embeddings: bool = False
layer_norm_eps: float = 1e-05
AB_weight_decay: bool = False
AB_layer_norm: bool = True

0 comments on commit 0d4442a

Please sign in to comment.