Skip to content

Commit

Permalink
Implement varlen generation
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 1, 2024
1 parent 3462302 commit 03a38fb
Show file tree
Hide file tree
Showing 8 changed files with 372 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla

## Installation

- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
- [Option] `pip install causal-conv1d>=1.4.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
- `pip install mamba-ssm`: the core Mamba package.

It can also be built from source with `pip install .` from this repository.
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/models/mixer_seq_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs):
residual = None
for layer in self.layers:
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
)
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
Expand Down
46 changes: 33 additions & 13 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None

try:
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
except ImportError:
causal_conv1d_varlen_states = None

try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
Expand Down Expand Up @@ -144,7 +149,7 @@ def __init__(
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs)

def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None):
"""
u: (batch, seqlen, hidden_dim) if seqlen=None.
If seqlen is not None, u is (batch * seqlen, hidden_dim). This is so that when we
Expand All @@ -161,7 +166,8 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):

conv_state, ssm_state = None, None
if inference_params is not None:
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
inference_batch = cu_seqlens.shape[0] - 1 if cu_seqlens is not None else batch
conv_state, ssm_state = self._get_states_from_cache(inference_params, inference_batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(u, conv_state, ssm_state)
Expand Down Expand Up @@ -206,14 +212,22 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
dim=-1
)
if conv_state is not None:
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
xBC_t = rearrange(xBC, "b l d -> b d l")
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
if cu_seqlens is None:
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
xBC_t = rearrange(xBC, "b l d -> b d l")
conv_state.copy_(F.pad(xBC_t, (self.d_conv - xBC_t.shape[-1], 0))) # Update state (B D W)
else:
assert causal_conv1d_varlen_states is not None, "varlen inference requires causal_conv1d package"
assert batch == 1, "varlen inference only supports batch dimension 1"
conv_varlen_states = causal_conv1d_varlen_states(
xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1]
)
conv_state.copy_(conv_varlen_states)
assert self.activation in ["silu", "swish"]
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
xBC = self.act(
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
else:
xBC = causal_conv1d_fn(
Expand All @@ -235,12 +249,18 @@ def forward(self, u, seqlen=None, seq_idx=None, inference_params=None):
dt_bias=self.dt_bias,
dt_softplus=True,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
**dt_limit_kwargs,
return_final_states=ssm_state is not None,
return_varlen_states=cu_seqlens is not None and inference_params is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
y, last_state, *rest = y
if cu_seqlens is None:
ssm_state.copy_(last_state)
else:
varlen_states = rest[0]
ssm_state.copy_(varlen_states)
y = rearrange(y, "b l h p -> b l (h p)")
if self.rmsnorm:
y = self.norm(y, z)
Expand Down Expand Up @@ -322,8 +342,8 @@ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs)
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype
)
batch_size, self.d_conv, self.conv1d.weight.shape[0], device=device, dtype=conv_dtype
).transpose(1, 2)
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
ssm_state = torch.zeros(
batch_size, self.nheads, self.headdim, self.d_state, device=device, dtype=ssm_dtype
Expand All @@ -336,11 +356,11 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states
batch_shape = (batch_size,)
conv_state = torch.zeros(
batch_size,
self.conv1d.weight.shape[0],
self.d_conv,
self.conv1d.weight.shape[0],
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
).transpose(1, 2)
ssm_state = torch.zeros(
batch_size,
self.nheads,
Expand Down
120 changes: 120 additions & 0 deletions mamba_ssm/ops/triton/ssd_chunk_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,97 @@ def _chunk_state_bwd_ddAcs_stable_kernel(
tl.atomic_add(ddA_cumsum_ptrs + stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size - 1)


@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
],
key=['hdim', 'dstate', 'chunk_size'],
)
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices
x_ptr, b_ptr, dt_ptr, dA_cumsum_ptr, chunk_states_ptr, cu_seqlens_ptr, states_ptr,
# Matrix dimensions
hdim, dstate, chunk_size,
seqlen, nheads_ngroups_ratio,
# Strides
stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_b_seqlen, stride_b_head, stride_b_dstate,
stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_chunk_states_chunk, stride_chunk_states_head, stride_chunk_states_hdim, stride_chunk_states_dstate,
stride_states_batch, stride_states_head, stride_states_hdim, stride_states_dstate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
pid_c = (end_idx - 1) // chunk_size
b_ptr += pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen)
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize

chunk_size_limit = end_idx - pid_c * chunk_size
start_idx = tl.load(cu_seqlens_ptr + pid_b)
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(x_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k) & (offs_k[None, :] >= start_idx_cur - k), other=0.0)
b = tl.load(b_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate) & (offs_k[:, None] >= start_idx_cur - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32)
scale = tl.where((offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize

# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
if start_idx < pid_c * chunk_size:
chunk_states_ptrs = chunk_states_ptr + (offs_m[:, None] * stride_chunk_states_hdim + offs_n[None, :] * stride_chunk_states_dstate)
chunk_states = tl.load(chunk_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
# scale = tl.where(start_idx < pid_c * chunk_size, tl.exp(dA_cs_last), 0.0)
scale = tl.exp(dA_cs_last)
acc += chunk_states * scale

states = acc.to(states_ptr.dtype.element_ty)

states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)


def _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))):
batch, seqlen, nheads = dt.shape
assert A.shape == (nheads,)
Expand Down Expand Up @@ -790,6 +881,35 @@ def _chunk_state_bwd_ddAcs_stable(B, x, dt, dA_cumsum, dstates, seq_idx=None):
return ddA_cumsum


def chunk_state_varlen(B, x, dt, dA_cumsum, cu_seqlens, chunk_states):
total_seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
assert nheads % ngroups == 0
assert B.shape == (total_seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
states = torch.empty(batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device)
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch, nheads)
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x, B, dt, dA_cumsum, chunk_states, cu_seqlens, states,
headdim, dstate, chunk_size,
total_seqlen, nheads // ngroups,
x.stride(0), x.stride(1), x.stride(2),
B.stride(0), B.stride(1), B.stride(2),
dt.stride(1), dt.stride(0), dt.stride(2),
dA_cumsum.stride(1), dA_cumsum.stride(0), dA_cumsum.stride(2),
chunk_states.stride(0), chunk_states.stride(1), chunk_states.stride(2), chunk_states.stride(3),
states.stride(0), states.stride(1), states.stride(2), states.stride(3),
)
return states


class ChunkStateFn(torch.autograd.Function):

@staticmethod
Expand Down
Loading

0 comments on commit 03a38fb

Please sign in to comment.