Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mixed 4-bit/8-bit data types GEMM #1413

Merged
merged 2 commits into from
Aug 30, 2024

Conversation

alexsamardzic
Copy link
Contributor

No description provided.

@alexsamardzic
Copy link
Contributor Author

More to come here: support for U4, support for generator in the CUTLASS library, etc. Still, opening PR to solicit feedback for S8/S4 and S4/S8 GEMMs that are now available; in particular, I'm interested in eventual suggestions for a faster approach to S4->S8 conversion.

@alexsamardzic
Copy link
Contributor Author

Added more tests.


*/

#include <iostream>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover? (just lurking around)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. But it's there pretty much in all the tests in test/unit/gemm/device, it seems we've been just copying it around... Would be the best to remove all of them in a separate PR.

@alexsamardzic
Copy link
Contributor Author

Added generator support for S8/S4 and S4/S8.


AFAIK, implementing generator support for given operation is not specifically documented, so I want to clarify the steps I've taken here. Basically, I've copied code from GenerateSM80_TensorOp_16832_TN method into GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_(a|b), and then made some changes:

  • Obviously, I've changed math_instructions assignments according to data types actually used for mixed input data types.
  • I'm not sure from where smem_usage = 164 in GenerateSM80_TensorOp_16832_TN comes, and this variable is not further used anyway, so I skipped it in new methods.
  • I've used two sets of alignment constraints. The alignments for operands A and B are the same, but operand C (and thus the result too) could be either 32-bit or 8-bit. The code at the end of the mixed input methods, within the last if statement is handling the later case, and alignments are changed here accordingly. (Note that for GenerateSM80_TensorOp_16816_mixed_input_upcast_(a|b) there are snippets of code at the end of methods doing alike thing, but they're slightly different from each other, and also from what I did here.)
  • The tile_descriptions were initially copied from GenerateSM80_TensorOp_16832_TN, and then I would make sure that all the relevant kernels would be compiled (through adding CUTLASS_LIBRARY_KERNELS="*i16832gemm*" to the CMake command line), and would remove tiles that would fail to compile.

I did the verification as @manishucsd suggested here: As mentioned above, I did the build with all the relevant kernels included, and then I verified that cutlass_profiler would run all the tile variations that are specified in GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_(a|b). Note that the profiler would produce Disposition: Incorrect for all the kernels with 8-bit output; I suppose it's related to saturation - I'm not sure if I should actually come up with applying saturation somehow for this combination of input data types?


Overall, this PR now contains everything that I intended to do for S4/S8 and S8/S4 GEMM, and it's ready for review. It has grown somewhat large, so I'd suggest to have it reviewed and eventually merged, and then I can add U4/U8 and U8/U4, and maybe U4/S8 and S8/U4, support in follow-up PR(s).

@andrewor14
Copy link

Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?

@alexsamardzic
Copy link
Contributor Author

Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?

This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator.

@alexsamardzic
Copy link
Contributor Author

