Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add couple configs into generator.py for mixed input MM #1350

Merged
merged 4 commits into from
Aug 16, 2024

Conversation

alexsamardzic
Copy link
Contributor

I'm adding (PR here) CUTLASS kernels as an auto-tune option for PyTorch compiler, and it would be nice to have these additional configurations available. This is not urgent, and more of alike changes may be further desired, so if it's OK to make changes like this then this PR could be kept open for while, and I'll make further additions, as needed, to it.

@manishucsd : Would it make sense for GenerateSM80_TensorOp_16816_mixed_input_upcast_a and GenerateSM80_TensorOp_16816_mixed_input_upcast_b to be symmetric w.r.t. math_instructions and tile_descriptions? I can change it through this PR too.

@manishucsd
Copy link
Contributor

to be symmetric w.r.t. math_instructions and tile_descriptions.

What do you mean by symmetric (same?). Tensor Core math_instruction shape for both upcast_a and upcast_b is 16816. The supported tile_description (more precisely tile shape) may need to be different for upcast_a vs upcast_b.

@alexsamardzic
Copy link
Contributor Author

By symmetry, I meant on math_instructions list within given generator methods: I was thinking that, if GenerateSM80_SparseTensorOp_16832 method has for example DataType.f16, DataType.f16, DataType.f32 combination listed there, then upcast_a method should have DataType.s8, DataType.f16, DataType.f32 and DataType.u8, DataType.f16, DataType.f32, and upcast_b method should have DataType.f16, DataType.s8, DataType.f32 and DataType.f16, DataType.u8, DataType.f32; and alike for other elements of this list in GenerateSM80_SparseTensorOp_16832 method. I've update the PR with all the changes I think should be made in that regard.

As far as tile_descriptions lists concerned, I thought that most of them should be the same between GenerateSM80_SparseTensorOp_16832, and upcast_a and upcast_b methods - my reasoning was that the multiplication itself is 16-bit in each case. On the other side, less shared memory is used for mixed data-types case, so some configurations may be different, but at least I'd expect some kind of "symmetry" between upcast_a and upcast_b, in the sense that I don't get why the first one has only 9 elements in tile_descriptions list, while the other one has 16.

@manishucsd
Copy link
Contributor

manishucsd commented Feb 21, 2024

For math_instructions makes sense. Yes, we should have the support for combinations you listed. Once you add those, please ensure the references for the same are also in place, run the verification to ensure the kernel runs and verifies.

For tile_descriptions, the 8-bit operand needs to be loaded from GMEM to SMEM and this puts some restrictions on what tile_descriptions (shapes) are currently supported. These may not be same for upcast_a and upcast_b.

@alexsamardzic
Copy link
Contributor Author

Thanks for the clarification. I've updated gemm_fp_mixed_input.cu in my PR. W.r.t. verification - is there an "official" way to do it? I've checked that, on A100, whenever for example there is an item in gemm_fp_mixed_input.cu like:

make_gemm_real_canonical_layouts<
    uint8_t,
    half_t,
    half_t,
    half_t,
    half_t
  >(manifest);

that matching cutlass_profiler run, in this case:

cutlass_profiler --A=u8 --B=f16 --C=f16 --accum=f16

produces at least one line of profiling output, which should mean that the kernel compiled and ran successfully.

As a matter of fact, in these tests of mine cutlass_profiler always produces exactly one line of output, and if I try to force (using cta_m/cta_n/cta_k etc. command-line arguments) different tile description (but still one listed in generator.py, for upcast_a in this particular case) from one that cutlass_profiler printed out, it won't print anything. The cutlass_profiler is not giving any information about tile descriptions that it tried, but that didn't work; on the other side, in the PyTorch mixed data-types GEMM auto-tuning context mentioned in my first comment, more information about the compilation is printed, and I noticed that kernels generated by cutlass_library would fail to compile for some tile descriptions. This is another reason that I'm interested in proper verification.

@alexsamardzic
Copy link
Contributor Author

Asking again: how to properly run verification after my changes?

@manishucsd
Copy link
Contributor

manishucsd commented Mar 5, 2024

  1. For your mixed-input case, add a device-level unit test. Track similar unit test from here.

  2. You should also test if the profiler is working with verification for your mixed-input case. Tips on achiving that:

  • Use cmake flags to compile only the kernels you are interested in. Use the below cmake command as an example to create your own
