Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mixed 4-bit/8-bit data types GEMM #1413

Merged
merged 2 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions include/cutlass/gemm/device/default_gemm_configuration.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,60 @@ struct DefaultGemmConfigurationSm89F8 {
using Operator = arch::OpMultiplyAdd;
};

////////////////////////////////////////////////////////////////////////////////

template <
typename ElementC>
struct DefaultGemmConfiguration<
arch::OpClassTensorOp,
arch::Sm80,
int4b_t,
int8_t,
ElementC,
int32_t> {

static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;
static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;

using ThreadblockShape = GemmShape<128, 256, 64>;
using WarpShape = GemmShape<64, 64, 64>;
using InstructionShape = GemmShape<16, 8, 32>;
static int const kStages = 3;

using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

using Operator = arch::OpMultiplyAddSaturate;
};

////////////////////////////////////////////////////////////////////////////////

template <
typename ElementC>
struct DefaultGemmConfiguration<
arch::OpClassTensorOp,
arch::Sm80,
int8_t,
int4b_t,
ElementC,
int32_t> {

static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;
static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;

using ThreadblockShape = GemmShape<128, 256, 64>;
using WarpShape = GemmShape<64, 64, 64>;
using InstructionShape = GemmShape<16, 8, 32>;
static int const kStages = 3;

using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;

using Operator = arch::OpMultiplyAddSaturate;
};

////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for SM89 fe4m3 x fe4m3
template <typename ElementC, typename ElementAccumulator>
struct DefaultGemmConfiguration<
Expand Down
71 changes: 70 additions & 1 deletion include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ struct DefaultMmaTensorOp<
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
ElementA, ElementB>::type;

// Operand datatypes in the internal MMA instruction - use the wider of the two data types
Expand All @@ -294,6 +294,75 @@ struct DefaultMmaTensorOp<
Policy, PartitionsK, AccumulatorsInRowMajor>;
};


/////////////////////////////////////////////////////////////////////////////////////////////////

/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Element type of A matrix
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Element type of B matrix
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
WarpShape_,
GemmShape<16, 8, 32>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
PartitionsK, AccumulatorsInRowMajor> {


// Check if the ElementA and ElementB are of different data types
static_assert(!platform::is_same<ElementA, ElementB>::value,
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
ElementA, ElementB>::type;

// Operand datatypes in the internal MMA instruction - use the wider of the two data types
using MmaElementA = ElementOperand;
using MmaElementB = ElementOperand;
using MmaElementC = ElementC;

// Uses
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
GemmShape<16, 8, 32>,
32,
MmaElementA, cutlass::layout::RowMajor,
MmaElementB, cutlass::layout::ColumnMajor,
MmaElementC, cutlass::layout::RowMajor,
arch::OpMultiplyAddSaturate
>,
cutlass::MatrixShape<1, 1> >;

// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace warp
Expand Down
14 changes: 10 additions & 4 deletions include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ struct FragmentShuffler {
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
/// for operand A multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
Expand All @@ -122,8 +123,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kA,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)) ||
((sizeof_bits<ElementMma_>::value == 8) &&
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
hwu36 marked this conversation as resolved.
Show resolved Hide resolved
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
Expand Down Expand Up @@ -187,6 +190,7 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
Expand All @@ -205,8 +209,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)) ||
((sizeof_bits<ElementMma_>::value == 8) &&
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
Expand Down
80 changes: 80 additions & 0 deletions include/cutlass/numeric_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -2771,6 +2771,86 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
}
};

/// Partial specialization for Array<int8_t, 8> <= Array<int4b_t, 8>
template <
FloatRoundStyle Round
>
struct NumericArrayConverter<int8_t, int4b_t, 8, Round> {

using result_type = Array<int8_t, 8>;
using source_type = Array<int4b_t, 8>;
static FloatRoundStyle const round_style = Round;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {

unsigned const& storage = reinterpret_cast<unsigned const &>(source);
unsigned out[2];

asm volatile(
"{ .reg .u32 tmp0, tmp1, tmp2;"
"shl.b32 tmp0, %2, 4;"
"and.b32 tmp0, tmp0, 0xf0f0f0f0;"
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
"shr.u32 tmp0, tmp0, 4;"
"or.b32 tmp2, tmp0, tmp1;"
"and.b32 tmp0, %2, 0xf0f0f0f0;"
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
"shr.u32 tmp0, tmp0, 4;"
"or.b32 tmp0, tmp0, tmp1;"
"prmt.b32 %0, tmp2, tmp0, 0x5140;"
"prmt.b32 %1, tmp2, tmp0, 0x7362;"
"}"
: "=r"(out[0]), "=r"(out[1])
: "r"(storage));

return reinterpret_cast<result_type const &>(out);
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

/// Partial specialization for Array<int8_t> <= Array<int4b_t>
template <
int N,
FloatRoundStyle Round
>
struct NumericArrayConverter<int8_t, int4b_t, N, Round> {
static_assert(!(N % 8), "N must be multiple of 8.");

using result_type = Array<int8_t, N>;
using source_type = Array<int4b_t, N>;
static FloatRoundStyle const round_style = Round;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {

NumericArrayConverter<int8_t, int4b_t, 8, Round> convert_vector_;

result_type result;

Array<int8_t, 8> *result_ptr = reinterpret_cast<Array<int8_t, 8> *>(&result);
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}

return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

#endif // Conditional guards to enable partial specialization for packed integers

namespace detail {
Expand Down
Loading