Skip to content

Commit

Permalink
[Kernel] Change interface to Mamba causal_conv1d_update for continuou…
Browse files Browse the repository at this point in the history
…s batching (vllm-project#8012)
  • Loading branch information
tlrmchlsmth authored and Jeffwan committed Sep 19, 2024
1 parent 3435fe8 commit f7c0189
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 16 deletions.
30 changes: 27 additions & 3 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation) {
bool silu_activation,
const c10::optional<at::Tensor> &conv_state_indices_) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand All @@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
const int width = weight.size(-1);

CHECK_SHAPE(x, batch_size, dim);
CHECK_SHAPE(conv_state, batch_size, dim, width);
CHECK_SHAPE(weight, dim, width);

TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
Expand All @@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
params.conv_state_c_stride = conv_state.stride(1);
params.conv_state_l_stride = conv_state.stride(2);

if (conv_state_indices_.has_value()) {
auto conv_state_indices = conv_state_indices_.value();
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
TORCH_CHECK(conv_state_indices.is_cuda());
TORCH_CHECK(conv_state_indices.stride(0) == 1)
CHECK_SHAPE(conv_state_indices, batch_size);

int conv_state_entries = conv_state.size(0);
CHECK_SHAPE(conv_state, conv_state_entries, dim, width);

params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
} else {
CHECK_SHAPE(conv_state, batch_size, dim, width);
params.conv_state_indices_ptr = nullptr;
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
Expand Down Expand Up @@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int channel_id = blockIdx.y * kNThreads + tidx;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride;
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride

// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;

weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride;
Expand Down
4 changes: 4 additions & 0 deletions csrc/mamba/causal_conv1d/causal_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct ConvParamsBase {

void *__restrict__ conv_state_ptr;

// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t *__restrict__ conv_state_indices_ptr;

void *__restrict__ seq_idx_ptr;

// No __restrict__ since initial_states could be the same as final_states.
Expand Down
9 changes: 4 additions & 5 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x);

at::Tensor causal_conv1d_update(const at::Tensor& x,
const at::Tensor& conv_state,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
bool silu_activation);
at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias, bool silu_activation,
const c10::optional<at::Tensor>& conv_state_indices);

at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
Expand Down
5 changes: 3 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation) -> Tensor");
"Tensor? bias,"
"bool silu_activation,"
"Tensor? conv_state_indices) -> Tensor");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

ops.def(
Expand Down
58 changes: 58 additions & 0 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,

assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2

# set seed
torch.random.manual_seed(0)
batch = 64

x = torch.randn(batch, dim, device=device, dtype=itype)

total_entries = 10 * batch
conv_state = torch.randn(total_entries,
dim,
width,
device=device,
dtype=itype)
conv_state_indices = torch.randperm(total_entries)[:batch].to(
dtype=torch.int32, device=device)

weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices)
out_ref = causal_conv1d_update_ref(x,
conv_state_ref,
weight,
bias,
activation=activation)

print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
14 changes: 10 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
silu_activation)


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
conv_state_indices: Optional[torch.Tensor],
) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation)
silu_activation,
conv_state_indices)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py

from typing import Optional

Expand Down Expand Up @@ -70,17 +71,22 @@ def causal_conv1d_update(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
activation: Optional[str] = None,
conv_state_indices: Optional[torch.Tensor] = None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
out: (batch, dim)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation_bool = activation in ["silu", "swish"]
return ops.causal_conv1d_update(x, conv_state, weight, bias,
activation_bool)
activation_bool, conv_state_indices)

0 comments on commit f7c0189

Please sign in to comment.