Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step #6338

Merged
merged 32 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);

void gelu_quick(torch::Tensor& out, torch::Tensor& input);

void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables);

#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
Expand Down
131 changes: 131 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* The goal of this GPU kernel is to advance input tensors on the GPU directly
* PR: https://github.com/vllm-project/vllm/pull/6338
* Current restrictions:
* 1. Specialized for DraftModelRunner
* 2. Supports flash_attn only
*/

#include "advance_step.cuh"

namespace prepare_inputs {

//
template <int const num_threads>
__global__ void advance_step_kernel(int num_seqs, int num_queries,
comaniac marked this conversation as resolved.
Show resolved Hide resolved
int block_size, long* input_tokens_ptr,
long const* sampled_token_ids_ptr,
long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr,
int64_t const block_tables_stride) {
int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x >= num_query_blocks) {
return;
}

int cur_query_id = blockIdx.x * num_threads + threadIdx.x;

if (cur_query_id >= num_queries) {
return;
}

// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];

int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;

// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;

int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;

int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;

int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
}

inline void verify_tensor(std::string const& name, torch::Tensor& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
comaniac marked this conversation as resolved.
Show resolved Hide resolved
bool size_0_cond = true;
if (size_0 != -1) {
size_0_cond = t.size(0) == size_0;
}

bool size_1_cond = true;
if (size_1 != -1) {
size_1_cond = t.size(1) == size_1;
}

bool is_contiguous = t.is_contiguous();
bool same_type = t.dtype() == type;

bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
if (!pass) {
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
" is not as expected: shape = [", size_0, ", ", size_1,
"], type = ", type);
}
}

void advance_step(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables) { // type: int

if (logging) {
printf("advance_step:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);

int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);

int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);

advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0));
}

} // namespace prepare_inputs

void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
sampled_token_ids, input_positions, seq_lens,
slot_mapping, block_tables);
}
19 changes: 19 additions & 0 deletions csrc/prepare_inputs/advance_step.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>

namespace prepare_inputs {

static constexpr int max_threads = 256;
static constexpr bool logging = false;

constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }

} // namespace prepare_inputs
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);

// prepare_inputs advance_step
ops.def("advance_step", &advance_step);
ops.impl("advance_step", torch::kCUDA, &advance_step);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
Expand Down
1 change: 1 addition & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def get_output_from_llm_generator(
maybe_assert_ngram_worker(llm)

outputs = llm.generate(prompts, sampling_params, use_tqdm=True)

token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]

Expand Down
48 changes: 48 additions & 0 deletions tests/spec_decode/test_multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k():
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]


@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'

k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)

# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)

seq_group_metadata_list, _, _ = create_batch(batch_size, k)

# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)

# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)

with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1
12 changes: 12 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)


def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping,
block_tables)


# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
Expand Down
Loading
Loading