Skip to content

Commit

Permalink
FP8 groupwise scaling along M
Browse files Browse the repository at this point in the history
  • Loading branch information
zl committed Jan 13, 2025
1 parent 902dff3 commit 9d997ce
Show file tree
Hide file tree
Showing 8 changed files with 1,390 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ using ArchTag = cutlass::arch::Sm90; // T
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;

using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ cutlass_example_add_executable(
65_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
65_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
)

cutlass_example_add_executable(
65_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
65_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
)

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ template <
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType
int ScaleGranularityM
>
struct CollectiveBuilder<
arch::Sm90,
Expand All @@ -1023,12 +1023,12 @@ struct CollectiveBuilder<
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>,
cute::enable_if_t<
(cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>;

static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
Expand All @@ -1054,7 +1054,7 @@ struct CollectiveBuilder<
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>;
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM>>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;

Expand All @@ -1074,7 +1074,7 @@ struct CollectiveBuilder<

static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;

using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
Expand Down
26 changes: 21 additions & 5 deletions include/cutlass/gemm/collective/fp8_accumulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
//////////////////////////////////////////////////////////////////////////////
///////////////////////////////////FP8 Accumulation///////////////////////////
//////////////////////////////////////////////////////////////////////////////
/// This calss provides API to promote (add) or scale (multiply_add) the results
/// This class provides API to promote (add) or scale (multiply_add) the results
/// from the tensor core accumulators to the main accumulators when the number
/// of MMAs reaches the max number of MMA interval specified by user, after that
/// the tensor core accumulators are zeroed.
Expand Down Expand Up @@ -75,12 +75,22 @@ struct GmmaFP8Accumulation {
}

// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_core(ElementAccumulator const& scale) {
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;

static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");

static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");

warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scale;
accum_(i) += accum_temp_(i) * scale(i);
}
}

Expand Down Expand Up @@ -142,8 +152,11 @@ struct GmmaFP8Accumulation {
//

/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_if_needed(ElementAccumulator const& scale) {
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
mma_count_ += mma_count_per_mainloop_iteration_;
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
if (reset_accum_flag_) {
Expand All @@ -153,8 +166,11 @@ struct GmmaFP8Accumulation {
}

/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_residue_if_needed(ElementAccumulator const& scale) {
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
scale_core(scale);
}
Expand Down
Loading

0 comments on commit 9d997ce

Please sign in to comment.