Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Mixed Input TensorOp #1084

Merged
merged 21 commits into from
Sep 27, 2023

Conversation

manishucsd
Copy link
Contributor

@manishucsd manishucsd commented Sep 11, 2023

This PR introduces support for mixed input data types for the NVIDIA A100. The PR is motivated by rising trend of utilizing activations in 16-bit floating-point (either F16 or BF16) alongside weights in 8-bit integer.

Supported GEMM datatype configurations

F16 with 8b Integers

Upcast on Operand B

  • [F16|F32] <= F16 * S8 + [F16|F32]
  • [F16|F32] <= F16 * U8 + [F16|F32]

Upcast on Operand A

  • [F16|F32] <= S8 * F16 + [F16|F32]
  • [F16|F32] <= U8 * F16 + [F16|F32]

BF16 with 8b Integers

Upcast on Operand B

  • [BF16|F32] <= BF16 * S8 + F32
  • [BF16|F32] <= BF16 * U8 + F32

Upcast on Operand A

  • [BF16|F32] <= S8 * BF16 + F32
  • [BF16|F32] <= U8 * BF16 + F32

Key Features of the Mixed Input Computation:

  • Memory Layout: TN layout, which implies RowMajor for matrix A and ColumnMajor for matrix B when they're in the Global Memory.

  • Mainloop Execution Blueprint:

    1. Data Loading: Copies data from Global Memory to Shared Memory (using either cp.async [or cp.async.bulk]. cp.async(A100 version) is implemented in this PR).
    2. Shared Memory to Register: Loads data from Shared Memory to Registers through the ldmatrix instruction.
    3. Layout Adaptation for mma Operation:
      • Incorporates the FragmentShuffler component, which efficiently reorganizes data within a warp to align with the thread-value mapping needed by the mma operation.
      • Given that ldmatrix operates on 8-bit values (S8 or U8) and mma.sync works on 16-bit values (F16 or BF16), there's a mismatch in thread ownership post ldmatrix. The data isn't aligned with what mma.sync.[f16|f32].f16.f16.[f16|f32] or mma.sync.f32.bf16.bf16.f32 expects. The FragmentShuffler addresses this challenge. As all necessary data is within the warp (32 threads), FragmentShuffler shuffles it to achieve a thread layout that conforms to mma.sync.
        4. Data Type Conversion: Adapts data to align with the wider of the two datatype, making it suitable for processing by the math instruction. This is facilitated by the FragmentConverter, leveraging the efficiencyt and newly addedFastNumericArrayConverter for each of the four conversion combinations [4xS8 | 4xU8] -> [4xF16|4xBF16].

Performance Profiling

GEMM Performance with Mixed-Input Datatype on NVIDIA A100 40GB SXM4

The top commit on this PR archives the performance as shown by the graph above for the various mixed-input datatype on an NVIDIA A100 40GB SXM4 Chip

Reproducing Performance Results

In addition to the implementation, the kernels are also added to the profiler. To repro the above performance results, follow the following steps:

# cmake to generate only the kernels that we want to profile
build $ cmake ../cutlass/ -DCUTLASS_NVCC_ARCHS='80' -DCUTLASS_LIBRARY_KERNELS="\
s16816gemm_f16_s8_128x128_64x3,\
s16816gemm_s8_f16_128x128_64x3,\
s16816gemm_u8_f16_128x128_64x3,\
s16816gemm_f16_u8_128x128_64x3,\
s16816gemm_bf16_s8_128x128_64x3,\
s16816gemm_s8_bf16_128x128_64x3,\
s16816gemm_bf16_u8_128x128_64x3,\
s16816gemm_u8_bf16_128x128_64x3,\
tensorop_s16816gemm_f16_128x128_64x*_tn_align8,\
tensorop_s16816gemm_bf16_128x128_64x*_tn_align8" -DCUTLASS_NVCC_KEEP=OFF -DCUTLASS_LIBRARY_IGNORE_KERNELS="gemm_grouped*,gemm_planar*,f16_s16816gemm_f16_128x128_64x4_tn_align8"

