-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fused HQQ Quantization Gemm #153
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass
benchmarks/benchmark_hqq.py
Outdated
@@ -0,0 +1,134 @@ | |||
import torch | |||
from termcolor import colored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will need to get rid of this dependency for merge, I'm fine with adding adding colors though so something like this should work
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
RESET = "\033[0m" # Resets the color to default.
name = "Alice"
print(f"{GREEN}Hello, {name}!{RESET}")
torchao/prototype/hqq/mixed_mm.py
Outdated
from triton import cdiv | ||
import triton.language as tl | ||
from .kernels import mixed_mm_kernel_compute_bound, mixed_mm_kernel_max_autotune | ||
#credit jlebar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you link to the original code as well? Will need to double check if the LICENSE allows us to copy paste
test/hqq/test_triton_mm.py
Outdated
@@ -0,0 +1,101 @@ | |||
import itertools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both test and benchmark will require skips if triton is less than 3.0 (which is fine because nightlies now ship with 3.0.0) and if hqq is not installed
For hqq I'm fine if we add it as a dev dependency for now
return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None | ||
|
||
|
||
SHAPES = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch I guess these shapes are fine for now but are there some specific shapes we're more interested in tracking on an ongoing basis if so I wish we could just make them part of our benchmark or test utilities
benchmarks/benchmark_hqq.py
Outdated
|
||
|
||
df = pd.DataFrame(data, columns=HEADERS) | ||
df.to_csv("benchmark_triton.csv", index=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will lose this csv on CI unless its saved to some github artifact so unless this file is huge let's just print it for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
GPU details: | ||
|
||
``` | ||
_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
once we figure out the installation issues I'll check to see if results repro on an H100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apologies meant pip freeze
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you run on H100, can you run once with DISABLE_MMA_V3=1
? It toggles Hopper
specific specializations in triton
. Curious to see how performance changes.
torchao/prototype/hqq/README.md
Outdated
@@ -0,0 +1,43 @@ | |||
## Fused `int4 / fp16` Quant Matmul | |||
|
|||
Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds like one of the 2 asymetric should be a symetric?
torchao/prototype/hqq/README.md
Outdated
The kernel packs `u4 / s4` weights and fuses dequantization with the matmul. | ||
|
||
- tested for `float16 / bfloat16` activations, scales, and zeros | ||
- autotuned for both compute-bound and io-bound configs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could use memory bandwidth bound terminology instead
torchao/prototype/hqq/README.md
Outdated
|
||
Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. | ||
|
||
The kernel packs `u4 / s4` weights and fuses dequantization with the matmul. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b q: whu can't we generically do this with torch.compile @HDCharles
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does work with torch.compile and there's a good speed-up (up to 4x compared to Pytorch), but a dequantize() CUDA kernel + torch.matmul is a bit faster.
I think the bitpacking should be done in such a way that torch.compile can fully optimize it.
|
||
|
||
@triton.jit | ||
def _mixed_mm_kernel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see a lot of similarities between this code and what you had contributed for galore, can we start modularizing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass
|
Results on an H100
Although it's very curious how different the pattern is on an A6000
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
* add test / benchmark * add kernels * update readme * more readme edits * edit readme * add transpose test * transpose test pass * refactor test * add checks for CI * add more comments for transpose kernel * remove import in test * clean up benchmark * fix test import order * minor README edits * additional readme edits * update readme * update readme * add note about cudamode --------- Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
@msaroufim
Fused
int4 / fp16
Quant MatmulFused kernel that combines asymmetric dequantization and gemm. Useful primarily for compute-bound (M > 16) scenarios and not for memory-bound / inference scenarios.
The kernel fuses two ops:
u4 / s4
weights tofloat16 / bfloat16
, followed by groupwise scaling and shifting by scales / zeropointsTested and benchmarked for
HQQ
but could theoretically be used for any asymmetric quantization scheme.Implementation Details
tinygemm
orfastertransformer
)float16 / bfloat16
activations, scales, and zerosgemm
is is the quantized type.in-features
, i.e., theK
dimension, oraxis=1
, oftorch.linear.weight
.Performance
Initial benchmarking (on
A6000
) demonstrates promising results, scaling well for compute-bound workloads:ms
, seebenchmarks/benchmark_hqq.py
.hqq_ref
is the baseHQQ_Linear
module that is unfused (dequantization followed by call to torch.matmul).tinygemm
callstorch.ops.aten._weight_int4pack_mm
. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham fromCUDA-mode
Discord discussions.GPU details:
NOTE
This implementation requires
triton >= 3.0.0
.Running tests / benchmarks requires installation of
hqq
:TODO
pytest
torch.compile
benchmarkinghqq
)