-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Mixed Input TensorOp #1084
Conversation
Hi @manishucsd, Thanks for the awesome work! One high level question - do you plan to handle the scales in the epilogue? The FT implementation does:
Even with FP32 accum, we saw that model accuracy can degrade if we scale in the epilogue. I think it is because the input range of |
Thanks for bringing this to my attention. I am not planning to handle this requirement in the current PR, just mixed-input upcast GEMMs. It should be passing additional scalar argument to the mainloop and let us handle it in another PR? |
022500d
to
0bee609
Compare
e1af683
to
a1284f4
Compare
fcc2bf7
to
da44e43
Compare
Thanks for adding mixed GEMM support! I have the same question as OP: Any follow-ups on implementing scales, and also how about (u)int4 support (packed by two into an 8-bit integer)? |
This PR introduces support for mixed input data types for the NVIDIA A100. The PR is motivated by rising trend of utilizing activations in 16-bit floating-point (either F16 or BF16) alongside weights in 8-bit integer.
Supported GEMM datatype configurations
F16 with 8b Integers
Upcast on Operand B
[F16|F32] <= F16 * S8 + [F16|F32]
[F16|F32] <= F16 * U8 + [F16|F32]
Upcast on Operand A
[F16|F32] <= S8 * F16 + [F16|F32]
[F16|F32] <= U8 * F16 + [F16|F32]
BF16 with 8b Integers
Upcast on Operand B
[BF16|F32] <= BF16 * S8 + F32
[BF16|F32] <= BF16 * U8 + F32
Upcast on Operand A
[BF16|F32] <= S8 * BF16 + F32
[BF16|F32] <= U8 * BF16 + F32
Key Features of the Mixed Input Computation:
Memory Layout: TN layout, which implies RowMajor for matrix A and ColumnMajor for matrix B when they're in the Global Memory.
Mainloop Execution Blueprint:
cp.async
[orcp.async.bulk
].cp.async
(A100 version) is implemented in this PR).ldmatrix
instruction.mma
Operation:FragmentShuffler
component, which efficiently reorganizes data within a warp to align with the thread-value mapping needed by themma
operation.ldmatrix
operates on 8-bit values (S8
orU8
) andmma.sync
works on 16-bit values (F16
orBF16
), there's a mismatch in thread ownership postldmatrix
. The data isn't aligned with whatmma.sync.[f16|f32].f16.f16.[f16|f32]
ormma.sync.f32.bf16.bf16.f32
expects. TheFragmentShuffler
addresses this challenge. As all necessary data is within the warp (32 threads),FragmentShuffler
shuffles it to achieve a thread layout that conforms tomma.sync
.4. Data Type Conversion: Adapts data to align with the wider of the two datatype, making it suitable for processing by the math instruction. This is facilitated by the
FragmentConverter
, leveraging the efficiencyt and newly addedFastNumericArrayConverter
for each of the four conversion combinations [4xS8
|4xU8
] -> [4xF16
|4xBF16
].Performance Profiling
The top commit on this PR archives the performance as shown by the graph above for the various mixed-input datatype on an NVIDIA A100 40GB SXM4 Chip
Reproducing Performance Results
In addition to the implementation, the kernels are also added to the profiler. To repro the above performance results, follow the following steps:
Breakdown of Performance Improvements
The figure above shows the performance improvements for one mixed input datatype with increasing level of mainloop optimization. The idea is to avoid long-latency memory instructions, and long sequence of arithmetic conversion instructions inside the mainloop.
Functional Testing
Added device level tests.
Added warp-level tests for each variant. Tests the
FragmentShuffler
andFragmentConverter
.If this fails most-likely the issue will be in the warp-level components. Specifically in the
FragmentShuffler
. Note that any bugs inFragmentConverter
will be caught in the newly addedtest/core
described in the below bullet.FastNumericConverter
. For 8b integer to 16b floating point, it tests all 8bit integer patterns which are not that many :).Small Improvements in This PR (Pushing as we Review)
FragmentShuffler
needs to specialized for matrix A and matrix B. It currently works when 8b integers are coming from matrix B.Future Improvements (To be taken up in seperate new PRs)
I believe there is more optimal sequence of PTX instructions to achieve the following:
S8
->BF16
conversion can get better and I am working on it. This particular component is a partial specializationstruct FastNumericArrayConverter<cutlass::bfloat16_t, int8_t, 4, Round>
innumeric_converter.h
.SFHL
instructions to achieve layout conversion byFragmentShuffler
could potentially be reduced by increasing a bit of register pressure.We can take up these potential improvements in the later PRs.