Skip to content

Commit

Permalink
Add int4b_t/uint4b_t support for mixed dtypes GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandar Samardžić authored and alexsamardzic committed Jul 13, 2024
1 parent 56b46e2 commit b96bd61
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 22 deletions.
36 changes: 25 additions & 11 deletions include/cutlass/gemm/threadblock/mma_multistage.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,15 @@ class MmaMultistage :
/// Pair of B fragments used to overlap shared memory loads and math instructions
WarpLoadedFragmentB warp_loaded_frag_B_[2];
WarpTransformedFragmentB warp_transformed_frag_B_[2];

using ElementA = typename WarpLoadedFragmentA::Element;
using ElementB = typename WarpLoadedFragmentB::Element;
static constexpr size_t sizeof_bits_A =
cutlass::sizeof_bits<ElementA>::value;
static constexpr size_t sizeof_bits_B =
cutlass::sizeof_bits<ElementB>::value;
static constexpr bool is_mixed_and_B_4bit =
(sizeof_bits_A != sizeof_bits_B) && (sizeof_bits_B == 4);
};


Expand Down Expand Up @@ -254,7 +263,7 @@ class MmaMultistage :
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations / (PipeState::is_mixed_and_B_4bit ? 2 : 1), 0});
smem_read_stage_idx_ = 0;
}
}
Expand Down Expand Up @@ -510,25 +519,31 @@ class MmaMultistage :
++this->warp_tile_iterator_A_;

// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
if constexpr (!PipeState::is_mixed_and_B_4bit) {
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_B_;
} else if ((warp_mma_k + 1) % 2 == 0) {
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k / 2 + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k / 2 + 1) % 2]);
++this->warp_tile_iterator_B_;
}

// Except for the first warp-tile, all warp-tiles convert their incoming shared memory fragments as necessary
if (warp_mma_k > 0) {
warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2],
pipe_state.warp_loaded_frag_A_[warp_mma_k % 2],
pipe_state.warp_loaded_frag_B_[warp_mma_k % 2]);
pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2]);
}

// Execute the current warp-tile of MMA operations
if (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2],
pipe_state.tmp_accum_
);

Expand All @@ -541,7 +556,7 @@ class MmaMultistage :
warp_mma_(
accum,
pipe_state.warp_transformed_frag_A_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[warp_mma_k % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2) % 2 : warp_mma_k % 2],
accum
);
}
Expand Down Expand Up @@ -596,12 +611,11 @@ class MmaMultistage :
// the first warp-tile of the next iteration, if necessary (so we can
// immediately start issuing MMA instructions at the top of the loop )
if (warp_mma_k + 1 == Base::kWarpGemmIterations) {

warp_mma_.transform(
pipe_state.warp_transformed_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2],
pipe_state.warp_transformed_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2],
pipe_state.warp_loaded_frag_B_[(warp_mma_k + 1) % 2]);
pipe_state.warp_loaded_frag_B_[PipeState::is_mixed_and_B_4bit ? (warp_mma_k / 2 + 1) % 2 : (warp_mma_k + 1) % 2]);
}

}
Expand Down
3 changes: 2 additions & 1 deletion include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ struct DefaultMmaTensorOp<
/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32,
/// or F16 <= F16 x S4 + F16, F16 <= BF16 x S4 + F32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
Expand Down
167 changes: 157 additions & 10 deletions include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,117 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
return result;
}

};
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 4b (S4/U4)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 4)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;

static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kB;

using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;

private:
int src_lane_0_, src_lane_1_;
uint32_t byte_selector_0_, byte_selector_10_, byte_selector_11_;
int dst_incr_0_, dst_incr_1_;

public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
int mul;

src_lane_0_ = lane_id ^ 1;
mul = lane_id & 1;
byte_selector_0_ = mul * 0x3715 + (1 - mul) * 0x6240;

src_lane_1_ = lane_id ^ 2;
mul = (lane_id & 2) >> 1;
byte_selector_10_ = mul * 0x7632 + (1 - mul) * 0x5410;
byte_selector_11_ = mul * 0x5410 + (1 - mul) * 0x7632;
dst_incr_0_ = mul * (WarpFragment::kElements / 16);
dst_incr_1_ = (1 - mul) * (WarpFragment::kElements / 16);
}

CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {

WarpFragment result;

MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result);

uint32_t const* src_ptr = reinterpret_cast<uint32_t const *>(&mma_frag_src_ptr[0]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t *>(&mma_frag_dst_ptr[0]);

// The code assumes that twice more values than needed for a
// F16/BF16 MMA is loaded along contiguous dimension. E.g. in the
// case of column major matrix: threads 0-3 would hold 32 elements
// of the first column in the warp fragment, threads 0-4 32
// elements of the second column, etc.; but only the first 16
// elements of each column will be used for the first MMA
// operation, and the last 16 elements will be used for the
// follow-up MMA operation. This code distributes input values
// across threads so that all of the left (in case of row-major
// matrix) or upper (in case of column-major matrix) half of
// values comes first, and then right/lower half of values comes
// second in corresponding warp fragments. The values are also
// re-distributed between threads so that each value belongs to
// the proper thread for F16/BF16 MMA that will take place after
// the up-casting.

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < WarpFragment::kElements / 16; n++) {
// Exchange values with a neighboring thread, that loaded values
// from the same half of all values as given thread; then,
// combine values in such a way that final values could be
// produced after another exchange.
uint32_t tmp0 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n], src_lane_0_);
uint32_t tmp1 = __shfl_sync(0xFFFFFFFF, src_ptr[2 * n + 1], src_lane_0_);
tmp0 = __byte_perm(src_ptr[2 * n], tmp0, byte_selector_0_);
tmp1 = __byte_perm(src_ptr[2 * n + 1], tmp1, byte_selector_0_);

// Exchange values with corresponding thread from the same
// quadruple as given thread, but that loaded values from the
// other half of all values. Then, combine values to produce
// final values hold by given thread.
uint32_t mine = __byte_perm(tmp0, tmp1, byte_selector_10_);
uint32_t theirs = __byte_perm(tmp0, tmp1, byte_selector_11_);
theirs = __shfl_sync(0xFFFFFFFF, theirs, src_lane_1_);
dst_ptr[n + dst_incr_0_] = mine;
dst_ptr[n + dst_incr_1_] = theirs;
}

return result;
}

};

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -412,10 +523,21 @@ class MmaMixedInputTensorOp {

public:

// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionM =
(sizeof_bits<ElementB>::value > sizeof_bits<ElementA>::value)
? 8 * sizeof_bits<ElementB>::value / sizeof_bits<ElementA>::value
: InstructionShape::kM;

// Shape for loading data type from shared memory, accounting
// eventually for narrower ElementA.
using LoadInstructionShapeA =
GemmShape<LoadInstructionM, InstructionShape::kN, InstructionShape::kK>;

/// Iterates over the A operand in Shared Memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
MatrixShape<LoadInstructionShapeA::kM, LoadInstructionShapeA::kK>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

/// Storage for A tile in registers (loaded from Shared Memory)
Expand All @@ -428,10 +550,21 @@ class MmaMixedInputTensorOp {
/// Underlying arch::Mma instruction operand fragement for matrix A
using MmaOperandA = typename ArchMmaOperator::FragmentA;

// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK =
(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value)
? 8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value
: InstructionShape::kK;

// Shape for loading data type from shared memory, accounting
// eventually for narrower ElementB.
using LoadInstructionShapeB =
GemmShape<InstructionShape::kM, InstructionShape::kN, LoadInstructionK>;

/// Iterates over the B operand in Shared Memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
MatrixShape<LoadInstructionShapeB::kK, LoadInstructionShapeB::kN>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;

/// Storage for B tile in registers (loaded from Shared Memory)
Expand Down Expand Up @@ -492,6 +625,13 @@ class MmaMixedInputTensorOp {
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);

if constexpr (is_B_4bit) {
if (!transform_B_flag_) {
ptr_B += TransformedFragmentB::kElements / 2 / MmaOperandB::kElements;
}
transform_B_flag_ = !transform_B_flag_;
}

CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {

Expand Down Expand Up @@ -522,15 +662,17 @@ class MmaMixedInputTensorOp {
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {

// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<ElementBMma, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);
if (!is_B_4bit || transform_B_flag_) {
// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<ElementBMma, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);

// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<ElementBMma, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<ElementBMma, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
}

FragmentA tmp_A;

Expand All @@ -553,6 +695,11 @@ class MmaMixedInputTensorOp {

ptr_dst_A[1] = convert_A(ptr_tmp_A[1]);
}

private:
static constexpr bool is_B_4bit = cutlass::sizeof_bits<ElementB>::value == 4;
static_assert(!is_B_4bit || FragmentB::kElements % 16 == 0);
mutable bool transform_B_flag_ = true;
};

/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
1 change: 1 addition & 0 deletions test/unit/gemm/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ cutlass_test_unit_add_executable(
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu
gemm_universal_f16t_s4n_f16t_mixed_input_tensor_op_f16_sm80.cu
)

cutlass_test_unit_add_executable(
Expand Down
Loading

0 comments on commit b96bd61

Please sign in to comment.