Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Misc] Use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops #5047

Merged
merged 41 commits into from
Jun 9, 2024
Merged
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5cd3412
Use TORCH_LIBRARY instead of PYBIND11_MODULE
bnellnm May 25, 2024
77c3e93
fix cpu defs
bnellnm May 25, 2024
d317299
fix typo in cpu pybind
bnellnm May 25, 2024
730ad15
fix moe/punica
bnellnm May 25, 2024
d435732
fixes
bnellnm May 25, 2024
f474a78
fixes
bnellnm May 26, 2024
5ddf7c0
fix punica_pybind signature
bnellnm May 26, 2024
699a373
rebase
bnellnm May 26, 2024
f2646f4
fix format
bnellnm May 26, 2024
497c3f4
more clang-format
bnellnm May 26, 2024
ed5cb0b
cpu fixes
bnellnm May 26, 2024
81c2783
convert cache_ops + cuda_utils
bnellnm May 27, 2024
f0c5e87
add mutable indices to schema registration
bnellnm May 27, 2024
70bf9a9
clang format
bnellnm May 27, 2024
216d3d1
fix format
bnellnm May 27, 2024
4bc89c1
fix cpu binding
bnellnm May 27, 2024
fb2a195
fix intel
bnellnm May 27, 2024
4052198
format
bnellnm May 27, 2024
e8e9af2
add meta functions and signatures
bnellnm May 31, 2024
c72bf4c
update cpu bindings
bnellnm May 31, 2024
c5562c8
convert custom_ar ops
bnellnm May 31, 2024
a0988ac
fix formatting
bnellnm May 31, 2024
a0a2a00
move punica and moe ops into _custom_ops
bnellnm May 31, 2024
d956724
maybe fix all_reduce
bnellnm May 31, 2024
a24f023
fix more formatting
bnellnm May 31, 2024
b1f61f4
fix some stuff
bnellnm Jun 1, 2024
6d35d5c
use python stable api
bnellnm Jun 1, 2024
1fcde34
fix cpu
bnellnm Jun 1, 2024
88dc724
add comment
bnellnm Jun 1, 2024
9d42c29
fix punica
bnellnm Jun 1, 2024
f178aed
try to fix ROCM dockerfile
bnellnm Jun 1, 2024
13e36f0
cleanups
bnellnm Jun 3, 2024
415dc42
rebase + use Tensor instead of std::vector<uint8_t> in custom ar api
bnellnm Jun 3, 2024
0f35b08
fix formatting
bnellnm Jun 3, 2024
3b722e3
fix test_int8_quant.py test
bnellnm Jun 3, 2024
3e5cec2
remove meta ops for now
bnellnm Jun 4, 2024
dc87320
rebase + some review comments
bnellnm Jun 7, 2024
bb2446e
libtorch_python.so no longer needed?
bnellnm Jun 7, 2024
48868e1
rename pybind files to torch_bindings.cpp
bnellnm Jun 7, 2024
4ed8bf2
add comments about const vectors
bnellnm Jun 7, 2024
57088e4
rebase + run format.sh
bnellnm Jun 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add mutable indices to schema registration
bnellnm committed Jun 9, 2024
commit f0c5e87da940c28c3b996e5cc7bcecc324557f5e
59 changes: 23 additions & 36 deletions csrc/cpu/pybind.cpp
Original file line number Diff line number Diff line change
@@ -8,75 +8,62 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Attention ops
// Compute the attention between an input query and the cached keys/values
// using PagedAttention.
ops.def(
"paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, "
"Tensor value_cache, int num_kv_heads, float scale, Tensor "
"block_tables, Tensor seq_lens, int block_size, int max_seq_len, "
"Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank,"
"int blocksparse_local_blocks, int blocksparse_vert_stride, "
"int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()");
ops.def("paged_attention_v1", &paged_attention_v1);
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);

// PagedAttention V2.
ops.def(
"paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits,"
"Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache,"
"int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens,"
"int block_size, int max_seq_len, Tensor? alibi_slopes, "
"str kv_cache_dtype, float kv_scale, int tp_rank, "
"int blocksparse_local_blocks, int blocksparse_vert_stride,"
"int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()");
ops.def("paged_attention_v2", &paged_attention_v2);
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);

// Activation ops

// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor out, Tensor input) -> ()");
ops.def("silu_and_mul", &silu_and_mul);
ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);

// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor out, Tensor input) -> ()");
ops.def("gelu_and_mul", &gelu_and_mul);
ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);

// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor out, Tensor input) -> ()");
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul);
ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);

// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor out, Tensor input) -> ()");
ops.def("gelu_new", &gelu_new);
ops.impl("gelu_new", torch::kCPU, &gelu_new);

// Approximate GELU implementation.
ops.def("gelu_fast(Tensor out, Tensor input) -> ()");
ops.def("gelu_fast", &gelu_fast);
ops.impl("gelu_fast", torch::kCPU, &gelu_fast);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()");
ops.def("rms_norm", &rms_norm);
ops.impl("rms_norm", torch::kCPU, &rms_norm);

// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float "
"epsilon) -> ()");
ops.def("fused_add_rms_norm", &fused_add_rms_norm);
ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);

// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor query, Tensor key, int "
"head_size, Tensor cos_sin_cache, bool is_neox) -> ()");
ops.def("rotary_embedding", &rotary_embedding);
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def("copy_blocks", &copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def("reshape_and_cache", &reshape_and_cache,
"Reshape the key and value tensors and cache them");
// Swap in (out) the cache blocks from src to dst.
cache_ops.def("swap_blocks", &swap_blocks);
cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);

// Copy the cache blocks from src to dst.
cache_ops.def("copy_blocks", &copy_blocks);
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);

// Reshape the key and value tensors and cache them.
cache_ops.def("reshape_and_cache", &reshape_and_cache);
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}
10 changes: 2 additions & 8 deletions csrc/moe/moe_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
#include "moe_ops.h"

#include <torch/extension.h>

#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor topk_weights, Tensor topk_indices, Tensor "
"token_expert_indices, Tensor gating_output) -> ()");
vllm::def(m, "topk_softmax", &topk_softmax, {0, 1});
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
}

// TODO: get rid of this
// TODO: get rid of this?
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
2 changes: 1 addition & 1 deletion csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "register.h"

void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
6 changes: 1 addition & 5 deletions csrc/ops.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
#pragma once

#include <torch/extension.h>

#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#include "register.h"

void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
2 changes: 1 addition & 1 deletion csrc/punica/punica_ops.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/extension.h>
#include "register.h"

void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, double scale);
12 changes: 2 additions & 10 deletions csrc/punica/punica_pybind.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
#include <torch/extension.h>

#include "punica_ops.h"

#define TORCH_LIBRARY_EXPAND(NAME, M) TORCH_LIBRARY(NAME, M)

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.def(
"dispatch_bgmv(Tensor y, Tensor x, Tensor w, Tensor indicies, int "
"layer_idx, float scale) -> ()");
vllm::def(m, "dispatch_bgmv", &dispatch_bgmv, {0});
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);

m.def(
"dispatch_bgmv_low_level(Tensor y, Tensor x, Tensor w, Tensor indicies, "
"int layer_idx, float scale, int h_in, int h_out, int y_offset) -> ()");
vllm::def(m, "dispatch_bgmv_low_level", &dispatch_bgmv_low_level, {0});
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
}

124 changes: 39 additions & 85 deletions csrc/pybind.cpp
bnellnm marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -3,179 +3,132 @@
#include "ops.h"
#include <torch/extension.h>

using vllm::def;

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops

// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
ops.def(
"paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, "
"Tensor value_cache, int num_kv_heads, float scale, Tensor "
"block_tables, Tensor seq_lens, int block_size, int max_seq_len, "
"Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale, int tp_rank,"
"int blocksparse_local_blocks, int blocksparse_vert_stride, "
"int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()");
//ops.def("paged_attention_v1", &paged_attention_v1);
def(ops, "paged_attention_v1", &paged_attention_v1, {0});
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);

// PagedAttention V2.
ops.def(
"paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits,"
"Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache,"
"int num_kv_heads, float scale, Tensor block_tables, Tensor seq_lens,"
"int block_size, int max_seq_len, Tensor? alibi_slopes, "
"str kv_cache_dtype, float kv_scale, int tp_rank, "
"int blocksparse_local_blocks, int blocksparse_vert_stride,"
"int blocksparse_block_size, int blocksparse_head_sliding_step) -> ()");
def(ops, "paged_attention_v2", &paged_attention_v2, {0});
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);

// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor out, Tensor input) -> ()");
def(ops, "silu_and_mul", &silu_and_mul, {0});
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor out, Tensor input) -> ()");
def(ops, "gelu_and_mul", &gelu_and_mul, {0});
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);

// Activation function used in GeGLU with `tanh` approximation.
ops.def("gelu_tanh_and_mul(Tensor out, Tensor input) -> ()");
def(ops, "gelu_tanh_and_mul", &gelu_tanh_and_mul, {0});
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);

// GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor out, Tensor input) -> ()");
def(ops, "gelu_new", &gelu_new, {0});
ops.impl("gelu_new", torch::kCUDA, &gelu_new);

// Approximate GELU implementation.
ops.def("gelu_fast(Tensor out, Tensor input) -> ()");
def(ops, "gelu_fast", &gelu_fast, {0});
ops.impl("gelu_fast", torch::kCUDA, &gelu_fast);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor out, Tensor input, Tensor weight, float epsilon) -> ()");
def(ops, "rms_norm", &rms_norm, {0});
//ops.def("rms_norm", &rms_norm);
// ops.def(torch::schema("rms_norm(Tensor out, Tensor input, Tensor weight,
// float epsilon) -> ()"), c10::AliasAnalysisKind::CONSERVATIVE);
ops.impl("rms_norm", torch::kCUDA, &rms_norm);

// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm(Tensor input, Tensor residual, Tensor weight, float "
"epsilon) -> ()");
def(ops, "fused_add_rms_norm", &fused_add_rms_norm, {0, 1});
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);

// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor query, Tensor key, int "
"head_size, Tensor cos_sin_cache, bool is_neox) -> ()");
def(ops, "rotary_embedding", &rotary_embedding, {1, 2});
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);

// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
// (supports multiple loras).
ops.def(
"batched_rotary_embedding(Tensor positions, Tensor query, Tensor "
"key, int head_size, Tensor cos_sin_cache, bool is_neox, int "
"rot_dim, Tensor cos_sin_cache_offsets) -> ()");
def(ops, "batched_rotary_embedding", &batched_rotary_embedding, {1, 2}); // ?
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);

// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
ops.def(
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, "
"Tensor codebook_partition_sizes, Tensor? bias) -> Tensor");
ops.def("aqlm_gemm", &aqlm_gemm);
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);

// Decompression method for AQLM.
ops.def(
"aqlm_dequant(Tensor codes, Tensor codebooks, Tensor "
"codebook_partition_sizes) -> Tensor");
ops.def("aqlm_dequant", &aqlm_dequant);
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);

// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
"Tensor _zeros, int split_k_iters) -> Tensor");
ops.def("awq_gemm", &awq_gemm);
ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);

// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor "
"workspace, int size_m, int size_n, int size_k) -> Tensor");
ops.def("marlin_gemm", &marlin_gemm);
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);

// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, Tensor "
"b_scales, Tensor workspace, int num_bits, int size_m, int size_n, int "
"size_k) -> Tensor");
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);

// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, Tensor "
"g_idx, Tensor perm, Tensor workspace, int num_bits, int size_m, int "
"size_n, int size_k, bool is_k_full) -> Tensor");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);

// gptq_marlin repack from GPTQ.
ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, int size_k, int "
"size_n, int num_bits) -> Tensor");
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);

// Dequantization for AWQ.
ops.def(
"awq_dequantize(Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, "
"int split_k_iters, int thx, int thy) -> Tensor");
ops.def("awq_dequantize", &awq_dequantize);
ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);

// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_dq(Tensor out, Tensor a, Tensor b, Tensor a_scales, "
"Tensor b_scales) -> ()");
def(ops, "cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, {0});
ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq);
#endif

// Quantized GEMM for GPTQ.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, Tensor "
"b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) -> Tensor");
ops.def("gptq_gemm", &gptq_gemm);
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);

// Post processing for GPTQ.
ops.def("gptq_shuffle(Tensor q_weight, Tensor q_perm, int bit) -> ()");
def(ops, "gptq_shuffle", &gptq_shuffle, {0});
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);

// Quantized GEMM for SqueezeLLM.
ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor mul, Tensor "
"lookup_table) -> ()");
def(ops, "squeezellm_gemm", &squeezellm_gemm, {2});
ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm);

// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()");
def(ops, "static_scaled_fp8_quant", &static_scaled_fp8_quant, {0});
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);

// Compute FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor out, Tensor input, Tensor scale) -> ()");
def(ops, "dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, {0});
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);

// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size,"
"Tensor sorted_token_ids, Tensor experts_ids, Tensor "
"num_tokens_post_pad) -> ()");
def(ops, "moe_align_block_size", &moe_align_block_size, {3, 4, 5});
ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);

// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor out, Tensor input, float scale) -> ()");
def(ops, "static_scaled_int8_quant", &static_scaled_int8_quant, {0});
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);

// Compute int8 quantized tensor and scaling factor
@@ -187,35 +140,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops
// Swap in (out) the cache blocks from src to dst.
cache_ops.def("swap_blocks(Tensor src, Tensor dst, Tensor block_mapping) -> ()");
def(cache_ops, "swap_blocks", &swap_blocks, {0,1});
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);

// Copy the cache blocks from src to dst.
cache_ops.def("copy_blocks", &copy_blocks);
def(cache_ops, "copy_blocks", &copy_blocks, {0, 1});
cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);

// Reshape the key and value tensors and cache them.
cache_ops.def("reshape_and_cache", &reshape_and_cache);
def(cache_ops, "reshape_and_cache", &reshape_and_cache, {2, 3}); // 4?
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);

// Reshape the key and value tensors and cache them.
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash);
def(cache_ops, "reshape_and_cache_flash", &reshape_and_cache_flash, {2, 3}); // 4?
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash);

// Convert the key and value cache to fp8 data type.
cache_ops.def("convert_fp8", &convert_fp8);
def(cache_ops, "convert_fp8", &convert_fp8, {0});
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
// Cuda utils

// Gets the specified device attribute.
cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
cuda_utils.def("get_device_attribute", &get_device_attribute);
cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);

// Gets the maximum shared memory per block device attribute.
cuda_utils.def("get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute);
cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", torch::kCUDA,
&get_max_shared_memory_per_block_device_attribute);
}
51 changes: 51 additions & 0 deletions csrc/register.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include <iostream>

#include <torch/extension.h>

#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)

namespace vllm {

template <typename FnType>
void def(torch::Library& lib, std::string const& name, FnType* fn,
std::initializer_list<int> mutating_arg_indices = {}) {
#if 1
auto raw_schema =
c10::detail::inferFunctionSchemaFromFunctor<std::decay_t<FnType>>();
auto named_schema = raw_schema->cloneWithName(name, "");

if (mutating_arg_indices.size() != 0) {
std::vector<c10::Argument> const& args = named_schema.arguments();
std::vector<c10::Argument> new_args;
for (size_t i = 0; i < args.size(); ++i) {
auto const& arg = args[i];
if (std::find(mutating_arg_indices.begin(), mutating_arg_indices.end(),
i) == mutating_arg_indices.end()) {
new_args.push_back(arg);
} else {
c10::AliasInfo new_alias_info;
if (arg.alias_info()) {
new_alias_info = *arg.alias_info();
}
new_alias_info.setIsWrite(true);

new_args.emplace_back(
arg.name(), arg.type(), arg.real_type(), arg.N(),
arg.default_value(), arg.kwarg_only(), new_alias_info);
}
}

named_schema = named_schema.cloneWithArguments(std::move(new_args));
}

lib.def(std::move(named_schema));
#else
lib.def(name.c_str(), fn);
#endif
}

} // namespace vllm
18 changes: 12 additions & 6 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import torch

try:
# ruff: noqa: SIM105
# ruff: noqa: F401 SIM105
import vllm._C
except ImportError as e:
from vllm.logger import init_logger
@@ -315,8 +315,10 @@ def reshape_and_cache(
kv_cache_dtype: str,
kv_scale: float,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
torch.ops._C_cache_ops.reshape_and_cache(
key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, kv_scale)


def reshape_and_cache_flash(
@@ -327,8 +329,10 @@ def reshape_and_cache_flash(
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype)


def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
@@ -353,7 +357,9 @@ def get_device_attribute(attribute: int, device: int) -> int:


def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(device)
# ruff: noqa: E501
return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
device)


#TODO: custom_ar