Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Change interface to Mamba causal_conv1d_update for continuous batching #8012

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation) {
bool silu_activation,
const c10::optional<at::Tensor> &conv_state_indices_) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
Expand All @@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x,
const int width = weight.size(-1);

CHECK_SHAPE(x, batch_size, dim);
CHECK_SHAPE(conv_state, batch_size, dim, width);
CHECK_SHAPE(weight, dim, width);

TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
Expand All @@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x,
params.conv_state_c_stride = conv_state.stride(1);
params.conv_state_l_stride = conv_state.stride(2);

if (conv_state_indices_.has_value()) {
auto conv_state_indices = conv_state_indices_.value();
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
TORCH_CHECK(conv_state_indices.is_cuda());
TORCH_CHECK(conv_state_indices.stride(0) == 1)
CHECK_SHAPE(conv_state_indices, batch_size);

int conv_state_entries = conv_state.size(0);
CHECK_SHAPE(conv_state, conv_state_entries, dim, width);

params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
} else {
CHECK_SHAPE(conv_state, batch_size, dim, width);
params.conv_state_indices_ptr = nullptr;
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
Expand Down Expand Up @@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int channel_id = blockIdx.y * kNThreads + tidx;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride;
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride

// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;

weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride;
Expand Down
4 changes: 4 additions & 0 deletions csrc/mamba/causal_conv1d/causal_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ struct ConvParamsBase {

void *__restrict__ conv_state_ptr;

// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
int32_t *__restrict__ conv_state_indices_ptr;

void *__restrict__ seq_idx_ptr;

// No __restrict__ since initial_states could be the same as final_states.
Expand Down
9 changes: 4 additions & 5 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,10 @@ std::vector<torch::Tensor> selective_scan_fwd(
const c10::optional<torch::Tensor>& index_,
const c10::optional<torch::Tensor>& x);

at::Tensor causal_conv1d_update(const at::Tensor& x,
const at::Tensor& conv_state,
const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
bool silu_activation);
at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias, bool silu_activation,
const c10::optional<at::Tensor>& conv_state_indices);

at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
Expand Down
5 changes: 3 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"causal_conv1d_update(Tensor! x,"
"Tensor! conv_state,"
"Tensor! weight,"
"Tensor? bias_,"
"bool silu_activation) -> Tensor");
"Tensor? bias,"
"bool silu_activation,"
"Tensor? conv_state_indices) -> Tensor");
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);

ops.def(
Expand Down
58 changes: 58 additions & 0 deletions tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,

assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 4, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
silu_activation, itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2

# set seed
torch.random.manual_seed(0)
batch = 64

x = torch.randn(batch, dim, device=device, dtype=itype)

total_entries = 10 * batch
conv_state = torch.randn(total_entries,
dim,
width,
device=device,
dtype=itype)
conv_state_indices = torch.randperm(total_entries)[:batch].to(
dtype=torch.int32, device=device)

weight = torch.randn(dim,
width,
device=device,
dtype=itype,
requires_grad=True)
if has_bias:
bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
else:
bias = None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=conv_state_indices)
out_ref = causal_conv1d_update_ref(x,
conv_state_ref,
weight,
bias,
activation=activation)

print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
14 changes: 10 additions & 4 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
silu_activation)


def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor, bias_: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
conv_state_indices: Optional[torch.Tensor],
) -> torch.Tensor:
return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_,
silu_activation)
silu_activation,
conv_state_indices)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py

from typing import Optional

Expand Down Expand Up @@ -70,17 +71,22 @@ def causal_conv1d_update(x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Optional[str] = None):
activation: Optional[str] = None,
conv_state_indices: Optional[torch.Tensor] = None):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.

out: (batch, dim)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
activation_bool = activation in ["silu", "swish"]
return ops.causal_conv1d_update(x, conv_state, weight, bias,
activation_bool)
activation_bool, conv_state_indices)
Loading