Skip to content

Commit

Permalink
[PTX] ldmatrix builtin to accelerate copying data from shared memor…
Browse files Browse the repository at this point in the history
…y to warp memory (apache#10855)

We already have PTX mma and mma.sp builtin support in apache#9909  and apache#10339 . However, we have not supported corresponding data movement builtins for these mma instructions, so the data movement would not be as fast as wmma.

This PR brings the `ldmatrix` builtin, which is a native PTX warp-level instruction (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix), and we can use it to load several (1/2/4) 8x8 matrices from shared memory to warp memory.
  • Loading branch information
yzh119 authored and pfk-beta committed Apr 11, 2022
1 parent cd0113b commit c2b2333
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 40 deletions.
9 changes: 9 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,15 @@ TVM_DLL const Op& ptx_mma();
*/
TVM_DLL const Op& ptx_mma_sp();

/*!
* \brief tvm intrinsic for ptx load matrix from shared memory.
*
* void ptx_ldmatrix(Bool trans, IntImm num, StringImm type,
* Var local_ptr, Expr local_offset,
* Var smem_ptr, Expr smem_offset);
*/
TVM_DLL const Op& ptx_ldmatrix();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
26 changes: 22 additions & 4 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <vector>

#include "literal/cuda_half_t.h"
#include "ptx_mma.h"
#include "ptx.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -772,11 +772,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 3: A precision: fp16, fp32, ...
// arg 4: B precision: fp16, fp32, ...
// arg 5: C precision: fp16, fp32, ...
// arg 6: A multiplicand
// arg 6: A multiplicand pointer
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 8: B multiplicand pointer
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 10: C accumulator pointer
// arg 11: C accumulator index
// arg 12: metadata
// arg 13: metadata index
Expand All @@ -803,6 +803,24 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load.
// arg 2: The data type in the matrix, .b16 is the only accepted data type.
// arg 3: pointer to local buffer.
// arg 4: The offset of the element to store in the local buffer.
// arg 5: pointer to the shared memory buffer to load.
// arg 6: The offset of the start element of the row to load in shared memory.
ICHECK_EQ(op->args.size(), 7U);
bool trans = Downcast<Bool>(op->args[0])->value;
int num = Downcast<Integer>(op->args[1])->value;
std::string type = Downcast<StringImm>(op->args[2])->value;
std::string local_ptr = this->PrintExpr(op->args[3]);
std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]);
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
126 changes: 101 additions & 25 deletions src/target/source/ptx_mma.cc → src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
*/

/*!
* \file ptx_mma.cc
* \file ptx.cc
*/

#include "ptx_mma.h"
#include "ptx.h"

#include <algorithm>
#include <string>
Expand Down Expand Up @@ -60,13 +60,18 @@ enum class DataType : int {
kFloat32 = 13,
kTensorFloat32 = 14,
kFloat64 = 15,
kBit1 = 16
kBit1 = 16,
kBit8 = 17,
kBit16 = 18,
kBit32 = 19,
kBit64 = 20,
};

static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16",
".s32", ".u32", ".s64", ".u64", ".f16", ".bf16",
".f16x2", ".f32", ".tf32", ".f64", ".b1"};
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1};
static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32",
".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32",
".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"};
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16,
16, 32, 32, 32, 64, 1, 8, 16, 32, 64};

/*!
* \brief Create PTX data type from string.
Expand Down Expand Up @@ -106,6 +111,14 @@ inline DataType DTypeFromString(const std::string str) {
return DataType::kFloat64;
} else if (str == "int1" || str == ".b1") {
return DataType::kBit1;
} else if (str == ".b8") {
return DataType::kBit8;
} else if (str == ".b16") {
return DataType::kBit16;
} else if (str == ".b32") {
return DataType::kBit32;
} else if (str == ".b64") {
return DataType::kBit64;
} else {
LOG(FATAL) << "Unrecognized PTX data type " << str;
return DataType(0);
Expand Down Expand Up @@ -360,6 +373,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) {
case DataType::kUInt4:
case DataType::kInt8:
case DataType::kUInt8:
case DataType::kBit16:
case DataType::kFloat16: // .f16x2 register
case DataType::kBFloat16:
case DataType::kTensorFloat32:
Expand Down Expand Up @@ -508,9 +522,9 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
const std::string& B_layout, const std::string& A_dtype,
const std::string& B_dtype, const std::string& C_dtype,
const std::string& a_ref, const std::string& a_offset,
const std::string& b_ref, const std::string& b_offset,
const std::string& c_ref, const std::string& c_offset,
const std::string& a_ptr, const std::string& a_elem_offset,
const std::string& b_ptr, const std::string& b_elem_offset,
const std::string& c_ptr, const std::string& c_elem_offset,
const std::string& metadata, const std::string& metadata_offset,
const std::string& sparsity_selector, const std::string& bit_op,
bool sparse, bool saturate) {
Expand All @@ -525,7 +539,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
std::string asm_code = R"(
{
__asm__ __volatile__(
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
"mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}"
"{templates};\n"
: {outputs}
: {inputs});
Expand All @@ -537,30 +551,92 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo

// replace patterns
Replacer replacer;
replacer.register_rule("{sparse}", sparse ? ".sp" : "");
replacer.register_rule("{shape}", shape);
replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
replacer.register_rule("{alayout}", A_layout);
replacer.register_rule("{blayout}", B_layout);
replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{.sparse}", sparse ? ".sp" : "");
replacer.register_rule("{.shape}", "." + shape);
replacer.register_rule("{.saturate}", saturate ? ".satfinite" : "");
replacer.register_rule("{.alayout}", "." + A_layout);
replacer.register_rule("{.blayout}", "." + B_layout);
replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a));
replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b));
replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
replacer.register_rule("{inputs}", inputs_str);
asm_code = replacer.rewrite(asm_code);
replacer.empty_rules();
replacer.register_rule("A", a_ref + " + " + a_offset);
replacer.register_rule("B", b_ref + " + " + b_offset);
replacer.register_rule("C", c_ref + " + " + c_offset);
replacer.register_rule("D", c_ref + " + " + c_offset);
replacer.register_rule("A", a_ptr + " + " + a_elem_offset);
replacer.register_rule("B", b_ptr + " + " + b_elem_offset);
replacer.register_rule("C", c_ptr + " + " + c_elem_offset);
replacer.register_rule("D", c_ptr + " + " + c_elem_offset);
replacer.register_rule("E", metadata + " + " + metadata_offset);
replacer.register_rule("F", sparsity_selector);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}

inline std::tuple<std::string, std::string> GetLoadMatrixOperands(
int num, const std::string& local_ptr, const std::string& local_elem_offset) {
std::stringstream templates, outputs;
int arg_counter = 0;
// generate templates
templates << "{%" << arg_counter++;
for (int i = 1; i < num; ++i) {
templates << ", %" << arg_counter++;
}
templates << "}, [%" << arg_counter++ << "]";
// generate outputs
std::string ptr_type = "(unsigned *)";
for (int i = 0; i < num; ++i) {
if (i != 0) {
outputs << ", ";
}
outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))["
<< i << "])";
}
return std::make_tuple(templates.str(), outputs.str());
}

std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
const std::string& local_ptr,
const std::string& local_elem_offset,
const std::string& smem_ptr,
const std::string& smem_elem_offset) {
CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices.";
ptx::DataType data_type = ptx::DTypeFromString(type);
CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
"{templates};\n"
: {outputs}
: "r"(addr)
);
}
)";
std::string templates_str, outputs_str;
std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);

Replacer replacer;
replacer.register_rule("{.shape}", ".m8n8");
replacer.register_rule("{.num}", ".x" + std::to_string(num));
replacer.register_rule("{.trans}", trans ? ".trans" : "");
replacer.register_rule("{.ss}", ".shared");
replacer.register_rule("{.type}", ptx::DTypeToString(data_type));
replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset);
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}

} // namespace codegen
} // namespace tvm
38 changes: 27 additions & 11 deletions src/target/source/ptx_mma.h → src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
*/

/*!
* \file ptx_mma.h
* \brief MMA code generation with inlined PTX code.
* \file ptx.h
* \brief Code generation with inlined PTX code.
*/
#ifndef TVM_TARGET_SOURCE_PTX_MMA_H_
#define TVM_TARGET_SOURCE_PTX_MMA_H_
#ifndef TVM_TARGET_SOURCE_PTX_H_
#define TVM_TARGET_SOURCE_PTX_H_

#include <tvm/runtime/logging.h>

Expand All @@ -40,11 +40,11 @@ namespace codegen {
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
* \param a_ref Pointer to buffer A.
* \param a_ptr Pointer to buffer A.
* \param a_offset The offset of element in A.
* \param b_ref Pointer to buffer B.
* \param b_ptr Pointer to buffer B.
* \param b_offset The offset of element in B.
* \param c_ref Pointer to buffer C.
* \param c_ptr Pointer to buffer C.
* \param c_offset The offset of element in C.
* \param metadata Pointer to metadata buffer (only used for sparse mma).
* \param metadata_offset The offset of element in metadata.
Expand All @@ -56,14 +56,30 @@ namespace codegen {
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
const std::string& B_layout, const std::string& A_dtype,
const std::string& B_dtype, const std::string& C_dtype,
const std::string& a_ref, const std::string& a_offset,
const std::string& b_ref, const std::string& b_offset,
const std::string& c_ref, const std::string& c_offset,
const std::string& a_ptr, const std::string& a_offset,
const std::string& b_ptr, const std::string& b_offset,
const std::string& c_ptr, const std::string& c_offset,
const std::string& metadata, const std::string& metadata_offset,
const std::string& sparsity_selector, const std::string& bit_op,
bool sparse, bool saturate);

/*!
* \brief Print ldmatrix assembly string given parameters.
* \param trans: whether the matrix is loaded in column major format or not.
* \param num: number of matrices to load.
* \param type: The data type in the matrix, .b16 is the only accepted data type.
* \param local_ptr: pointer to local buffer.
* \param local_elem_offset: The offset of the element to store in the local buffer.
* \param smem_ptr: pointer to the shared memory buffer to load.
* \param smem_elem_offset: The offset of the start element of the row to load in shared memory.
*/
std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
const std::string& local_ptr,
const std::string& local_elem_offset,
const std::string& smem_ptr,
const std::string& smem_elem_offset);

} // namespace codegen
} // namespace tvm

#endif // TVM_TARGET_SOURCE_PTX_MMA_H_
#endif // TVM_TARGET_SOURCE_PTX_H_
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

Expand Down
Loading

0 comments on commit c2b2333

Please sign in to comment.