-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Which low bit CUDA kernels should we merge or write? #697
Comments
Thanks for the survey summary @msaroufim! This is very helpful for understanding what kernels we might be interested to integrate. I think one assumption here is that single kernel performance with certain shape (M, N, K) for linear is a proxy for e2e performance in model. However, I'm still unclear if this is true yet, for example, we know that llama2 int8wo and int4wo gives speedup over bfloat16: https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks, today I printed the recorded shapes from autoquant and run a microbenchmark for the single linears with corresponding (M, N, K) sizes in llama2 model: #695, but it looks like all the shapes for all the quant method are slower than bfloat16, if our microbenchmarking is done correctly, then it means single kernel perf on a given shape may not be a good proxy for model level performance, and we may need to target to optimize specific models instead of just kernels. But I do want to see more data on microbenchmarking v.s. e2e model level benchmarking and try to understand if microbenchmarking kernel perf results can be a good proxy for model level perf results. cc @HDCharles wondering if you have data here around this point when developing autoquant. If the assumption is not true, then it means as part of deciding what kernels we want to merge, we'd also need to say what are the models and even runtime environment (execution engines) we are optimizing for. |
My feeling is it's very much a do both kind of exercise and microbenchmarks can more quickly and reliably find bad kernels and IMHO should be the main criteria to merge work. With e2e benchmarks the tricky part is we might not know if the kernel we have is useful on an other model we haven't tried out yet. But e2e benchmarks would be the main criteria to decide whether to blog about something |
|
@msaroufim #1413 is |
Does this work with bfp16/fp16 accumulation as well or just fp8? |
@mobicham torchao also supports symmetric quant, it should also be easy to support no zero_point use case as well by adding a new layout type for affine quantized tensor I think |
@jerryzh168 the quality tends to be worse with symmetric quantization compared to asymmetric. Much of the quality in linear quantization actually comes from the zero-point not the scaling factor. I actually reported this issue here IST-DASLab/marlin#5 (comment) |
Hi all. I highly recommend the gemm in TurboMind, which implements AWQ, GPTQ, W8A16(INT8, FP8), and is currently the fastest open-source implementation. At small batch sizes, it is several times faster than cuBLAS. It's also faster than the Marlin used in vLLM. https://github.com/InternLM/lmdeploy/tree/main/src/turbomind/kernels/gemm
Recently, we have plans to extract this part into a separate library for easier integration with other projects. I am also very much looking forward to the performance after being integrated into SGLang. If interested, we can further discuss in depth. Cheers! |
@zhyncs is the gemm in turbomind something you'd be interested in contributing to ao? One nice thing about marlin as an example is they have most of their code in a single file and they encourage people to copy paste and package up their kernels |
Ok |
I'll discuss the integrated technical solution with @lzhangzz, aiming to finish the integration asap. We're really excited about it! cc @lvhan028 |
Here is my understanding of the existing state of things and what I think we should be doing to make our lower-bit kernels more performant at both small and larger batch sizes. I'm making this an RFC because I'm curious whether I'm paying attention to the wrong things so if you disagree with any of the below please comment!
First a quick survey of libraries
Survey of existing solutions
Interestingly enough none of the below solutions package their libraries into a package and instead encourage users to copy-paste their code and cite it. It's common to make these libraries to be headers only to make integrations easier.
And we thankfully do have the machinery to support CUDA kernels on multiple different kinds of versions with minimal headache thanks to our custom CUDA extension support https://github.com/pytorch/ao/tree/main/torchao/csrc
So it's easy to merge kernels but which ones should we actually merge?
Marlin
This is the kernel of choice in VLLM arguably the most popular inference provider on the market, they have fp16xint4 kernels that work for smaller batch sizes but larger than tinygemm and competitors and the kernels don't seem particularly affected by power limitation on GPUs, something that has bit us in the past when running internal performance benchmarks.
There's also a 2:4 sparse variant of the above kernel which we're already working on upstreaming #621 yet I'm not sure right now whether we should look to merge both kernels or just the sparse one.
Regardless the https://github.com/IST-DASLab/marlin lab does excellent work consistently and is worth following for us
tinygemm
tinygemm isn't a full library in core but it's an op and it's the speediest thing we've found for int4 weight-only quantization (w4a16) so far
torch.ops.aten._weight_int4pack_mm
. One of the challenges though is because of how fast it is it becomes a hammer and all our performance problems become nails whereas if we could easily accelerate other dtypes we might not rely on it so muchCUTLASS
This work leverages Universal Gemm operator in CUTLASS NVIDIA/cutlass#1549 - no bit is packing since CUTLASS supports a type for
cutlass::int4b_t
There are also some open PRs in CUTLASS for signed and unsigned int4/int8 multiplication with activations in fp16 NVIDIA/cutlass#1413 by @alexsamardzic
Perhaps the main recurring con that comes up with CUTLASS is that it's hard to learn but it generally is one of the best perf targets considering it's more vertically integrated within the NVIDIA stack. And well maybe it's not hard, maybe it's a bit of a skill issue on my end.
gemlite
This is a more recent project but it offers GEMV acceleration https://mobiusml.github.io/gemlite_blogpost/ by @mobicham
The core idea is well explained in https://github.com/Bruce-Lee-LY/cuda_hgemv#optimization-method where they walk through naive implementations to ones efficiently using shared memory and warp scheduling
GEMV kernels are inherently solving a more restricted problem which is bs=1 inference a la gpt-fast
However, despite being limited to batch size 1, gemlite is quite expressive in that allows arbitrary weight dtype. If you look at their function definition
gemv_A16fWnO16f_int32packing
you can read that _fp16 x n-bit as 32-bit packed, mixed fp16 accumulationThe code is quite short and restricted to very few files so quite easy to releverage.
bitblas
https://github.com/microsoft/BitBLAS
This is the only repo with a pip package so packaging it doesn't make as much sense although we could explore using it as an optional backend in ao in cases when we don't have the right kernel. Their support matrix is probably the most comprehensive out of any repo in this list https://github.com/microsoft/BitBLAS#support-matrix
Suggested next steps
Merge the obviously useful kernels
The sort of obvious next steps to match the current state of things are
Considering both of the above work let us work with larger batch sizes than 1 and are an industry standard where people have been frustrated with the installation experience.
Write the non-obvious kernels
For the non-obvious kernels, they haven't been written yet so our strategy typically has been
torch.compile()
with clever bitpacking as a baselineEnd to end benchmarks are certainly helpful but considering here we're talking about kernels we'd also need to run microbenchmarks on various shapes as @jerryzh168 suggests
For bs=1 get better performance for dtypes smaller than 4
gemlite is a nice educational library supporting gemv for a variety of dtypes, so leveraging it not just for end-to-end performance benchmarks but also speed-of-light calculations to help us understand a bit better the gaps for bs=1 inference. The idea here is to ensure that performance is great for a variety of intX as opposed to overfitting to 4 just because we have tinygemm
@vayuda has already led some early work here by doing bitpacks with torch.compile so we need to start baselining more heavily
for bs=n inference start writing new kernels since they don't exist
For H100+
The biggest theme here is that instead of relying on fp16 as the activation dtype we can instead rely on fp8
Some of this work was already mentioned here #663 but we'll add more detail
For A100
For A100 our options are a bit more obvious where we should be showing compelling dynamic quantization (quantize the activations to int8) performance on larger batch sizes. gpt-fast has already been extended to support larger batch sizes https://github.com/pytorch-labs/gpt-fast/tree/batched_generation
For this work we'd focus on int8 dynamic quantization and then work our way down from there.
Related work
The text was updated successfully, but these errors were encountered: