-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Groupwise scaling along M for FP8 gemm #2037
Groupwise scaling along M for FP8 gemm #2037
Conversation
Hi @hwu36 This PR is from the DeepSeek Team. Could you help review and merge it? The SGLang team wants to implement block-wise FP8 using CUTLASS for DeepSeek V3. This PR is essential for us. Thanks! |
Hi @zhyncs zh This PR looks like a example demo,Has the integration with SGLang been done? Could you post a PR about the integration code with SGLang? |
@ll2088 |
The version developed based on CUTLASS in SGLang, Does it PRed? Could you post it here? |
Not yet. |
9d997ce
to
a08ef31
Compare
And why does ScaleMsPerTile = 128 not work? @soundOfDestiny |
a08ef31
to
0c08d7c
Compare
The issue of incorrect calculation of shared memory size has appeared since #1932. |
0c08d7c
to
df73dd0
Compare
df73dd0
to
3197c81
Compare
examples/65_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt
Outdated
Show resolved
Hide resolved
cuda 12.9 will improve the performance of blockscale/groupscale kernels. |
Hi @soundOfDestiny and @hwu36 |
|
hi @jackkosaian I'm currently working on optimizing this Groupwise-GEMM performance for the Hopper architecture using CUTLASS 3.x and exploring the split-K technique. I've reviewed previous issues related to split-K (#702 (comment), I initially attempted to implement split-K by directly modifying the code here: using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopWithBlockWiseScaling,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; I tried replacing I'm trying to use |
the definition of
the second template argument should be epilogue, rather than |
hi @soundOfDestiny and @jackkosaian |
In cuBLAS? |
@hwu36 BTW, would newer version of Transformer-Engine support generating this kind of group-wise scaling factors? As current TE only supports generating per-tensor scales, thanks! |
Sorry, I don't know. |
Background (copied from #1932)
As we adopt narrower datatypes, traditional scaling methods struggle to maintain accuracy, particularly with 8-bit floating-point types (e.g.,$D = alpha * (A @ B) + beta * C$ , but narrower datatypes necessitate more finer-grained scaling techniques. Before we dive deep into groupwise scaling below is a glossary of various scaling methods:
e5m2_t
,e4m3_t
). The typical GEMM operation uses tensorwise scaling withSummary
As #1932 adds blockwise scaling strategy, this PR is a patch based on #1932 and adds groupwise scaling strategy along M in A tensor. Scaling granularity along M is made independent of CTA Block configuration, however, scaling granularities along N and K are still blockwise (i.e. one scaling value per CTA Block).
This PR restricts scaling granularity along M to a factor of
TILE_SHAPE_M
in CTA Block configuration, while one can set the GEMM scaling granularity along M to exactlyTILE_SHAPE_M
(i.e. fallback to blockwise scaling strategy) and callrepeat_interleave
method on input tensorScaleA
to simulate the situation that scaling granularity is multiplies ofTILE_SHAPE_M
.Groupwise Scaling
In this implementation, we load scaling tensors with more elements than #1932 to shared memory since there might be various scaling along M per CTA Block. However, each thread only needs to load at most 2 scale values for A tensor and exactly one scale value for B tensor from shared memory to registers per iteration because WGMMA accumulators of each thread involve only 2 rows in result tensor.
Performance
I haven't observed a performance degradation compared with #1932
blockwise scaling
groupwise scaling (this PR, setting scaling granularity along M to 64)