Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTX] ldmatrix builtin to accelerate copying data from shared memory to warp memory #10855

Merged
merged 3 commits into from
Apr 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,15 @@ TVM_DLL const Op& ptx_mma();
*/
TVM_DLL const Op& ptx_mma_sp();

/*!
* \brief tvm intrinsic for ptx load matrix from shared memory.
*
* void ptx_ldmatrix(Bool trans, IntImm num, StringImm type,
* Var local_ptr, Expr local_offset,
* Var smem_ptr, Expr smem_offset);
*/
TVM_DLL const Op& ptx_ldmatrix();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
26 changes: 22 additions & 4 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <vector>

#include "literal/cuda_half_t.h"
#include "ptx_mma.h"
#include "ptx.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -772,11 +772,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// arg 3: A precision: fp16, fp32, ...
// arg 4: B precision: fp16, fp32, ...
// arg 5: C precision: fp16, fp32, ...
// arg 6: A multiplicand
// arg 6: A multiplicand pointer
// arg 7: A multiplicand index
// arg 8: B multiplicand
// arg 8: B multiplicand pointer
// arg 9: B multiplicand index
// arg 10: C accumulator
// arg 10: C accumulator pointer
// arg 11: C accumulator index
// arg 12: metadata
// arg 13: metadata index
Expand All @@ -803,6 +803,24 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset,
c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate);
this->stream << asm_code;
} else if (op->op.same_as(builtin::ptx_ldmatrix())) {
// arg 0: whether the matrix is loaded in column major format or not.
// arg 1: number of matrices to load.
// arg 2: The data type in the matrix, .b16 is the only accepted data type.
// arg 3: pointer to local buffer.
// arg 4: The offset of the element to store in the local buffer.
// arg 5: pointer to the shared memory buffer to load.
// arg 6: The offset of the start element of the row to load in shared memory.
ICHECK_EQ(op->args.size(), 7U);
bool trans = Downcast<Bool>(op->args[0])->value;
int num = Downcast<Integer>(op->args[1])->value;
std::string type = Downcast<StringImm>(op->args[2])->value;
std::string local_ptr = this->PrintExpr(op->args[3]);
std::string local_elem_offset = this->PrintExpr(op->args[4]);
std::string smem_ptr = this->PrintExpr(op->args[5]);
std::string smem_elem_offset = this->PrintExpr(op->args[6]);
this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset,
smem_ptr, smem_elem_offset);
} else {
CodeGenC::VisitExpr_(op, os);
}
Expand Down
126 changes: 101 additions & 25 deletions src/target/source/ptx_mma.cc → src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
*/

/*!
* \file ptx_mma.cc
* \file ptx.cc
*/

#include "ptx_mma.h"
#include "ptx.h"

#include <algorithm>
#include <string>
Expand Down Expand Up @@ -60,13 +60,18 @@ enum class DataType : int {
kFloat32 = 13,
kTensorFloat32 = 14,
kFloat64 = 15,
kBit1 = 16
kBit1 = 16,
kBit8 = 17,
kBit16 = 18,
kBit32 = 19,
kBit64 = 20,
};

static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16",
".s32", ".u32", ".s64", ".u64", ".f16", ".bf16",
".f16x2", ".f32", ".tf32", ".f64", ".b1"};
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1};
static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32",
".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32",
".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"};
static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16,
16, 32, 32, 32, 64, 1, 8, 16, 32, 64};

