Skip to content

Commit

Permalink
Formating and renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
mzusman committed Sep 29, 2024
1 parent 893fdf9 commit e1c018b
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 102 deletions.
32 changes: 16 additions & 16 deletions csrc/mamba/causal_conv1d/causal_conv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void set_conv_params_fwd(ConvParamsBase &params,
const at::Tensor out,
const c10::optional<at::Tensor>& bias,
bool silu_activation,
const c10::optional<at::Tensor>& seq_start_loc = std::nullopt,
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {

Expand All @@ -75,10 +75,10 @@ void set_conv_params_fwd(ConvParamsBase &params,
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
params.out_ptr = out.data_ptr();
// All stride are in elements, not bytes.
params.seq_start_loc_ptr = seq_start_loc.has_value() ? seq_start_loc.value().data_ptr() : nullptr;
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
const bool varlen = params.seq_start_loc_ptr != nullptr;
const bool varlen = params.query_start_loc_ptr != nullptr;
params.x_batch_stride = x.stride(varlen ? 1 : 0);
params.x_c_stride = x.stride(varlen ? 0 : 1);
params.x_l_stride = x.stride(varlen ? 1 : -1);
Expand All @@ -94,7 +94,7 @@ at::Tensor
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &seq_start_loc,
const c10::optional<at::Tensor> &query_start_loc,
const c10::optional<at::Tensor> &cache_indices,
const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation) {
Expand All @@ -106,9 +106,9 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(weight.is_cuda());

const bool varlen = seq_start_loc.has_value() ? true : false;
const bool varlen = query_start_loc.has_value() ? true : false;
const auto sizes = x.sizes();
const int batch_size = varlen ? seq_start_loc.value().sizes()[0] - 1 : sizes[0];
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
const int dim = varlen ? sizes[0] : sizes[1];
const int seqlen = varlen ? sizes[1] : sizes[2];
const int width = weight.size(-1);
Expand Down Expand Up @@ -139,10 +139,10 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
}


if (seq_start_loc.has_value()) {
auto seq_start_loc_ = seq_start_loc.value();
TORCH_CHECK(seq_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(seq_start_loc_.is_cuda());
if (query_start_loc.has_value()) {
auto query_start_loc_ = query_start_loc.value();
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(query_start_loc_.is_cuda());
}


Expand All @@ -159,7 +159,7 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
seq_start_loc,
query_start_loc,
cache_indices,
has_initial_state
);
Expand Down Expand Up @@ -319,13 +319,13 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);

const bool kVarlen = params.seq_start_loc_ptr != nullptr;
const bool kVarlen = params.query_start_loc_ptr != nullptr;
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y;
const int *seq_start_loc = kVarlen ? reinterpret_cast<int *>(params.seq_start_loc_ptr) : nullptr;
const int sequence_start_index = kVarlen ? seq_start_loc[batch_id] : batch_id;
const int seqlen = kVarlen ? seq_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;

input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
+ channel_id * params.x_c_stride;
Expand Down Expand Up @@ -453,7 +453,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
const bool kVarlen = params.seq_start_loc_ptr != nullptr;
const bool kVarlen = params.query_start_loc_ptr != nullptr;
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize;
Expand Down
2 changes: 1 addition & 1 deletion csrc/mamba/causal_conv1d/causal_conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct ConvParamsBase {
void *__restrict__ out_ptr;

void *__restrict__ conv_state_ptr;
void *__restrict__ seq_start_loc_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
Expand Down
2 changes: 1 addition & 1 deletion csrc/mamba/mamba_ssm/selective_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct SSMParamsBase {
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;

void *__restrict__ seq_start_loc_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;

Expand Down
28 changes: 14 additions & 14 deletions csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
int seqlen = params.seqlen;
int sequence_start_index = batch_id;
if constexpr (kVarlen){
int *seq_start_loc = reinterpret_cast<int *>(params.seq_start_loc_ptr);
sequence_start_index = seq_start_loc[batch_id];
seqlen = seq_start_loc[batch_id + 1] - sequence_start_index;
int *query_start_loc = reinterpret_cast<int *>(params.query_start_loc_ptr);
sequence_start_index = query_start_loc[batch_id];
seqlen = query_start_loc[batch_id + 1] - sequence_start_index;
}
const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
Expand Down Expand Up @@ -310,7 +310,7 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.seq_start_loc_ptr != nullptr , kVarlen, [&] {
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
Expand Down Expand Up @@ -404,7 +404,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const torch::Tensor ssm_states,
bool has_z,
bool delta_softplus,
const c10::optional<at::Tensor>& seq_start_loc,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool varlen) {
Expand Down Expand Up @@ -437,7 +437,7 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.ssm_states_ptr = ssm_states.data_ptr();
params.z_ptr = has_z ? z.data_ptr() : nullptr;
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
params.seq_start_loc_ptr = seq_start_loc.has_value() ? seq_start_loc.value().data_ptr() : nullptr;
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;

Expand Down Expand Up @@ -504,7 +504,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const c10::optional<torch::Tensor> &z_,
const c10::optional<torch::Tensor> &delta_bias_,
bool delta_softplus,
const c10::optional<torch::Tensor> &seq_start_loc,
const c10::optional<torch::Tensor> &query_start_loc,
const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states) {
Expand All @@ -530,8 +530,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);

const auto sizes = u.sizes();
const bool varlen = seq_start_loc.has_value();
const int batch_size = varlen ? seq_start_loc.value().sizes()[0] - 1 : sizes[0];
const bool varlen = query_start_loc.has_value();
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
const int dim = varlen ? sizes[0] : sizes[1];
const int seqlen = varlen ? sizes[1] : sizes[2];
const int dstate = A.size(1);
Expand Down Expand Up @@ -588,10 +588,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
}


if (seq_start_loc.has_value()) {
auto seq_start_loc_ = seq_start_loc.value();
TORCH_CHECK(seq_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(seq_start_loc_.is_cuda());
if (query_start_loc.has_value()) {
auto query_start_loc_ = query_start_loc.value();
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(query_start_loc_.is_cuda());
}


Expand Down Expand Up @@ -636,7 +636,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
ssm_states,
has_z,
delta_softplus,
seq_start_loc,
query_start_loc,
cache_indices,
has_initial_state,
varlen
Expand Down
36 changes: 18 additions & 18 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,31 +215,31 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);

void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &A, const torch::Tensor &B,
const torch::Tensor &C,
const c10::optional<torch::Tensor> &D_,
const c10::optional<torch::Tensor> &z_,
const c10::optional<torch::Tensor> &delta_bias_,
bool delta_softplus,
const c10::optional<torch::Tensor> &seq_start_loc,
const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B,
const torch::Tensor& C,
const c10::optional<torch::Tensor>& D_,
const c10::optional<torch::Tensor>& z_,
const c10::optional<torch::Tensor>& delta_bias_,
bool delta_softplus,
const c10::optional<torch::Tensor>& query_start_loc,
const c10::optional<torch::Tensor>& cache_indices,
const c10::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states);

at::Tensor causal_conv1d_update(
const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_, bool silu_activation,
const c10::optional<at::Tensor>& cache_seqlens_,
const c10::optional<at::Tensor>& conv_state_indices_);

at::Tensor causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &seq_start_loc,
const c10::optional<at::Tensor> &cache_indices,
const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation);
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
const c10::optional<at::Tensor>& bias_,
const c10::optional<at::Tensor>& conv_states,
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool silu_activation);

#ifndef USE_ROCM
using fptr_t = int64_t;
Expand Down
4 changes: 2 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor! A, Tensor! B, Tensor! C,"
"Tensor? D_, Tensor!? z_, Tensor? delta_bias_,"
"bool delta_softplus,"
"Tensor? seq_start_loc,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"Tensor! ssm_states) -> ()");
Expand All @@ -295,7 +295,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
"Tensor? bias_,"
"Tensor!? conv_states,"
"Tensor? seq_start_loc,"
"Tensor? query_start_loc,"
"Tensor? cache_indices,"
"Tensor? has_initial_state,"
"bool silu_activation) -> Tensor");
Expand Down
3 changes: 2 additions & 1 deletion tests/kernels/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation,
assert all(s > 0 for s in seqlens[-1])

cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0],dtype=torch.int32), cumsum], dim=0)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
Expand Down
8 changes: 3 additions & 5 deletions tests/kernels/test_mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
assert all(s > 0 for s in seqlens[-1])

cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat(
[torch.tensor([0],dtype=torch.int32), cumsum],
dim=0
).cuda()
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0).cuda()

dim = 4
dstate = 8
Expand Down Expand Up @@ -480,7 +478,7 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups,
device=u.device)
out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias,
delta_softplus, cumsum, cache_indices,
has_initial_state )
has_initial_state)
outs_ref = []
splits = [
torch.split(var, seqlens[0], dim=-1)
Expand Down
25 changes: 12 additions & 13 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,12 +763,12 @@ def ggml_mul_mat_a8(
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
seq_start_loc: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool) -> torch.Tensor:
return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states,
seq_start_loc, cache_indices,
query_start_loc, cache_indices,
has_initial_state, silu_activation)


Expand All @@ -782,18 +782,17 @@ def causal_conv1d_update(
conv_state_indices)


def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, C: torch.Tensor,
D_: Optional[torch.Tensor], z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool,
seq_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: torch.Tensor):
def selective_scan_fwd(
u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
delta_softplus: bool, query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor):
torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_,
delta_softplus, seq_start_loc, cache_indices,
has_initial_state, ssm_states)
delta_softplus, query_start_loc,
cache_indices, has_initial_state,
ssm_states)


# moe
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
seq_start_loc: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
Expand All @@ -23,10 +23,10 @@ def causal_conv1d_fn(
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
seq_start_loc: (batch + 1) int32
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence.
for example: seq_start_loc = torch.Tensor([0,10,16,17]),
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
Expand All @@ -46,7 +46,7 @@ def causal_conv1d_fn(
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None

out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, seq_start_loc,
out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
cache_indices, has_initial_state, activation
in ["silu", "swish"])
return out
Expand Down
Loading

0 comments on commit e1c018b

Please sign in to comment.