diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 88a64a8ece585..32261ec17d897 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -198,7 +198,8 @@ causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, - bool silu_activation) { + bool silu_activation, + const c10::optional &conv_state_indices_) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -216,7 +217,6 @@ causal_conv1d_update(const at::Tensor &x, const int width = weight.size(-1); CHECK_SHAPE(x, batch_size, dim); - CHECK_SHAPE(conv_state, batch_size, dim, width); CHECK_SHAPE(weight, dim, width); TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); @@ -241,6 +241,22 @@ causal_conv1d_update(const at::Tensor &x, params.conv_state_c_stride = conv_state.stride(1); params.conv_state_l_stride = conv_state.stride(2); + if (conv_state_indices_.has_value()) { + auto conv_state_indices = conv_state_indices_.value(); + TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) + TORCH_CHECK(conv_state_indices.is_cuda()); + TORCH_CHECK(conv_state_indices.stride(0) == 1) + CHECK_SHAPE(conv_state_indices, batch_size); + + int conv_state_entries = conv_state.size(0); + CHECK_SHAPE(conv_state, conv_state_entries, dim, width); + + params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + } else { + CHECK_SHAPE(conv_state, batch_size, dim, width); + params.conv_state_indices_ptr = nullptr; + } + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)x.get_device()}; @@ -646,8 +662,16 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int channel_id = blockIdx.y * kNThreads + tidx; input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor + // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. + const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr + ? batch_id + : params.conv_state_indices_ptr[batch_id]; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + channel_id * params.out_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index bb25314c8bbbd..32a7d83c09b8d 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -36,6 +36,10 @@ struct ConvParamsBase { void *__restrict__ conv_state_ptr; + // For the continuous batching case. Makes it so that the mamba state for + // the current batch doesn't need to be a contiguous tensor. + int32_t *__restrict__ conv_state_indices_ptr; + void *__restrict__ seq_idx_ptr; // No __restrict__ since initial_states could be the same as final_states. diff --git a/csrc/ops.h b/csrc/ops.h index ee89ad32cb025..15e9ebe87408a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -222,11 +222,10 @@ std::vector selective_scan_fwd( const c10::optional& index_, const c10::optional& x); -at::Tensor causal_conv1d_update(const at::Tensor& x, - const at::Tensor& conv_state, - const at::Tensor& weight, - const c10::optional& bias_, - bool silu_activation); +at::Tensor causal_conv1d_update( + const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, + const c10::optional& bias, bool silu_activation, + const c10::optional& conv_state_indices); at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7009180a8687c..045203c3de8a8 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -279,8 +279,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "causal_conv1d_update(Tensor! x," "Tensor! conv_state," "Tensor! weight," - "Tensor? bias_," - "bool silu_activation) -> Tensor"); + "Tensor? bias," + "bool silu_activation," + "Tensor? conv_state_indices) -> Tensor"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 7bf338b36953a..344e07e739454 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -203,3 +203,61 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1, 4, 5]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, + silu_activation, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + # set seed + torch.random.manual_seed(0) + batch = 64 + + x = torch.randn(batch, dim, device=device, dtype=itype) + + total_entries = 10 * batch + conv_state = torch.randn(total_entries, + dim, + width, + device=device, + dtype=itype) + conv_state_indices = torch.randperm(total_entries)[:batch].to( + dtype=torch.int32, device=device) + + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ac90895b11c37..ff5aa8bee3c27 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -768,11 +768,17 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, silu_activation) -def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + conv_state_indices: Optional[torch.Tensor], +) -> torch.Tensor: return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation) + silu_activation, + conv_state_indices) def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 413c8bc227ae8..196d81267f32f 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py from typing import Optional @@ -70,12 +71,17 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, - activation: Optional[str] = None): + activation: Optional[str] = None, + conv_state_indices: Optional[torch.Tensor] = None): """ x: (batch, dim) conv_state: (batch, dim, width) weight: (dim, width) bias: (dim,) + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. out: (batch, dim) """ @@ -83,4 +89,4 @@ def causal_conv1d_update(x: torch.Tensor, raise NotImplementedError("activation must be None, silu, or swish") activation_bool = activation in ["silu", "swish"] return ops.causal_conv1d_update(x, conv_state, weight, bias, - activation_bool) + activation_bool, conv_state_indices)