@manishucsd, @hwu36: Would it be possible for someone to review this PR (and eventually #1350 too)? These should not be controversial, are needed by PyTorch, and for this one I'd like to proceed with another PR to add other 4-bit/8-bit integer combinations that make sense.

@hwu36
Copy link
Collaborator

hwu36 commented Apr 18, 2024

working on it now.

@Hongbosherlock
Copy link

Hi @alexsamardzic, thanks for working on this. Just wanted to clarify, will this kernel support int4 grouped per channel weight quantization + int8 per token dynamic activation quantization?

This kernel is just int4/int8 GEMM, producing int32 (or int8) result. Quantization is not to be supported by CUTLASS directly, but could be implemented using an EVT epilogue. In particular, I'm trying to get this feature into CUTLASS mainly in order to have this particular operation supported in PyTorch, with using it along with quantization as primary motivator.

Great job! How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic

@alexsamardzic
Copy link
Contributor Author

How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic

The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon.

@Hongbosherlock
Copy link

How can I integrate this PR with PyTorch? Are there any example codes available ? @alexsamardzic

The primary motivation for this PR is to have this combination of operands supported by PyTorch, so the integration should be coming soon.

I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM.
Could you please provide an example code for testing this s4/s8 GEMM? like the official example here:
https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md

@alexsamardzic
Copy link
Contributor Author

I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md

These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see using Gemm = cutlass::gemm::device::GemmUniversal... template instantiations in the test cases added by this PR into test/unit/gemm/device directory. As far as your data concerned, s4 data should be provided as two successive values packed into single byte, and that's all.

@alexsamardzic
Copy link
Contributor Author

On a quick look, your strides may be wrong.

@zkf331
Copy link

zkf331 commented May 13, 2024

On a quick look, your strides may be wrong.

Thank you for your prompt reply. I don't know much about this parameter, and I can't find many references. Could you give me some more details? Thank you very much.

@Hongbosherlock
Copy link

I'm a beginner with Cutlass, I have on idea how to use my own constructed s4/s8 data to run this GEMM. Could you please provide an example code for testing this s4/s8 GEMM? like the official example here: https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/README.md

These changes are not for Hopper, but for Ampere architecture. The code to run s4/s8 GEMM would be the same as for any other GEMM, for example s8/s8, except that when a GEMM template instantiated, data type and other argument should be specified accordingly. For some examples of this, see using Gemm = cutlass::gemm::device::GemmUniversal... template instantiations in the test cases added by this PR into test/unit/gemm/device directory. As far as your data concerned, s4 data should be provided as two successive values packed into single byte, and that's all.

I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented May 16, 2024

I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?

No, s4 values should be packed, two values per byte.

@Hongbosherlock
Copy link

Hongbosherlock commented May 23, 2024

I have two s4 values packed in a single byte(uint8). Do I need to unpack the uint8 data to get s4 data before GEMM manually?

No, s4 values should be packed, two values per byte.

Thanks for your help ! I can get correct result now. but I have another question:
Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM: C = A·B, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication : E * C. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments? I am not sure how to do this.
Maybe here is an example what I want to do :https://github.com/NVIDIA/TensorRT-LLM/blob/5d8ca2faf74c494f220c8f71130340b513eea9a9/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L131

@alexsamardzic
Copy link
Contributor Author

Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM: C = A·B, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication : E * C. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments?

If matrix E is really MxN (i.e. not broadcasted), it doesn't seem that the code you linked is doing this exact operation. I'd say the simplest way to achieve this would be through EVT epilogues, these are exactly for the purpose of fusing matrix multiplications with arbitrary operations. For Ampere, there is examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu example demonstrating how to use EVT epilogues, you'd have to remove everything related to Bias/C1 matrices in this example, to use C2 as your matrix E, and then to replace cutlass::plus with cutlass::multiplies in using Compute2 = ... (also, you should take care that all of the data types in the template instantiations are correctly specified).

@Hongbosherlock
Copy link

Assuming that A is int8 and (M, K), B is int4 and (K, N), after GEMM: C = A·B, and C will be (M, N). Now, I have another matrix E, which is fp32 and also (M,N). I want to perform element-wise multiplication : E * C. Can I complete this element-wise multiplication within the this s4/s8 GEMM operation ? for example by passing matrix E toArguments?

If matrix E is really MxN (i.e. not broadcasted), it doesn't seem that the code you linked is doing this exact operation. I'd say the simplest way to achieve this would be through EVT epilogues, these are exactly for the purpose of fusing matrix multiplications with arbitrary operations. For Ampere, there is examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu example demonstrating how to use EVT epilogues, you'd have to remove everything related to Bias/C1 matrices in this example, to use C2 as your matrix E, and then to replace cutlass::plus with cutlass::multiplies in using Compute2 = ... (also, you should take care that all of the data types in the template instantiations are correctly specified).

Thanks, I’m trying this, but it’s not going well currently.
To make it clearer, what I want to do is exactly the following:

    // inputs
    //     A           [M, K]    int8
    //     B           [N, K]    int4
    //     alphaCol    [M, 1]    fp32
    //     alphaRow    [1, N]    fp32
    // outputs
    //     mat [M, N]            fp32

That is: (alphaCol x alphaRow) * (A x B)
I think here is a s8/s8 example(A and B are all int8):https://github.com/NVIDIA/TensorRT-LLM/blob/5d8ca2faf74c494f220c8f71130340b513eea9a9/cpp/tensorrt_llm/kernels/cutlass_kernels/int8_gemm/int8_gemm_template.h#L131, which also uses EVT, and the inputs are passed from here
I wonder if I could use the same EVT code and using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel> with this s4/s8 GEMM.

@Hongbosherlock

This comment was marked as duplicate.

@alexsamardzic
Copy link
Contributor Author

This code uses PyTorch, can you post a reproducible example that uses CUTLASS only?

@Hongbosherlock
Copy link

Hongbosherlock commented Jun 5, 2024

This code uses PyTorch, can you post a reproducible example that uses CUTLASS only?

Hi @alexsamardzic , I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu#L114

you can add this example , then complie and run it:

# cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80

# cutlass/build$ make 61_s4s8_gemm 

#  cutlass/build$ ./examples/61_s4s8_gemm/61_s4s8_gemm

when ElementB = int8_t, it seems ok. you can get the result:
image

But when ElementB = cutlass::int4b_t, lots of compilation errors occur.
image

I am really at a loss and would greatly appreciate any guidance or help you can provide. Thank you very much in advance for your time and assistance!

@alexsamardzic
Copy link
Contributor Author

I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu

Replace cutlass::arch::OpMultiplyAddSaturate with cutlass::arch::OpMultiplyAddMixedInputUpcast.

@Hongbosherlock
Copy link

I have pushed my code here: https://github.com/Hongbosherlock/cutlass/blob/add-mixed-4bit-8bit-gemm/examples/61_s4s8_gemm/s4s8_gemm.cu

Replace cutlass::arch::OpMultiplyAddSaturate with cutlass::arch::OpMultiplyAddMixedInputUpcast.

Works for me. Thanks!

@alexsamardzic
Copy link
Contributor Author

Works for me. Thanks!

Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work.

@Hongbosherlock
Copy link

Hongbosherlock commented Jun 7, 2024

Works for me. Thanks!

Good. Remember that CUTLASS is a heavily templated library, but actually small number of all the possible template argument combination work together - so one cannot just paste pieces of code from different sources, and expect it to work.

Yea, that was a mistake. OpMultiplyAddMixedInputUpcast did appear in the test folder. I think I was a bit disoriented. There are too many arguments.
I think CUTLASS is somewhat challenging for beginners. Do you have any recommended learning paths?

@alexsamardzic
Copy link
Contributor Author

I think CUTLASS is somewhat challenging for beginners. Do you have any recommended learning paths?

Going through relevant examples, as well as unit tests, in the CUTLASS source tree is probably still the best way to start.

@Hongbosherlock
Copy link

Hongbosherlock commented Jun 19, 2024

Hi @alexsamardzic , in fact, I am working in the GEMM+de-quantization fusion kernel for W4A8 based on your PR, similar to the W8A8 kernel for pytorch here (GEMM+de-quantization), which also used EVT.

I have completed most of the work and used EVT to finish the de-quantization.
For simplicity, I started with W8A8 to verify that my EVT is correct, and then I will modify it for W4A8.
Here is the W8A8 GEMM+DQ code. The code can be compiled, but the results are incorrect compared to the Python simulation results.
I tried printing some elements, and they are very large values like 18766625, which might indicate an address access error.

Can you please have a look what the possible issues might be?
Thanks!

@alexsamardzic
Copy link
Contributor Author

Can you please have a look what the possible issues might be? Thanks!

I'm sorry, but this has nothing to do with this particular PR, and unfortunately I don't have cycles to help you with this. You need to be sure that you understand building EVT epilogues, as well as specifying corresponding arguments. Then, in your position I would start with a simple epilogue that is just storing values from the accumulator into the output tensor. If results match expected ones, then I would add nodes into the EVT epilogue that do the multiplication, one by one, and would keep comparing results with the expected ones. When there is mismatch, you should know where to look for the fix.

@Hongbosherlock
Copy link

Hi @alexsamardzic ,thanks for your help, I have make it done.

When profiling a single GEMM, do you think the performance of s4/s8 will be better than that of of s8/s8?
In my test s8/s8(int8 GEMM) is faster.

@alexsamardzic
Copy link
Contributor Author

When profiling a single GEMM, do you think the performance of s4/s8 will be better than that of of s8/s8? In my test s8/s8(int8 GEMM) is faster.

The GEMM actually performed is the same: S8/S8 in both cases. The bandwidth used in transfers between global and shared memory, and then between shared memory and registers, is smaller in S4/S8 case, but on the other side there are additional calculations for S4->S8 conversion in this case. Thus, very broadly: S4/S8 and S8/S8 should be in the same ballpark regarding performance, but it is not unusual if for specific input shapes, and tile sizes selected, one is faster than the other one.

Copy link
Contributor

@manishucsd manishucsd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this. Apologies for a delayed review. LGTM.
Over to NVIDIA/CUTLASS (cc: @hwu36 ) for merging this.

Comment on lines +2868 to +2901
TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you come up with this TileDescription list for S8 x S4? I guess you carried these from S8 x S8. Please make sure all of these pass verification. You can follow steps similar to here to instantiate all the tile shapes listed here by using -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8". By default the build process only instantiate 128x128 tile shape.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please run the verification and profiling on --m=3456 --n=4096 --k=2048 on an A100?

Please compile using -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8,s4,s8" to also have s4 x s4 and s8 x s8 kernels in the runs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tiles selection is desribed in a comment above; also, as mentioned in this comment, I did the verification. I will repeat the verification procedure, together with profiling, and report the outcome here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a build issue after rebasing on the latest main: basically, OpMultiplyAddSaturate for MmaTensorOpPolicy in the specialization of struct DefaultMmaTensorOp (in include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h) seem to be obligatory now, as the build fails if OpMultiplyAdd used. The branch is updated accordingly.

I've configured the build using -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS="s8_s4,s4_s8,s4,s8" CMake options, and then verified that cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 unit test passes. Then, I did profiler runs as follows:

./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s8:row --B=s8:column >& s8_s8.txt
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s4:row --B=s8:column >& s4_s8.txt
./build/tools/profiler/cutlass_profiler --operation=gemm --m=3456 --n=4096 --k=2048 --A=s8:row --B=s4:column >& s8_s4.txt

The corresponding profiler outputs are here:
s8_s8.txt
s4_s8.txt
s8_s4.txt

The disposition values for mixed data types cases with s8 accumulator are still incorrect. Also, the timings are somewhat slower than for corresponding s8xs8 cases (with the same configurations: tile sizes etc.).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for running and sharing these results.

Accumulator is for all of these runs should be S32 as shown at the bottom of the output in csv format with accum type = S32. The Incorrect disposition with mixed-input is happening for only S8 output, i.e., when the accumulators are S32 but the output is downcast-ed to S8.

We do not see incorrect results for S8xS8 with S32 accumulators and S8 output, can you pick one row of incorrect run from

(elementD/elementC type) <= (elementA type) x(elementB type) + (accum type)
S8  <= S8  x S4  + S32 

and compare the same kernel configuration against

S8 <= S8 x S8 + S32

to find where is the difference?

I believe it is to do with initialization of the operands during profiling or inside the kernel epilogue S32-to-S8. quantization.

Also, you can just upload the csv that can be produced by adding --output=filename.csv to the profiler runs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is what I found so far regarding incorrect cases:

First, I made following change in the code generating inputs, in order to generate the same inputs for profiler for S4xS8 and S8xS8 cases:

diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu
index 2cbfa5d2..7b488fe8 100644
--- a/tools/profiler/src/device_context.cu
+++ b/tools/profiler/src/device_context.cu
@@ -105,7 +105,7 @@ DeviceAllocation *DeviceContext::allocate_tensor(
           data_distribution.set_uniform(-1, 1, 0);
           break;
         case library::NumericTypeID::kS4:
-          data_distribution.set_uniform(-2, 2, 0);
+          data_distribution.set_uniform(-3, 3, 0);
           break;
         case library::NumericTypeID::kU2:
           data_distribution.set_uniform(0, 2, 0);

I used following profiler runs to make comparision between S4xS8 and S8xS8 cases (BTW, I found that smaller input shapes selection that would still allow for reproducing the problem would be --m=32 --n=64 --k=512):

cutlass_profiler --operation=gemm --gemm_kind=universal --m=3456 --n=4096 --k=2048 --A=s8:row --B=s8:column --C=s8:column --D=s8:column --accum=s32 --cta_m=256 --cta_n=128 --cta_k=64 --stages=3 --save-workspace=always
cutlass_profiler --operation=gemm --gemm_kind=universal --m=3456 --n=4096 --k=2048 --A=s4:row --B=s8:column --C=s8:column --D=s8:column --accum=s32 --cta_m=256 --cta_n=128 --cta_k=64 --stages=3 --save-workspace=always

By comparing saved .mat files, I verified that input matrices A, B and C are the same, but also that output matrices D are the same. What differs are actually Reference matrices, which means that reference results calculated for S4xS8 case are wrong. If I understood it correctly, cuBLAS is used for reference calculations, so I'll check what's going on there...

Copy link
Contributor

@manishucsd manishucsd Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if cuBLAS is called for this for reference check. The output should show which references are called. You can use --verification-providers=cublas,host,device to run them all. Is there a host reference for this you must check in here /tools/library/src/reference ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed - device provider is actually used for reference check here. I posted an update with the fix in reference calculations, so for most of cases with S8 output, cutlass_profiler reports success now. However, there are still couple of cases where incorrect is reported, I'm looking into this...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed another update - the problem with remaining incorrect cases was that I haven't copied C operand alignment update from S8xS8 case, in the generator code. Everything is reported as passed now by profiler, the output files are attached below. I believe this one should be ready for merging now.

s8_s8.gemm.csv
s4_s8.gemm.csv
s8_s4.gemm.csv

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic thank you for digging it through. LGTM!

@hwu36 , @thakkarV , @IonThruster , can you please help it merge it?

Copy link
Contributor Author

@alexsamardzic alexsamardzic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@hwu36 hwu36 merged commit e1976da into NVIDIA:main Aug 30, 2024
@hwu36
Copy link
Collaborator

hwu36 commented Aug 30, 2024

while we are at this, i think we can improve the int4->int8 upcasting. now we use 11 instructions to upcast 8 elements. quite a lot. we used a look-up-table method to do int->fp8 upcasting (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/numeric_conversion.h#L2983-L3027), I think we maybe able to use the same here.

@alexsamardzic , do you want to give it a try? i am setting up now so it won't take me months to merge your code.

cc @rhenry-nv

@alexsamardzic alexsamardzic deleted the add-mixed-4bit-8bit-gemm branch August 30, 2024 12:05
@alexsamardzic
Copy link
Contributor Author

@alexsamardzic , do you want to give it a try? i am setting up now so it won't take me months to merge your code.

Sure. Below is a patch to implement the look-up table method for int4->int8 (pretty much the same as existing int4->fp8 code), and also the profiler outputs for original and patched version. It seems that the look-up table method is slower.

I ran the profiler in both cases as follows:

cutlass_profiler --operation=Gemm --m=1024 --n=1024 --k=1024 --A=s4:row --B=s8:column

and here are mentioned files:
patch.txt
original.csv
patched.csv

I was the least happy about the conversion code in this PR, but this is the best I was able to come up with...

ucassjy pushed a commit to ucassjy/cutlass that referenced this pull request Sep 4, 2024
* Add support for mixed 4-bit/8-bit data types GEMM

* fix ( and )

---------

Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
ucassjy pushed a commit to ucassjy/cutlass that referenced this pull request Sep 4, 2024
* Add support for mixed 4-bit/8-bit data types GEMM

* fix ( and )

---------

Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants