Skip to content

Commit

Permalink
FP8 groupwise scaling along M
Browse files Browse the repository at this point in the history
  • Loading branch information
zl committed Jan 21, 2025
1 parent b78588d commit df73dd0
Show file tree
Hide file tree
Showing 8 changed files with 1,418 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ using ArchTag = cutlass::arch::Sm90; // T
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;

using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ cutlass_example_add_executable(
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
)

cutlass_example_add_executable(
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
)

Large diffs are not rendered by default.

45 changes: 36 additions & 9 deletions include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
return (capacity_bytes - carveout_bytes) / stage_bytes;
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int carveout_bytes_, int alignment = 128>
constexpr int
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
constexpr auto scale_bits = cute::sizeof_bits_v<ElementBlockScale>;
constexpr int stage_bytes_ =
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
cutlass::bits_to_bytes(scale_bits * ScaleMsPerTile) + // scale of tensor A
cutlass::bits_to_bytes(scale_bits * 1); // scale of tensor B

constexpr int stage_bytes = cutlass::round_up(stage_bytes_, alignment) +
static_cast<int>(mainloop_pipeline_bytes);
constexpr int carveout_bytes = cutlass::round_up(carveout_bytes_, alignment);
constexpr int capacity_bytes = capacity_bytes_ / alignment * alignment;

return (capacity_bytes - carveout_bytes) / stage_bytes;
}

// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages, int alignment = 128>
constexpr int
Expand Down Expand Up @@ -1009,7 +1031,7 @@ template <
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
class KernelScheduleType
int ScaleGranularityM_
>
struct CollectiveBuilder<
arch::Sm90,
Expand All @@ -1024,12 +1046,12 @@ struct CollectiveBuilder<
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>,
cute::enable_if_t<
(cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>) &&
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>;

static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
Expand All @@ -1048,14 +1070,15 @@ struct CollectiveBuilder<
// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
using ElementBlockScale = ElementAccumulator;

static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();

static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>;
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;

Expand All @@ -1073,9 +1096,13 @@ struct CollectiveBuilder<
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);

static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape_MNK{}) : ScaleGranularityM_;
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
static_assert((size<0>(TileShape_MNK{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");

static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_>;

using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
Expand Down
24 changes: 20 additions & 4 deletions include/cutlass/gemm/collective/fp8_accumulation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,22 @@ struct GmmaFP8Accumulation {
}

// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_core(ElementAccumulator const& scale) {
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;

static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");

static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");

warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accum_); ++i) {
accum_(i) += accum_temp_(i) * scale;
accum_(i) += accum_temp_(i) * scale(i);
}
}

Expand Down Expand Up @@ -142,8 +152,11 @@ struct GmmaFP8Accumulation {
//

/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_if_needed(ElementAccumulator const& scale) {
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
mma_count_ += mma_count_per_mainloop_iteration_;
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
if (reset_accum_flag_) {
Expand All @@ -153,8 +166,11 @@ struct GmmaFP8Accumulation {
}

/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
template <
class EngineScale,
class LayoutScale>
CUTLASS_DEVICE
void scale_residue_if_needed(ElementAccumulator const& scale) {
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
scale_core(scale);
}
Expand Down
Loading

0 comments on commit df73dd0

Please sign in to comment.