/*!
* \brief Create PTX data type from string.
Expand Down Expand Up @@ -106,6 +111,14 @@ inline DataType DTypeFromString(const std::string str) {
return DataType::kFloat64;
} else if (str == "int1" || str == ".b1") {
return DataType::kBit1;
} else if (str == ".b8") {
return DataType::kBit8;
} else if (str == ".b16") {
return DataType::kBit16;
} else if (str == ".b32") {
return DataType::kBit32;
} else if (str == ".b64") {
return DataType::kBit64;
} else {
LOG(FATAL) << "Unrecognized PTX data type " << str;
return DataType(0);
Expand Down Expand Up @@ -360,6 +373,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) {
case DataType::kUInt4:
case DataType::kInt8:
case DataType::kUInt8:
case DataType::kBit16:
case DataType::kFloat16: // .f16x2 register
case DataType::kBFloat16:
case DataType::kTensorFloat32:
Expand Down Expand Up @@ -508,9 +522,9 @@ inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, i
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
const std::string& B_layout, const std::string& A_dtype,
const std::string& B_dtype, const std::string& C_dtype,
const std::string& a_ref, const std::string& a_offset,
const std::string& b_ref, const std::string& b_offset,
const std::string& c_ref, const std::string& c_offset,
const std::string& a_ptr, const std::string& a_elem_offset,
const std::string& b_ptr, const std::string& b_elem_offset,
const std::string& c_ptr, const std::string& c_elem_offset,
const std::string& metadata, const std::string& metadata_offset,
const std::string& sparsity_selector, const std::string& bit_op,
bool sparse, bool saturate) {
Expand All @@ -525,7 +539,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
std::string asm_code = R"(
{
__asm__ __volatile__(
"mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}"
"mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}"
"{templates};\n"
: {outputs}
: {inputs});
Expand All @@ -537,30 +551,92 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo

// replace patterns
Replacer replacer;
replacer.register_rule("{sparse}", sparse ? ".sp" : "");
replacer.register_rule("{shape}", shape);
replacer.register_rule("{saturate}", saturate ? ".satfinite" : "");
replacer.register_rule("{alayout}", A_layout);
replacer.register_rule("{blayout}", B_layout);
replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a));
replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b));
replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{.sparse}", sparse ? ".sp" : "");
replacer.register_rule("{.shape}", "." + shape);
replacer.register_rule("{.saturate}", saturate ? ".satfinite" : "");
replacer.register_rule("{.alayout}", "." + A_layout);
replacer.register_rule("{.blayout}", "." + B_layout);
replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a));
replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b));
replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c));
replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
replacer.register_rule("{inputs}", inputs_str);
asm_code = replacer.rewrite(asm_code);
replacer.empty_rules();
replacer.register_rule("A", a_ref + " + " + a_offset);
replacer.register_rule("B", b_ref + " + " + b_offset);
replacer.register_rule("C", c_ref + " + " + c_offset);
replacer.register_rule("D", c_ref + " + " + c_offset);
replacer.register_rule("A", a_ptr + " + " + a_elem_offset);
replacer.register_rule("B", b_ptr + " + " + b_elem_offset);
replacer.register_rule("C", c_ptr + " + " + c_elem_offset);
replacer.register_rule("D", c_ptr + " + " + c_elem_offset);
replacer.register_rule("E", metadata + " + " + metadata_offset);
replacer.register_rule("F", sparsity_selector);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}

inline std::tuple<std::string, std::string> GetLoadMatrixOperands(
int num, const std::string& local_ptr, const std::string& local_elem_offset) {
std::stringstream templates, outputs;
int arg_counter = 0;
// generate templates
templates << "{%" << arg_counter++;
for (int i = 1; i < num; ++i) {
templates << ", %" << arg_counter++;
}
templates << "}, [%" << arg_counter++ << "]";
// generate outputs
std::string ptr_type = "(unsigned *)";
for (int i = 0; i < num; ++i) {
if (i != 0) {
outputs << ", ";
}
outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))["
<< i << "])";
}
return std::make_tuple(templates.str(), outputs.str());
}

std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
const std::string& local_ptr,
const std::string& local_elem_offset,
const std::string& smem_ptr,
const std::string& smem_elem_offset) {
CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices.";
ptx::DataType data_type = ptx::DTypeFromString(type);
CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
std::string asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
"{templates};\n"
: {outputs}
: "r"(addr)
);
}
)";
std::string templates_str, outputs_str;
std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);

Replacer replacer;
replacer.register_rule("{.shape}", ".m8n8");
replacer.register_rule("{.num}", ".x" + std::to_string(num));
replacer.register_rule("{.trans}", trans ? ".trans" : "");
replacer.register_rule("{.ss}", ".shared");
replacer.register_rule("{.type}", ptx::DTypeToString(data_type));
replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset);
replacer.register_rule("{templates}", templates_str);
replacer.register_rule("{outputs}", outputs_str);
asm_code = replacer.rewrite(asm_code);
return asm_code;
}

} // namespace codegen
} // namespace tvm
38 changes: 27 additions & 11 deletions src/target/source/ptx_mma.h → src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
*/

/*!
* \file ptx_mma.h
* \brief MMA code generation with inlined PTX code.
* \file ptx.h
* \brief Code generation with inlined PTX code.
*/
#ifndef TVM_TARGET_SOURCE_PTX_MMA_H_
#define TVM_TARGET_SOURCE_PTX_MMA_H_
#ifndef TVM_TARGET_SOURCE_PTX_H_
#define TVM_TARGET_SOURCE_PTX_H_

#include <tvm/runtime/logging.h>

Expand All @@ -40,11 +40,11 @@ namespace codegen {
* \param A_dtype The data type of multiplicand A.
* \param B_dtype The data type of multiplicand B.
* \param C_dtype The data type of multiplicand C.
* \param a_ref Pointer to buffer A.
* \param a_ptr Pointer to buffer A.
* \param a_offset The offset of element in A.
* \param b_ref Pointer to buffer B.
* \param b_ptr Pointer to buffer B.
* \param b_offset The offset of element in B.
* \param c_ref Pointer to buffer C.
* \param c_ptr Pointer to buffer C.
* \param c_offset The offset of element in C.
* \param metadata Pointer to metadata buffer (only used for sparse mma).
* \param metadata_offset The offset of element in metadata.
Expand All @@ -56,14 +56,30 @@ namespace codegen {
std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
const std::string& B_layout, const std::string& A_dtype,
const std::string& B_dtype, const std::string& C_dtype,
const std::string& a_ref, const std::string& a_offset,
const std::string& b_ref, const std::string& b_offset,
const std::string& c_ref, const std::string& c_offset,
const std::string& a_ptr, const std::string& a_offset,
const std::string& b_ptr, const std::string& b_offset,
const std::string& c_ptr, const std::string& c_offset,
const std::string& metadata, const std::string& metadata_offset,
const std::string& sparsity_selector, const std::string& bit_op,
bool sparse, bool saturate);

/*!
* \brief Print ldmatrix assembly string given parameters.
* \param trans: whether the matrix is loaded in column major format or not.
* \param num: number of matrices to load.
* \param type: The data type in the matrix, .b16 is the only accepted data type.
* \param local_ptr: pointer to local buffer.
* \param local_elem_offset: The offset of the element to store in the local buffer.
* \param smem_ptr: pointer to the shared memory buffer to load.
* \param smem_elem_offset: The offset of the start element of the row to load in shared memory.
*/
std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
const std::string& local_ptr,
const std::string& local_elem_offset,
const std::string& smem_ptr,
const std::string& smem_elem_offset);

} // namespace codegen
} // namespace tvm

#endif // TVM_TARGET_SOURCE_PTX_MMA_H_
#endif // TVM_TARGET_SOURCE_PTX_H_
3 changes: 3 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr<TCallEffectKind>("TCallEffectKind",
TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(vectorhigh)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));

Expand Down
Loading