Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve mixed dtype GEMM #1972

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
425 changes: 226 additions & 199 deletions examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu

Large diffs are not rendered by default.

314 changes: 66 additions & 248 deletions examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu

Large diffs are not rendered by default.

224 changes: 16 additions & 208 deletions examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,12 @@
#include "cutlass/util/reference/device/tensor_compare.h"

#include "helper.h"
#include "unfused_weight_dequantize.hpp"
#include "mixed_dtype_utils.hpp"

using namespace cute;

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

// This is just an example, so we use a regular enum so we can compare directly to the command-line int.
enum GemmMode {
ConvertOnly,
ScaleOnly,
ScaleWithZeroPoint
};

/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -157,7 +150,7 @@ using ArchTag = cutlass::arch::Sm90; // T
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_128,_128,cute::Int<TileShapeK>>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; // Kernel to launch based on the default setting in the Collective Builder
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch based on the default setting in the Collective Builder
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

Expand Down Expand Up @@ -284,178 +277,14 @@ cutlass::DeviceAllocation<typename GemmScaleWithZeroPoint::EpilogueOutputOp::Ele
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////

// Command line options parsing
struct Options {

bool help = false;

float alpha = 1.0f;
float beta = 0.0f;
int iterations = 10;
int mode = 2;
int m = 5120, n = 4096, k = 4096;
int g = 128;
int l = 1;

// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);

if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}

cmd.get_cmd_line_argument("m", m);
cmd.get_cmd_line_argument("n", n);
cmd.get_cmd_line_argument("k", k);
cmd.get_cmd_line_argument("l", l);
cmd.get_cmd_line_argument("g", g);
cmd.get_cmd_line_argument("mode", mode);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations);
}

/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {

out << "55_hopper_warp_specialized_gemm\n\n"
<< " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --m=<int> Sets the M extent of the GEMM\n"
<< " --n=<int> Sets the N extent of the GEMM\n"
<< " --k=<int> Sets the K extent of the GEMM\n"
<< " --l=<int> The number of independent gemm problems with mnk shape\n"
<< " --g=<int> The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n"
<< " --mode=<int> The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n"
<< " --alpha=<f32> Epilogue scalar alpha\n"
<< " --beta=<f32> Epilogue scalar beta\n\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";

out
<< "\n\nExamples:\n\n"
<< "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 -g 0 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n";

return out;
}

/// Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
uint64_t flop = uint64_t(2) * m * n * k * l;
double gflop = double(flop) / double(1.0e9);
return gflop / runtime_s;
}
};

/// Result structure
struct Result
{
double avg_runtime_ms = 0.0;
double gflops = 0.0;
cutlass::Status status = cutlass::Status::kSuccess;
cudaError_t error = cudaSuccess;
bool passed = false;

};

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM setup and evaluation
/////////////////////////////////////////////////////////////////////////////////////////////////

/// Helper to initialize a block of device data
template <class Element>
bool initialize_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {

double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<Element>::value;

if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
}
else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
}
else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
}
else {
scope_max = 8;
scope_min = -8;
}
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));

return true;
}

template <typename Element>
bool initialize_quant_tensor(
cutlass::DeviceAllocation<Element>& block,
uint64_t seed=2023) {

float scope_min = float(cutlass::platform::numeric_limits<Element>::lowest());
float scope_max = float(cutlass::platform::numeric_limits<Element>::max());

cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));

return true;
}

template <class Element>
bool initialize_scale(
cutlass::DeviceAllocation<Element>& block,
Options const& options) {

if (options.mode == GemmMode::ConvertOnly) {
// No scales, so just initialize with 1 so we can use the same kernel to dequantize the data.
std::vector<Element> stage(block.size(), Element(1.0f));
block.copy_from_host(stage.data());
}
else {
float elt_max_f = float(cutlass::platform::numeric_limits<QuantType>::max());
const float max_dequant_val = 4.f;
const float min_dequant_val = 0.5f;

float scope_max(max_dequant_val / elt_max_f);
float scope_min(min_dequant_val / elt_max_f);

cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(scope_max), Element(scope_min));
}
return true;
}

template <class Element>
bool initialize_zero(
cutlass::DeviceAllocation<Element>& block,
Options const& options) {

if (options.mode == GemmMode::ScaleWithZeroPoint) {
cutlass::reference::device::BlockFillRandomUniform(
block.get(), block.size(), seed, Element(2.0f), Element(-2.0f));
} else {
// No bias, so just initialize with 1 so we can use the same kernel to dequantize the data.
std::vector<Element> stage(block.size(), Element(0.0f));
block.copy_from_host(stage.data());
}
return true;
}

/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(Options const& options) {
void initialize(MixedDtypeOptions const& options) {

auto shape_b = cute::make_shape(options.n, options.k, options.l);
int const scale_k = (options.k + options.g - 1) / options.g;
Expand Down Expand Up @@ -500,26 +329,26 @@ void initialize(Options const& options) {

/// Populates a Gemm::Arguments structure from the given commandline options
template <typename Args>
Args args_from_options(Options const& options)
Args args_from_options(MixedDtypeOptions const& options)
{
// Swap the A and B tensors, as well as problem shapes here.
if (options.mode == GemmMode::ConvertOnly) {
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B.get(), stride_B, block_A.get(), stride_A},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
else if (options.mode == GemmMode::ScaleOnly) {
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
{block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}
};
}
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
return Args {
cutlass::gemm::GemmUniversalMode::kGemm,
{options.n, options.m, options.k, options.l},
Expand All @@ -532,7 +361,7 @@ Args args_from_options(Options const& options)
}
}

bool verify(const Options &options) {
bool verify(MixedDtypeOptions const& options) {
//
// Compute reference output
//
Expand Down Expand Up @@ -598,7 +427,7 @@ bool verify(const Options &options) {

/// Execute a given example GEMM computation
template <typename Gemm>
int run(Options &options)
int run(MixedDtypeOptions &options)
{
initialize(options);

Expand All @@ -624,35 +453,14 @@ int run(Options &options)
CUTLASS_CHECK(gemm.run());

// Check if output from CUTLASS kernel and reference kernel are equal or not
Result result;
MixedDtypeResult result;
result.passed = verify(options);

mixed_dtype_profiling(gemm, options, result);
std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl;

if (!result.passed) {
exit(-1);
}

// Run profiling loop
if (options.iterations > 0)
{
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.run());
}
timer.stop();

// Compute average runtime and GFLOPs.
float elapsed_ms = timer.elapsed_millis();
result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations);
result.gflops = options.gflops(result.avg_runtime_ms / 1000.0);

std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl;
std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl;
std::cout << " GFLOPS: " << result.gflops << std::endl;
}

return 0;
}

Expand Down Expand Up @@ -685,7 +493,7 @@ int main(int argc, char const **args) {
// Parse options
//

Options options;
MixedDtypeOptions options;

options.parse(argc, args);

Expand All @@ -699,19 +507,19 @@ int main(int argc, char const **args) {
//

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
if (options.mode == GemmMode::ConvertOnly) {
if (options.mode == MixedDtypeGemmMode::ConvertOnly) {
std::cout << "Running in no scale mode." << std::endl;
run<GemmConvertOnly>(options);
}
else if (options.mode == GemmMode::ScaleOnly) {
else if (options.mode == MixedDtypeGemmMode::ScaleOnly) {
if (options.g == options.k) {
std::cout << "Running in per-column scale mode." << std::endl;
} else {
std::cout << "Running in group scale mode." << std::endl;
}
run<GemmScaleOnly>(options);
}
else if (options.mode == GemmMode::ScaleWithZeroPoint) {
else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) {
if (options.g == options.k) {
std::cout << "Running in per-column scale and zero mode." << std::endl;
} else {
Expand All @@ -724,4 +532,4 @@ int main(int argc, char const **args) {
return 0;
}

/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
4 changes: 2 additions & 2 deletions examples/55_hopper_mixed_dtype_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ This example shows how to do mixed types GEMMs in CUTLASS.
## High level overview
This example shows how to perform GEMMs on Hopper when A and B have different types. This implementation always passes the type with fewer bits through the register file and upcasts to the type with the higher bit count.

When relying on `KernelScheduleAuto`, the main loop supporting different A and B types will be selected whenever the bit count of A is not equal to the bit count of B. Users can manually select the mixed type main loop and explicitly choose the scheduling policy by specifying one of the following schedules to the `CollectiveBuilder`: `KernelTmaWarpSpecializedMixedInput`, `KernelTmaWarpSpecializedPingpongMixedInput` or `KernelTmaWarpSpecializedCooperativeMixedInput`.
When relying on `KernelScheduleAuto`, the main loop supporting different A and B types will be selected whenever the bit count of A is not equal to the bit count of B. Users can manually select the mixed type main loop and explicitly choose the scheduling policy by specifying one of the following schedules to the `CollectiveBuilder`: `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative`.

This first version only supports mixed type GEMMs using TMA.

Expand Down Expand Up @@ -36,4 +36,4 @@ We are currently optimizing the following cases:

* Optimizations for memory bound cases.

* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size.
Loading