cmake --no-warn-unused-cli -DCMAKE_BUILD_TYPE:STRING=Release -DCUTLASS_NVCC_ARCHS:STRING=80 -DCUTLASS_NVCC_KEEP:STRING=OFF -DCUTLASS_ENABLE_F16C:STRING=ON -DCUTLASS_LIBRARY_KERNELS:STRING=f16_s16816gemm_f16_s8_128x128_64x*,f16_s16816gemm_s8_f16_128x128_64x*,f16_s16816gemm_u8_f16_128x128_64x*,f16_s16816gemm_f16_u8_128x128_64x*,bf16_s16816gemm_bf16_s8_128x128_64x*,bf16_s16816gemm_s8_bf16_128x128_64x*,bf16_s16816gemm_bf16_u8_128x128_64x*,bf16_s16816gemm_u8_bf16_128x128_64x*,f16_s16816gemm_f16_128x128_64x*_tn_align8,bf16_s16816gemm_bf16_128x128_64x*_tn_align8 -DCUTLASS_LIBRARY_IGNORE_KERNELS:STRING=gemm_grouped*,gemm_planar* -DCMAKE_EXPORT_COMPILE_COMMANDS:BOOL=TRUE -DCMAKE_C_COMPILER:FILEPATH=/usr/bin/gcc -DCMAKE_CXX_COMPILER:FILEPATH=/usr/bin/g++ -S/mnt/disks/gcloud_workspace/repos/cutlass/cutlass_tree_2/cutlass -B/mnt/disks/gcloud_workspace/repos/cutlass/cutlass_tree_2/build -G Ninja

The cmake flags to play with are

        // "CUTLASS_LIBRARY_KERNELS": "tensorop_s16816fprop_optimized_f16_256x128_32x3_nhwc_align8,s16816gemm_bf16_128x128_64x3_tn_align8,s16816gemm_f16_128x128_64x3_tn_align8,h16816gemm_128x128_64x3_tn_align8",
        // Upcast on OperandA and OperandB
        "CUTLASS_LIBRARY_KERNELS": "f16_s16816gemm_f16_s8_128x128_64x*,f16_s16816gemm_s8_f16_128x128_64x*,f16_s16816gemm_u8_f16_128x128_64x*,f16_s16816gemm_f16_u8_128x128_64x*,bf16_s16816gemm_bf16_s8_128x128_64x*,bf16_s16816gemm_s8_bf16_128x128_64x*,bf16_s16816gemm_bf16_u8_128x128_64x*,bf16_s16816gemm_u8_bf16_128x128_64x*,f16_s16816gemm_f16_128x128_64x*_tn_align8,bf16_s16816gemm_bf16_128x128_64x*_tn_align8",
        // Upcast on OperandB only        
        // "CUTLASS_LIBRARY_KERNELS": "s16816gemm_f16_s8_*,s16816gemm_bf16_s8_*,s16816gemm_bf16_128x128_64x*_tn_align8,s16816gemm_f16_128x128_64x*_tn_align8",
        "CUTLASS_LIBRARY_IGNORE_KERNELS": "gemm_grouped*,gemm_planar*"
  • compile cutlass_profiler, make sure the kernel you are interested is generated and compiled.

  • Use ./cutlass_profiler --kernels="kernel_name" to run the kernel you are interested.

Apologies for the delayed response. I have been OOO for last few weeks.

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Mar 7, 2024

Thanks for the clarifications.

PR is updated with the changes suggested: Added number of tests, so that it should be all consistent now between tests, generator.py and gemm_fp_mixed_input.cu. Also fixed several unrelated typos in the generator and tests.

Script used to validate that cutlass_profiler generates kernels for mixed data-types
#! /bin/bash

IFS=","

for cfg in \
    s8,f16,f32,f32 \
    u8,f16,f32,f32 \
    s8,bf16,f32,f32 \
    u8,bf16,f32,f32 \
    s8,f16,f16,f32 \
    u8,f16,f16,f32 \
    s8,bf16,bf16,f32 \
    u8,bf16,bf16,f32 \
    s8,f16,f16,f16 \
    u8,f16,f16,f16
do
    set -- $cfg
    ./tools/profiler/cutlass_profiler \
        --operation=gemm --op_class=tensorop \
        --A=$1 --B=$2 --C=$3 --accum=$4
    read -n1 -s -r -p $"A=$1 B=$2 C=$3 accum=$4 done - Press any key to continue..." key
    ./tools/profiler/cutlass_profiler \
        --operation=gemm --op_class=tensorop \
        --A=$2 --B=$1 --C=$3 --accum=$4
    read -n1 -s -r -p $"A=$2 B=$1 C=$3 accum=$4 done - Press any key to continue..." key
done

Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@alexsamardzic
Copy link
Contributor Author

@hwu36: Thanks for the test fix! The problem with the configurations added in your commit is that they won't work - one could try to change for example exactly the test you fixed, to one of these configurations (I'm attaching an example in the patch below), and then to do make cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 - it won't build, reporting first error:

include/cutlass/epilogue/threadblock/output_tile_thread_map.h(262): error: static assertion failed with "Iteration Count Row must be > 0"

(The reason I removed some configurations from the generator in my commit is exactly that I tried them all, and removed those that won't compile.)

patch.txt

@hwu36
Copy link
Collaborator

hwu36 commented Jul 11, 2024

I did not change unit test.

The reason that profiler cannot do 128x32 is due to epilogue alignment. I fixed that. So 128x32 is fine now.

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Jul 11, 2024

I did not change unit test.

I meant on fixing the test name :-)

The reason that profiler cannot do 128x32 is due to epilogue alignment. I fixed that. So 128x32 is fine now.

Ah, I see - indeed it works this way.

So, do you think this PR could be merged now? And, unrelated: any chance to have PR 1190 reviewed by someone from the team - it's stalled for very long time, and this functionality would be still very useful (at least for PyTorch)?

@hwu36
Copy link
Collaborator

hwu36 commented Jul 11, 2024

i will merge this pr after 3.5.1 pr is merged.

@manishucsd is reviewing pr1190. that one changed mainloop, it will take a while.

@manishucsd
Copy link
Contributor

Thank you for the change. Overall looks good. Can do the following?

  1. CUTLASS Profiler Output for All the mixed input GEMMs
build $ cmake ../cutlass/ -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_F16C=ON -DCMAKE_BUILD_TYPE=Release -DCUTLASS_LIBRARY_KERNELS="u8_f16,s8_f16,f16_u8,f16_s8,u8_bf16,s8_bf16,bf16_u8,bf16_s8,"

build $ ./tools/profiler/cutlass_profiler --output=mixed_input.csv

Dump the output put CSV here and make sure all the mixed input variants Disposition column has Passed .

  1. I am hoping you have ran the newer unit tests cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 and everything passes here.

  2. Can you also run the full test on Ampere atleast? This should help @hwu36 in merging it.

cmake ../cutlass/ -DCUTLASS_NVCC_ARCHS='80'
make all -j32
make test ARGS='-j32 --output-on-error'

@hwu36
Copy link
Collaborator

hwu36 commented Jul 11, 2024

@manishucsd , do you mean this pr or 1190? i did all the testing for this pr myself.

manishucsd

This comment was marked as off-topic.

@manishucsd
Copy link
Contributor

manishucsd commented Jul 11, 2024

@manishucsd , do you mean this pr or 1190? i did all the testing for this pr myself.

I meant this one. However, this one is adding only unit tests and more instances to the cutlass_library, and if you already tested it, then this is good to go.

It will be good know the following two items:

  1. Reason behind the failing unit test here.
  2. I do see some unverified rows in the output of cutlass profiler for mixed-input runs.

@alexsamardzic
Copy link
Contributor Author

2. I do see some unverified rows in the output of cutlass profiler for mixed-input runs.

I've built and ran the profiler according to instructions you provided above, and the Disposition is passed for each row in the .csv file produced. Can you clarify about "unverified rows" you found?

@alexsamardzic
Copy link
Contributor Author

i will merge this pr after 3.5.1 pr is merged.

@manishucsd is reviewing pr1190. that one changed mainloop, it will take a while.

@manishucsd: Maybe you could review PR 1413 before 1190? The 1413 is third, and last, of my PRs awaiting for the review, the changes closely follow int8/float16 functionality you implemented initially, and the only thing that may need changes is the conversion. So it should be possible to review/update/merge this one relatively quickly, and this particular mixed data-types combination is also highly requested (well, at least for PyTorch).

The 1190 will need more time. It's not ready for merging, just has couple commits providing proof of the concept in slightly different ways, each one with its disadvantages. This one may need input from other members of the team, and will certainly take more time to review, and additional work by me to complete.

@manishucsd
Copy link
Contributor

2. I do see some unverified rows in the output of cutlass profiler for mixed-input runs.

I've built and ran the profiler according to instructions you provided above, and the Disposition is passed for each row in the .csv file produced. Can you clarify about "unverified rows" you found?

Here is my run
mixed_input.gemm.csv

Comment on lines 262 to 267
# Following test could be created by making a copy of
# gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f32_sm80.cu,
# and then replacing occurences of "half_t" in the code with
# "bfloat16_t". Such test would fail, but with only a single value
# differing from the reference.
#gemm_universal_u8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing this. There is nothing in the instance that I could say causing the issue. Can you try smaller range for initializing operandB of BF16. The current range is from -5 to 5, can you do try -3 to 3?

@alexsamardzic
Copy link
Contributor Author

alexsamardzic commented Jul 11, 2024

Here is mine:
mixed_input.gemm.csv - all fine here. Not sure what could be causing the difference... I tested on A100, the only difference may be that I re-based my branch on CUTLASS main.

Edit: I re-built, without rebasing the branch on latest main, and all rows in the .csv file still have passed as Disposition.

@manishucsd
Copy link
Contributor

LGTM. @hwu36 and CUTLASS team can you please merge this?

cc: @alexsamardzic

@manishucsd
Copy link
Contributor

Checking the status on this reviewed PR. If this is already merged?

@alexsamardzic
Copy link
Contributor Author

Checking the status on this reviewed PR. If this is already merged?

It doesn't seem to be merged yet.

@hwu36 hwu36 merged commit 3f084f7 into NVIDIA:main Aug 16, 2024
@alexsamardzic alexsamardzic deleted the add-mixed-input-configs branch August 30, 2024 12:05
ucassjy pushed a commit to ucassjy/cutlass that referenced this pull request Sep 4, 2024
* Add couple configs into generator.py for mixed input MM

* change one unit test name; reenable 128x32 in the profiler

* Added U8/BF16 tests.

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
ucassjy pushed a commit to ucassjy/cutlass that referenced this pull request Sep 4, 2024
* Add couple configs into generator.py for mixed input MM

* change one unit test name; reenable 128x32 in the profiler

* Added U8/BF16 tests.

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
Co-authored-by: Haicheng Wu <57973641+hwu36@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants