Skip to content

Commit

Permalink
use create_block from mamba_ssm
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Jul 20, 2024
1 parent cb4e984 commit ecb1821
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 38 deletions.
2 changes: 1 addition & 1 deletion requirements/requirements-mamba.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
causal_conv1d>=1.1.0
causal_conv1d
mamba-ssm
2 changes: 1 addition & 1 deletion src/gluonts/torch/model/mamba/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
prediction_length: int,
context_length: Optional[int] = None,
num_layers: int = 2,
hidden_size: int = 40,
hidden_size: int = 64,
lr: float = 1e-3,
weight_decay: float = 1e-8,
patience: int = 10,
Expand Down
51 changes: 15 additions & 36 deletions src/gluonts/torch/model/mamba/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,39 +33,12 @@
)
from gluonts.itertools import prod
from gluonts.model import Input, InputSpec
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.models.mixer_seq_simple import create_block
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None

def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
device=None,
dtype=None,
):
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block = Block(
d_model,
mixer_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block

class MambaModel(nn.Module):
"""
Expand Down Expand Up @@ -133,9 +106,12 @@ def __init__(
num_feat_static_cat: int = 1,
cardinality: List[int] = [1],
embedding_dimension: Optional[List[int]] = None,
d_intermediate: int = 0,
num_layers: int = 2,
hidden_size: int = 40,
ssm_cfg=None,
hidden_size: int = 64,
ssm_cfg: Optional[dict]={"layer": "Mamba2",},
attn_layer_idx: Optional[List[int]] = None,
attn_cfg: Optional[dict] = None,
norm_epsilon: float = 1e-5,
rms_norm: bool = False,
fused_add_norm:bool = False,
Expand Down Expand Up @@ -195,8 +171,11 @@ def __init__(
self.layers = nn.ModuleList(
[
create_block(
hidden_size,
d_model=hidden_size,
d_intermediate=d_intermediate,
ssm_cfg=ssm_cfg,
attn_layer_idx=attn_layer_idx,
attn_cfg=attn_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
Expand All @@ -211,7 +190,7 @@ def __init__(
)

self.nonnegative_pred_samples = nonnegative_pred_samples
self.param_proj = distr_output.get_args_proj(self.d_model)
self.param_proj = distr_output.get_args_proj(hidden_size)


def describe_inputs(self, batch_size=1) -> InputSpec:
Expand Down Expand Up @@ -332,15 +311,15 @@ def mamba(self, mamba_input):
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
hidden_states = fused_add_norm_fn(
hidden_states = layer_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm_f, RMSNorm)
)
return hidden_states

Expand Down Expand Up @@ -540,7 +519,7 @@ def forward(
)
mamba_input = torch.cat((next_lags, next_features), dim=-1)

output = self.mamba(rnn_input)
output = self.mamba(mamba_input)

repeated_past_target = torch.cat(
(repeated_past_target, scaled_next_sample), dim=1
Expand Down

0 comments on commit ecb1821

Please sign in to comment.