Skip to content

Commit

Permalink
[core] multi-step + flashinfer
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Sep 4, 2024
1 parent 77d9e51 commit ebca2cc
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 11 deletions.
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables);

void advance_step_flashinfer(
int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
173 changes: 173 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,81 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t,
}
}

__global__ void advance_step_flashinfer_kernel(
int num_threads, int num_seqs, int num_queries, int block_size,
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr, int64_t const block_tables_stride,
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x < num_query_blocks) {
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;

if (cur_query_id < num_queries) {
// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];

int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;

// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;

int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;

int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;

// Update paged_kv_last_page_len
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;

int slot_num =
seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
}
}
}

__global__ void advance_step_flashinfer_indptr_kernel(
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;

// Update paged_kv_indptr
if (idx < num_queries) {
int sum = 0;
for (int i = 0; i <= idx; ++i) {
sum += block_table_bound_ptr[i];
}
paged_kv_indptr_ptr[idx + 1] = sum;
}
}

__global__ void advance_step_flashinfer_indices_kernel(
int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
int64_t const block_tables_stride, int* paged_kv_indices_ptr,
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
int idx = blockIdx.x * num_threads + threadIdx.x;
int row = idx / block_tables_stride;
int col = idx % block_tables_stride;

if (row < num_queries && col < block_table_bound_ptr[row]) {
paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
block_tables_ptr[row * block_tables_stride + col];
}
// if cudagraph, fill padded seqs with the last valid seq's indptr
if (num_queries < row && row <= num_seqs) {
paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
}
}

void advance_step(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
Expand Down Expand Up @@ -119,6 +194,91 @@ void advance_step(int num_seqs, int num_queries, int block_size,
block_tables.stride(0));
}

void advance_step_flashinfer(
int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables, // type: int
torch::Tensor& paged_kv_indices, // type: int
torch::Tensor& paged_kv_indptr, // type: int
torch::Tensor& paged_kv_last_page_len, // type: int
torch::Tensor& block_table_bound) { // type: int

if (logging) {
printf("advance_step:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
// at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);

verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
at::kInt);

verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);

int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);

int blocks;
int threads;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
if (logging) {
printf("launching kernel with %d blocks\n", blocks);
}

// TODO(will): support arbitrary block_tables stride
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
TORCH_CHECK(false,
"multi-step: not enough threads to map block_table to"
"FlashInfer's paged_kv_indices on GPU. Try reducing the number "
"of seqs,",
" increasing the block size or take smaller steps.",
" num_queries = ", num_queries,
" block_tables.stride(0) = ", block_tables.stride(0),
" blocks = ", blocks, " max_threads = ", threads);
}

advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));

advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));

advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
threads, num_seqs, num_queries,
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0),
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
reinterpret_cast<int*>(block_table_bound.data_ptr()));
}

} // namespace prepare_inputs

void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
Expand All @@ -128,4 +288,17 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
sampled_token_ids, input_positions, seq_lens,
slot_mapping, block_tables);
}

void advance_step_flashinfer(
int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
prepare_inputs::advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
}
3 changes: 3 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("advance_step", &advance_step);
ops.impl("advance_step", torch::kCUDA, &advance_step);

ops.def("advance_step_flashinfer", &advance_step_flashinfer);
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
18 changes: 18 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,24 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
block_tables)


def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:

return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
block_table_bound)


# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
Expand Down
42 changes: 33 additions & 9 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ class FlashInferMetadata(AttentionMetadata):
query_start_loc: Optional[torch.Tensor] = None
block_tables: Optional[torch.Tensor] = None

# used for GPU in-place advance_step
seq_lens_tensor: Optional[torch.Tensor] = None
block_table_bound: Optional[torch.Tensor] = None

# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
Expand Down Expand Up @@ -315,6 +319,8 @@ def begin_forward(self):
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# We will use flash attention for profiling to
Expand All @@ -324,6 +330,8 @@ def begin_forward(self):
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.prefill_wrapper.end_forward()
self.prefill_wrapper.begin_forward(
Expand All @@ -332,14 +340,18 @@ def begin_forward(self):
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
# handle model warmup path
if self.block_table_bound is not None:
self.block_table_bound = self.block_table_bound.to(self.device)
if self.seq_lens_tensor is not None:
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
Expand Down Expand Up @@ -422,7 +434,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []

self.total_blocks = 0
self.is_profile_run: bool = False

def _add_seq_group(
Expand Down Expand Up @@ -493,6 +505,7 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self.total_blocks += len(block_table)
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
Expand Down Expand Up @@ -577,6 +590,10 @@ def build(self, seq_lens: List[int], query_lens: List[int],
out=query_start_loc[1:])

if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
dtype=torch.int)
Expand All @@ -585,10 +602,15 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device="cpu",
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
block_table_bound_tensor = None

if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
Expand All @@ -607,6 +629,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
block_table_bound=block_table_bound_tensor,
seq_lens_tensor=seq_lens_tensor,
num_qo_heads=self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config),
num_kv_heads=self.runner.model_config.get_num_kv_heads(
Expand Down
Loading

0 comments on commit ebca2cc

Please sign in to comment.