# build cutlass_profiler and stuff affected by this PR
build $ make cutlass_test_unit_core cutlass_test_unit_gemm_warp cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 cutlass_profiler -j

# On NVIDIA A100-SXM4-40GB - SM 8.0, 108 SMs @ 1410 MHz, L2 cache: 40 MB, Global Memory: 39 GB
build $ ./tools/profiler/cutlass_profiler --op_class=tensorop --m=3456 --n=4096 --k=8192 --output=cutlass_data.csv

Breakdown of Performance Improvements

Mixed Input Data Type (F16 _= F16 _ U8 + F16) Performance Progressing with Increasing Mainloop Optimizations

The figure above shows the performance improvements for one mixed input datatype with increasing level of mainloop optimization. The idea is to avoid long-latency memory instructions, and long sequence of arithmetic conversion instructions inside the mainloop.

Functional Testing

  • Added device level tests.

  • Added warp-level tests for each variant. Tests the FragmentShuffler and FragmentConverter.

build $ make cutlass_test_unit_gemm_warp -j

$ ./test/unit/gemm/warp/cutlass_test_unit_gemm_warp --gtest_filter=SM80_warp_gemm_mixed_input_tensor_op*
Note: Google Test filter = SM80_warp_gemm_mixed_input_tensor_op*
[==========] Running 5 tests from 3 test suites.
[----------] Global test environment set-up.
[----------] 2 tests from SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8
[ RUN      ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8.128x128x64_64x64x64_16x8x16
[       OK ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8.128x128x64_64x64x64_16x8x16 (155 ms)
[ RUN      ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8.64x64x64_64x64x64_16x8x16
[       OK ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8.64x64x64_64x64x64_16x8x16 (3 ms)
[----------] 2 tests from SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8 (158 ms total)

[----------] 2 tests from SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8
[ RUN      ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8.64x64x64_64x64x64_16x8x16
[       OK ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8.64x64x64_64x64x64_16x8x16 (3 ms)
[ RUN      ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8.128x128x64_64x64x64_16x8x16
[       OK ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8.128x128x64_64x64x64_16x8x16 (4 ms)
[----------] 2 tests from SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8 (7 ms total)

[----------] 1 test from SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8
[ RUN      ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8.64x64x64_64x64x64_16x8x16
[       OK ] SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8.64x64x64_64x64x64_16x8x16 (2 ms)
[----------] 1 test from SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8 (2 ms total)

[----------] Global test environment tear-down
[==========] 5 tests from 3 test suites ran. (169 ms total)
[  PASSED  ] 5 tests.

If this fails most-likely the issue will be in the warp-level components. Specifically in the FragmentShuffler. Note that any bugs in FragmentConverter will be caught in the newly added test/core described in the below bullet.

  • Added new test for FastNumericConverter. For 8b integer to 16b floating point, it tests all 8bit integer patterns which are not that many :).
build $ ./test/unit/core/cutlass_test_unit_core --gtest_filter=FastNumericConversion.*
Note: Google Test filter = FastNumericConversion.*
[==========] Running 5 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 5 tests from FastNumericConversion
[ RUN      ] FastNumericConversion.s32_to_f32
[       OK ] FastNumericConversion.s32_to_f32 (152 ms)
[ RUN      ] FastNumericConversion.s8_to_f16_array
[       OK ] FastNumericConversion.s8_to_f16_array (0 ms)
[ RUN      ] FastNumericConversion.u8_to_f16_array
[       OK ] FastNumericConversion.u8_to_f16_array (0 ms)
[ RUN      ] FastNumericConversion.u8_to_bf16_array
[       OK ] FastNumericConversion.u8_to_bf16_array (0 ms)
[ RUN      ] FastNumericConversion.s8_to_bf16_array
[       OK ] FastNumericConversion.s8_to_bf16_array (0 ms)
[----------] 5 tests from FastNumericConversion (153 ms total)

[----------] Global test environment tear-down
[==========] 5 tests from 1 test suite ran. (153 ms total)
[  PASSED  ] 5 tests.

Small Improvements in This PR (Pushing as we Review)

  • FragmentShuffler needs to specialized for matrix A and matrix B. It currently works when 8b integers are coming from matrix B.

Future Improvements (To be taken up in seperate new PRs)

I believe there is more optimal sequence of PTX instructions to achieve the following:

  • S8->BF16 conversion can get better and I am working on it. This particular component is a partial specialization struct FastNumericArrayConverter<cutlass::bfloat16_t, int8_t, 4, Round> in numeric_converter.h.
  • The number of SFHL instructions to achieve layout conversion by FragmentShufflercould potentially be reduced by increasing a bit of register pressure.
    We can take up these potential improvements in the later PRs.

test/unit/gemm/warp/CMakeLists.txt Outdated Show resolved Hide resolved
tools/library/CMakeLists.txt Outdated Show resolved Hide resolved
@rhenry-nv
Copy link

rhenry-nv commented Sep 11, 2023

Hi @manishucsd,

Thanks for the awesome work! One high level question - do you plan to handle the scales in the epilogue?

The FT implementation does:

half_t converted_val = i2f(val)
half_t mma_input_a = scale * converted_val
acc += mma(mma_input_a, mma_input_b)

Even with FP32 accum, we saw that model accuracy can degrade if we scale in the epilogue. I think it is because the input range of mma_input_a is much larger if we do mma with it unscaled (since scale is usually much smaller than 1).

@manishucsd
Copy link
Contributor Author

Hi @manishucsd,

Thanks for the awesome work! One high level question - do you plan to handle the scales in the epilogue?

The FT implementation does:

half_t converted_val = i2f(val)
half_t mma_input_a = scale * converted_val
acc += mma(mma_input_a, mma_input_b)

Even with FP32 accum, we saw that model accuracy can degrade if we scale in the epilogue. I think it is because the input range of mma_input_a is much larger if we do mma with it unscaled (since scale is usually much smaller than 1).

Thanks for bringing this to my attention. I am not planning to handle this requirement in the current PR, just mixed-input upcast GEMMs. It should be passing additional scalar argument to the mainloop and let us handle it in another PR?

@manishucsd manishucsd force-pushed the mixed_input_tensor_op_sm80 branch 2 times, most recently from 022500d to 0bee609 Compare September 14, 2023 00:04
@manishucsd manishucsd force-pushed the mixed_input_tensor_op_sm80 branch from e1af683 to a1284f4 Compare September 21, 2023 04:39
@manishucsd manishucsd force-pushed the mixed_input_tensor_op_sm80 branch from fcc2bf7 to da44e43 Compare September 26, 2023 21:58
@manishucsd manishucsd changed the title Support for Mixed Input TensorOp ([F16 | BF16] * [S8 | U8]) Support for Mixed Input TensorOp ([F16 | BF16] * [S8 | U8]) OR ([S8 | U8] * [F16 | BF16]) Sep 26, 2023
@manishucsd manishucsd changed the title Support for Mixed Input TensorOp ([F16 | BF16] * [S8 | U8]) OR ([S8 | U8] * [F16 | BF16]) Support for Mixed Input TensorOp Sep 26, 2023
@hwu36 hwu36 merged commit 7d8317a into NVIDIA:main Sep 27, 2023
@alexsamardzic
Copy link
Contributor

Hi @manishucsd,
Thanks for the awesome work! One high level question - do you plan to handle the scales in the epilogue?
The FT implementation does:

half_t converted_val = i2f(val)
half_t mma_input_a = scale * converted_val
acc += mma(mma_input_a, mma_input_b)

Even with FP32 accum, we saw that model accuracy can degrade if we scale in the epilogue. I think it is because the input range of mma_input_a is much larger if we do mma with it unscaled (since scale is usually much smaller than 1).

Thanks for bringing this to my attention. I am not planning to handle this requirement in the current PR, just mixed-input upcast GEMMs. It should be passing additional scalar argument to the mainloop and let us handle it in another PR?

Thanks for adding mixed GEMM support! I have the same question as OP: Any follow-ups on implementing scales, and also how about (u)int4 support (packed by two into an 8-bit integer)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants