Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

[Misc]: Move from using PYBIND11_MODULE macro to bind C++/CUDA kernels to python to using TORCH_LIBRARY macro #133

Closed
LucasWilkinson opened this issue Mar 17, 2024 · 1 comment
Assignees

Comments

@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Mar 17, 2024

Anything you want to discuss about vllm.

Motivation

Currently vLLM uses PYBIND11_MODULE macro to bind C++/CUDA to Python with the binding code being found in csrc/pybind.cpp. This means calls to this kernel bypasses the torch dispatcher (more information on the torch dispatcher can be found here and here). While bypassing the torch dispatcher works, using the torch dispatcher has a few distinct advantages, namely:

  1. Better integration with the Pytorch profiler
  2. A more natural way to support CPU only inference or other hardware in the future

With regards to 1, at Neural Magic we are working on more indepth profiling tools within vLLM using the Pytorch profiler, by using torch dispatcher (i.e. registering the C++/CUDA kernels using TORCH_LIBRARY macro instead of we can PYBIND11_MODULE) we can provide richer traces since the profiler will be able to capture metadata (namely type and shape information) for the inputs to each operation (kernel). Below is an example of the traces we are generating (Note this is work in progress):

name                                                         | cpu_time_us  | cuda_time_us | pct_cuda_... | trace                                                       
========================================================================================================================================================================
LlamaForCausalLM                                             |      9424.95 |     31087.00 |        93.80 |                                                             
|- LlamaModel                                                |      9403.02 |     31087.00 |        93.80 |                                                             
||- VocabParallelEmbedding(weight=bfloat16[32064, 4096])     |        93.30 |         7.00 |         0.02 |                                                             
|||- void at::native::(anonymous namespace)::indexSelectL... |         0.00 |         7.00 |         0.02 | index_select(bfloat16[32064, 4096], 0, int64[128]) <- emb...
||- LlamaDecoderLayer                                        |      1555.91 |       760.00 |         2.29 |                                                             
|||- RMSNorm(weight=bfloat16[4096])                          |       271.12 |         6.00 |         0.02 |                                                             
||||- void vllm::rms_norm_kernel<c10::BFloat16>(c10::BFlo... |         0.00 |         6.00 |         0.02 |                                                             
|||- LlamaAttention                                          |      1003.32 |       173.00 |         0.52 |                                                             
||||- QKVParallelLinear(weight=bfloat16[6144, 4096])         |       173.64 |        95.00 |         0.29 |                                                             
|||||- ampere_bf16_s16816gemm_bf16_256x128_ldg8_f2f_stage... |         0.00 |        95.00 |         0.29 | mm(bfloat16[128, 4096], bfloat16[4096, 6144]) <- matmul(b...
||||- RotaryEmbedding                                        |        19.37 |         4.00 |         0.01 |                                                             
|||||- void vllm::rotary_embedding_kernel<c10::BFloat16, ... |         0.00 |         4.00 |         0.01 |                                                             
||||- Attention                                              |       534.66 |        15.00 |         0.05 |                                                             
|||||- void vllm::reshape_and_cache_kernel<__nv_bfloat16,... |         0.00 |         6.00 |         0.02 |                                                             
|||||- void flash_fwd_kernel<Flash_fwd_kernel_traits<128,... |         0.00 |         9.00 |         0.03 | FlashAttnFunc(bfloat16[8, 16, 32, 128], bfloat16[8, 16, 8...
...

Under the final trace column we can see tensor type and shape information, however this information is only available for TorchOp events (i.e. kernels registered using TORCH_LIBRARY). For example flash_fwd_kernel and ampere_bf16_s16816gemm... has this shape an type information while vllm::reshape_and_cache_kernel does not, as the former two kernels go through the torch dispatcher why the latter does not.

With regards to 2, at Neural Magic we have ambitions to extend vLLM to support CPU inference which will require dispatching to CPU or CUDA versions of the same kernel depending on the location of tensors (this can apply to other hardware too, not just CPUs), this is something the torch dispatcher does this automatically alleviating the need for a chain of if statements.

Implementation

There appears to be 2 primary ways to register operations (kernels) with the torch dispatcher, the first is using C++ and the TORCH_LIBRARY macro mentioned in the motivation. An example of this can be found in the xformers repository, with an SpMM operation being declared here and implementation being bound for CUDA here and CPU here. The other way is via Python, xformers also has an example of this for the flash_fwd operation with the operation declaration being found here and the CUDA implementation being bound here.

For implementation given that vLLM controls the Python to C++/CUDA bindings for the kernels in csrc I think it would be cleaner to go with the TORCH_LIBRARY approach as it wouldn't require much more boiler plate than the existing PYBIND11_MODULE.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants