-
Notifications
You must be signed in to change notification settings - Fork 978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for mixed 4-bit/8-bit data types GEMM #1413
Conversation
More to come here: support for |
0653dc3
to
6cf7e62
Compare
Added more tests. |
6cf7e62
to
c51824c
Compare
|
||
*/ | ||
|
||
#include <iostream> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leftover? (just lurking around)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. But it's there pretty much in all the tests in test/unit/gemm/device
, it seems we've been just copying it around... Would be the best to remove all of them in a separate PR.
c51824c
to
af18f2c
Compare
Added generator support for S8/S4 and S4/S8. AFAIK, implementing generator support for given operation is not specifically documented, so I want to clarify the steps I've taken here. Basically, I've copied code from
I did the verification as @manishucsd suggested here: As mentioned above, I did the build with all the relevant kernels included, and then I verified that Overall, this PR now contains everything that I intended to do for |
af18f2c
to
7c91218
Compare
Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization? |
This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator. |
492a0c2
to
1dfe9c0
Compare
@manishucsd, @hwu36: Would it be possible for someone to review this PR (and eventually #1350 too)? These should not be controversial, are needed by PyTorch, and for this one I'd like to proceed with another PR to add other 4-bit/8-bit integer combinations that make sense. |
working on it now. |
Great job! How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic |
The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon. |
I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. |
These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see |
On a quick look, your strides may be wrong. |
Thank you for your prompt reply. I don't know much about this parameter, and I can't find many references. Could you give me some more details? Thank you very much. |
I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually? |
No, s4 values should be packed, two values per byte. |
Thanks for your help ! I can get correct result now. but I have another question: |
If matrix |
Thanks, I’m trying this, but it’s not going well currently.
That is: |
This comment was marked as duplicate.
This comment was marked as duplicate.
This code uses PyTorch, can you post a reproducible example that uses CUTLASS only? |
Hi @alexsamardzic , I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu#L114 you can add this example , then complie and run it:
when But when I am really at a loss and would greatly appreciate any guidance or help you can provide. Thank you very much in advance for your time and assistance! |
Replace |
Works for me. Thanks! |
Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work. |
Yea, that was a mistake. |
Going through relevant examples, as well as unit tests, in the CUTLASS source tree is probably still the best way to start. |
Hi @alexsamardzic , in fact, I am working in the GEMM+de-quantization fusion kernel for W4A8 based on your PR, similar to the W8A8 kernel for pytorch here (GEMM+de-quantization), which also used EVT. I have completed most of the work and used EVT to finish the de-quantization. Can you please have a look what the possible issues might be? |
I'm sorry, but this has nothing to do with this particular PR, and unfortunately I don't have cycles to help you with this. You need to be sure that you understand building EVT epilogues, as well as specifying corresponding arguments. Then, in your position I would start with a simple epilogue that is just storing values from the accumulator into the output tensor. If results match expected ones, then I would add nodes into the EVT epilogue that do the multiplication, one by one, and would keep comparing results with the expected ones. When there is mismatch, you should know where to look for the fix. |
Hi @alexsamardzic ,thanks for your help, I have make it done. When profiling a single GEMM, do you think the performance of |
The GEMM actually performed is the same: |
1dfe9c0
to
95a7d30
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this. Apologies for a delayed review. LGTM.
Over to NVIDIA/CUTLASS (cc: @hwu36 ) for merging this.
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), | ||
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), | ||
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), | ||
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), | ||
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), | ||
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), | ||
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), | ||
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), | ||
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), | ||
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), | ||
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), | ||
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), | ||
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), | ||
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), | ||
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), | ||
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you come up with this TileDescription
list for S8
x S4
? I guess you carried these from S8
x S8
. Please make sure all of these pass verification. You can follow steps similar to here to instantiate all the tile shapes listed here by using -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8"
. By default the build process only instantiate 128x128 tile shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please run the verification and profiling on --m=3456 --n=4096 --k=2048
on an A100?
Please compile using -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8,s4,s8"
to also have s4 x s4 and s8 x s8 kernels in the runs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tiles selection is desribed in a comment above; also, as mentioned in this comment, I did the verification. I will repeat the verification procedure, together with profiling, and report the outcome here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a build issue after rebasing on the latest main: basically, OpMultiplyAddSaturate
for MmaTensorOpPolicy
in the specialization of struct DefaultMmaTensorOp
(in include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
) seem to be obligatory now, as the build fails if OpMultiplyAdd
used. The branch is updated accordingly.
I've configured the build using -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8,s4,s8"
CMake options, and then verified that cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
unit test passes. Then, I did profiler runs as follows:
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s8:row --B=s8:column >& s8_s8.txt
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s4:row --B=s8:column >& s4_s8.txt
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s8:row --B=s4:column >& s8_s4.txt
The corresponding profiler outputs are here:
s8_s8.txt
s4_s8.txt
s8_s4.txt
The disposition values for mixed data types cases with s8
accumulator are still incorrect
. Also, the timings are somewhat slower than for corresponding s8xs8
cases (with the same configurations: tile sizes etc.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for running and sharing these results.
Accumulator is for all of these runs should be S32
as shown at the bottom of the output in csv format with accum
type = S32
. The Incorrect
disposition with mixed-input is happening for only S8
output, i.e., when the accumulators are S32
but the output is downcast-ed to S8
.
We do not see incorrect results for S8
xS8
with S32
accumulators and S8
output, can you pick one row of incorrect run from
(elementD/elementC type) <= (elementA type) x(elementB type) + (accum type)
S8 <= S8 x S4 + S32
and compare the same kernel configuration against
S8 <= S8 x S8 + S32
to find where is the difference?
I believe it is to do with initialization of the operands during profiling or inside the kernel epilogue S32-to-S8. quantization.
Also, you can just upload the csv that can be produced by adding --output=filename.csv
to the profiler runs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is what I found so far regarding incorrect
cases:
First, I made following change in the code generating inputs, in order to generate the same inputs for profiler for S4xS8
and S8xS8
cases:
diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu
index 2cbfa5d2..7b488fe8 100644
--- a/tools/profiler/src/device_context.cu
+++ b/tools/profiler/src/device_context.cu
@@ -105,7 +105,7 @@ DeviceAllocation *DeviceContext::allocate_tensor(
data_distribution.set_uniform(-1, 1, 0);
break;
case library::NumericTypeID::kS4:
- data_distribution.set_uniform(-2, 2, 0);
+ data_distribution.set_uniform(-3, 3, 0);
break;
case library::NumericTypeID::kU2:
data_distribution.set_uniform(0, 2, 0);
I used following profiler runs to make comparision between S4xS8
and S8xS8
cases (BTW, I found that smaller input shapes selection that would still allow for reproducing the problem would be --m=32 --n=64 --k=512
):
cutlass_profiler --operation=gemm --gemm_kind=universal --m=3456 --n=4096 --k=2048 --A=s8:row --B=s8:column --C=s8:column --D=s8:column --accum=s32 --cta_m=256 --cta_n=128 --cta_k=64 --stages=3 --save-workspace=always
cutlass_profiler --operation=gemm --gemm_kind=universal --m=3456 --n=4096 --k=2048 --A=s4:row --B=s8:column --C=s8:column --D=s8:column --accum=s32 --cta_m=256 --cta_n=128 --cta_k=64 --stages=3 --save-workspace=always
By comparing saved .mat
files, I verified that input matrices A
, B
and C
are the same, but also that output matrices D
are the same. What differs are actually Reference
matrices, which means that reference results calculated for S4xS8
case are wrong. If I understood it correctly, cuBLAS is used for reference calculations, so I'll check what's going on there...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if cuBLAS is called for this for reference check. The output should show which references are called. You can use --verification-providers=cublas,host,device
to run them all. Is there a host reference for this you must check in here /tools/library/src/reference ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed - device
provider is actually used for reference check here. I posted an update with the fix in reference calculations, so for most of cases with S8
output, cutlass_profiler
reports success now. However, there are still couple of cases where incorrect
is reported, I'm looking into this...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pushed another update - the problem with remaining incorrect
cases was that I haven't copied C
operand alignment update from S8xS8
case, in the generator code. Everything is reported as passed
now by profiler, the output files are attached below. I believe this one should be ready for merging now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexsamardzic thank you for digging it through. LGTM!
@hwu36 , @thakkarV , @IonThruster , can you please help it merge it?
95a7d30
to
1e5ed24
Compare
1e5ed24
to
fcebe62
Compare
fcebe62
to
a30806a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
while we are at this, i think we can improve the int4->int8 upcasting. now we use 11 instructions to upcast 8 elements. quite a lot. we used a look-up-table method to do int->fp8 upcasting (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/numeric_conversion.h#L2983-L3027), I think we maybe able to use the same here. @alexsamardzic , do you want to give it a try? i am setting up now so it won't take me months to merge your code. cc @rhenry-nv |
Sure. Below is a patch to implement the look-up table method for int4->int8 (pretty much the same as existing int4->fp8 code), and also the profiler outputs for original and patched version. It seems that the look-up table method is slower. I ran the profiler in both cases as follows:
and here are mentioned files: I was the least happy about the conversion code in this PR, but this is the best I was able to come up with... |
* Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
* Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
No description provided.