Skip to content

Commit

Permalink
Move wrapper to models and remove concat strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
yair-schiff committed Dec 20, 2023
1 parent e6ca69c commit 1ce8522
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 112 deletions.
5 changes: 2 additions & 3 deletions mamba_ssm/models/config_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@ class MambaConfig:
residual_in_fp32: bool = True
fused_add_norm: bool = True
pad_vocab_size_multiple: int = 8
bidirectional: bool = False,
bidirectional_strategy: Union[str, None] = None,
bidirectional_weight_tie: bool = False,
bidirectional: bool = False
bidirectional_strategy: Union[str, None] = None
106 changes: 92 additions & 14 deletions mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.nn as nn

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import MambaWrapper, Block
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

Expand All @@ -32,7 +32,6 @@ def create_block(
layer_idx=None,
bidirectional=False,
bidirectional_strategy=None,
bidirectional_weight_tie=False,
device=None,
dtype=None,
):
Expand All @@ -42,7 +41,6 @@ def create_block(
bidirectional_kwargs = {
"bidirectional": bidirectional,
"bidirectional_strategy": bidirectional_strategy,
"bidirectional_weight_tie": bidirectional_weight_tie,
}
mixer_cls = partial(MambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
norm_cls = partial(
Expand Down Expand Up @@ -92,6 +90,95 @@ def _init_weights(
p /= math.sqrt(n_residuals_per_layer * n_layer)


class MambaWrapper(nn.Module):
"""Thin wrapper around Mamba to support bi-directionality."""
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
bidirectional: bool = False,
bidirectional_strategy: Optional[str] = None,
device=None,
dtype=None,
):
super().__init__()
if bidirectional and bidirectional_strategy is None:
bidirectional_strategy = "add" # Default strategy: `add`
if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.mamba_fwd = Mamba(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
dt_rank=dt_rank,
dt_min=dt_min,
dt_max=dt_max,
dt_init=dt_init,
dt_scale=dt_scale,
dt_init_floor=dt_init_floor,
conv_bias=conv_bias,
bias=bias,
use_fast_path=use_fast_path, # Fused kernel options
layer_idx=layer_idx,
device=device,
dtype=dtype,
)
if bidirectional:
self.mamba_rev = Mamba(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
dt_rank=dt_rank,
dt_min=dt_min,
dt_max=dt_max,
dt_init=dt_init,
dt_scale=dt_scale,
dt_init_floor=dt_init_floor,
conv_bias=conv_bias,
bias=bias,
use_fast_path=use_fast_path, # Fused kernel options
layer_idx=layer_idx,
device=device,
dtype=dtype,
)
else:
self.mamba_rev = None

def forward(self, hidden_states, inference_params=None):
"""Bidirectional-enabled forward pass
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
out = self.mamba_fwd(hidden_states, inference_params=inference_params)
if self.bidirectional:
out_rev = self.mamba_rev(
hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
inference_params=inference_params
).flip(dims=(1,)) # Flip back for combining with forward hidden states
if self.bidirectional_strategy == "add":
out = out + out_rev
elif self.bidirectional_strategy == "ew_multiply":
out = out * out_rev
return out


class MixerModel(nn.Module):
def __init__(
self,
Expand All @@ -106,7 +193,6 @@ def __init__(
residual_in_fp32=False,
bidirectional: bool = False,
bidirectional_strategy: Optional[str] = None,
bidirectional_weight_tie: bool = False,
device=None,
dtype=None,
) -> None:
Expand Down Expand Up @@ -138,7 +224,6 @@ def __init__(
layer_idx=i,
bidirectional=bidirectional,
bidirectional_strategy=bidirectional_strategy,
bidirectional_weight_tie=bidirectional_weight_tie,
**factory_kwargs,
)
for i in range(n_layer)
Expand Down Expand Up @@ -208,7 +293,6 @@ def __init__(
pad_vocab_size_multiple = config.pad_vocab_size_multiple
bidirectional = config.bidirectional
bidirectional_strategy = config.bidirectional_strategy
bidirectional_weight_tie = config.bidirectional_weight_tie
factory_kwargs = {"device": device, "dtype": dtype}

super().__init__()
Expand All @@ -225,13 +309,9 @@ def __init__(
residual_in_fp32=residual_in_fp32,
bidirectional=bidirectional,
bidirectional_strategy=bidirectional_strategy,
bidirectional_weight_tie=bidirectional_weight_tie,
**factory_kwargs,
)
if self.config.bidirectional and self.config.bidirectional_strategy == "concat":
self.lm_head = nn.Linear(d_model * 2, vocab_size, bias=False, **factory_kwargs)
else:
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

# Initialize weights and apply final processing
self.apply(
Expand All @@ -241,9 +321,7 @@ def __init__(
**(initializer_cfg if initializer_cfg is not None else {}),
)
)
# For bi-directionality using the `concat` strategy, we cannot tie weights since concatenation doubles d_model
if not self.config.bidirectional or self.config.bidirectional_strategy != "concat":
self.tie_weights()
self.tie_weights()

def tie_weights(self):
self.lm_head.weight = self.backbone.embedding.weight
Expand Down
95 changes: 0 additions & 95 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,101 +294,6 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
return conv_state, ssm_state


class MambaWrapper(nn.Module):
"""Thin wrapper around Mamba to support bi-directionality."""
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
bidirectional: bool = False,
bidirectional_strategy: Optional[str] = None,
bidirectional_weight_tie: Optional[bool] = False,
device=None,
dtype=None,
):
super().__init__()
if bidirectional and bidirectional_strategy is None:
bidirectional_strategy = "add" # Default strategy: `add`
if bidirectional and bidirectional_strategy not in ["add", "concat", "ew_multiply"]:
raise NotImplementedError(f"{bidirectional_strategy} strategy for bi-directionality is not implemented!")
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.bidirectional_weight_tie = bidirectional_weight_tie
self.mamba_fwd = Mamba(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
dt_rank=dt_rank,
dt_min=dt_min,
dt_max=dt_max,
dt_init=dt_init,
dt_scale=dt_scale,
dt_init_floor=dt_init_floor,
conv_bias=conv_bias,
bias=bias,
use_fast_path=use_fast_path, # Fused kernel options
layer_idx=layer_idx,
device=device,
dtype=dtype,
)
if bidirectional and not bidirectional_weight_tie:
# If not weight tying, instantiate separate Mamba
self.mamba_rev = Mamba(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
dt_rank=dt_rank,
dt_min=dt_min,
dt_max=dt_max,
dt_init=dt_init,
dt_scale=dt_scale,
dt_init_floor=dt_init_floor,
conv_bias=conv_bias,
bias=bias,
use_fast_path=use_fast_path, # Fused kernel options
layer_idx=layer_idx,
device=device,
dtype=dtype,
)
else:
self.mamba_rev = None

def forward(self, hidden_states, inference_params=None):
"""Bidirectional-enabled forward pass
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
out = self.mamba_fwd(hidden_states, inference_params=inference_params)
if self.bidirectional:
mamba_rev = self.mamba_rev if self.mamba_rev is not None else self.mamba_fwd
out_rev = mamba_rev(
hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
inference_params=inference_params
).flip(dims=(1,)) # Flip back for combining with forward hidden states
if self.bidirectional_strategy == "add":
out = out + out_rev
elif self.bidirectional_strategy == "concat":
out = torch.cat([out, out_rev], dim=-1)
elif self.bidirectional_strategy == "ew_multiply":
out = out * out_rev
return out


class Block(nn.Module):
def __init__(
self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
Expand Down

0 comments on commit 1ce8522

Please sign in to comment.