diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index c42d44fd9727..b166b16b7721 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -623,6 +623,15 @@ TVM_DLL const Op& ptx_mma(); */ TVM_DLL const Op& ptx_mma_sp(); +/*! + * \brief tvm intrinsic for ptx load matrix from shared memory. + * + * void ptx_ldmatrix(Bool trans, IntImm num, StringImm type, + * Var local_ptr, Expr local_offset, + * Var smem_ptr, Expr smem_offset); + */ +TVM_DLL const Op& ptx_ldmatrix(); + // TODO(tvm-team) replace the usage of the vector operations by Shuffle. /*! * \brief Get the high level half of the vector diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index f74d5cf484b9..d4ec536fb001 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -33,7 +33,7 @@ #include #include "literal/cuda_half_t.h" -#include "ptx_mma.h" +#include "ptx.h" namespace tvm { namespace codegen { @@ -772,11 +772,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // arg 3: A precision: fp16, fp32, ... // arg 4: B precision: fp16, fp32, ... // arg 5: C precision: fp16, fp32, ... - // arg 6: A multiplicand + // arg 6: A multiplicand pointer // arg 7: A multiplicand index - // arg 8: B multiplicand + // arg 8: B multiplicand pointer // arg 9: B multiplicand index - // arg 10: C accumulator + // arg 10: C accumulator pointer // arg 11: C accumulator index // arg 12: metadata // arg 13: metadata index @@ -803,6 +803,24 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); this->stream << asm_code; + } else if (op->op.same_as(builtin::ptx_ldmatrix())) { + // arg 0: whether the matrix is loaded in column major format or not. + // arg 1: number of matrices to load. + // arg 2: The data type in the matrix, .b16 is the only accepted data type. + // arg 3: pointer to local buffer. + // arg 4: The offset of the element to store in the local buffer. + // arg 5: pointer to the shared memory buffer to load. + // arg 6: The offset of the start element of the row to load in shared memory. + ICHECK_EQ(op->args.size(), 7U); + bool trans = Downcast(op->args[0])->value; + int num = Downcast(op->args[1])->value; + std::string type = Downcast(op->args[2])->value; + std::string local_ptr = this->PrintExpr(op->args[3]); + std::string local_elem_offset = this->PrintExpr(op->args[4]); + std::string smem_ptr = this->PrintExpr(op->args[5]); + std::string smem_elem_offset = this->PrintExpr(op->args[6]); + this->stream << PrintLoadMatrixAssembly(trans, num, type, local_ptr, local_elem_offset, + smem_ptr, smem_elem_offset); } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/ptx_mma.cc b/src/target/source/ptx.cc similarity index 81% rename from src/target/source/ptx_mma.cc rename to src/target/source/ptx.cc index d04c01896ed7..02a98ffbbabd 100644 --- a/src/target/source/ptx_mma.cc +++ b/src/target/source/ptx.cc @@ -18,10 +18,10 @@ */ /*! - * \file ptx_mma.cc + * \file ptx.cc */ -#include "ptx_mma.h" +#include "ptx.h" #include #include @@ -60,13 +60,18 @@ enum class DataType : int { kFloat32 = 13, kTensorFloat32 = 14, kFloat64 = 15, - kBit1 = 16 + kBit1 = 16, + kBit8 = 17, + kBit16 = 18, + kBit32 = 19, + kBit64 = 20, }; -static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", - ".s32", ".u32", ".s64", ".u64", ".f16", ".bf16", - ".f16x2", ".f32", ".tf32", ".f64", ".b1"}; -static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, 16, 32, 32, 32, 64, 1}; +static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", + ".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32", + ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, + 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; /*! * \brief Create PTX data type from string. @@ -106,6 +111,14 @@ inline DataType DTypeFromString(const std::string str) { return DataType::kFloat64; } else if (str == "int1" || str == ".b1") { return DataType::kBit1; + } else if (str == ".b8") { + return DataType::kBit8; + } else if (str == ".b16") { + return DataType::kBit16; + } else if (str == ".b32") { + return DataType::kBit32; + } else if (str == ".b64") { + return DataType::kBit64; } else { LOG(FATAL) << "Unrecognized PTX data type " << str; return DataType(0); @@ -360,6 +373,7 @@ inline FragAttrs GetFragAttrs(DataType dtype) { case DataType::kUInt4: case DataType::kInt8: case DataType::kUInt8: + case DataType::kBit16: case DataType::kFloat16: // .f16x2 register case DataType::kBFloat16: case DataType::kTensorFloat32: @@ -508,9 +522,9 @@ inline std::tuple GetMMAOperands(int m, i std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, const std::string& B_layout, const std::string& A_dtype, const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_offset, - const std::string& b_ref, const std::string& b_offset, - const std::string& c_ref, const std::string& c_offset, + const std::string& a_ptr, const std::string& a_elem_offset, + const std::string& b_ptr, const std::string& b_elem_offset, + const std::string& c_ptr, const std::string& c_elem_offset, const std::string& metadata, const std::string& metadata_offset, const std::string& sparsity_selector, const std::string& bit_op, bool sparse, bool saturate) { @@ -525,7 +539,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo std::string asm_code = R"( { __asm__ __volatile__( - "mma{sparse}.sync.aligned.{shape}.{alayout}.{blayout}{saturate}{dtype}{atype}{btype}{ctype}{bitop}" + "mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}" "{templates};\n" : {outputs} : {inputs}); @@ -537,30 +551,92 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo // replace patterns Replacer replacer; - replacer.register_rule("{sparse}", sparse ? ".sp" : ""); - replacer.register_rule("{shape}", shape); - replacer.register_rule("{saturate}", saturate ? ".satfinite" : ""); - replacer.register_rule("{alayout}", A_layout); - replacer.register_rule("{blayout}", B_layout); - replacer.register_rule("{atype}", ptx::DTypeToString(dtype_a)); - replacer.register_rule("{btype}", ptx::DTypeToString(dtype_b)); - replacer.register_rule("{ctype}", ptx::DTypeToString(dtype_c)); - replacer.register_rule("{dtype}", ptx::DTypeToString(dtype_c)); - replacer.register_rule("{bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc"); + replacer.register_rule("{.sparse}", sparse ? ".sp" : ""); + replacer.register_rule("{.shape}", "." + shape); + replacer.register_rule("{.saturate}", saturate ? ".satfinite" : ""); + replacer.register_rule("{.alayout}", "." + A_layout); + replacer.register_rule("{.blayout}", "." + B_layout); + replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a)); + replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b)); + replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c)); + replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc"); replacer.register_rule("{templates}", templates_str); replacer.register_rule("{outputs}", outputs_str); replacer.register_rule("{inputs}", inputs_str); asm_code = replacer.rewrite(asm_code); replacer.empty_rules(); - replacer.register_rule("A", a_ref + " + " + a_offset); - replacer.register_rule("B", b_ref + " + " + b_offset); - replacer.register_rule("C", c_ref + " + " + c_offset); - replacer.register_rule("D", c_ref + " + " + c_offset); + replacer.register_rule("A", a_ptr + " + " + a_elem_offset); + replacer.register_rule("B", b_ptr + " + " + b_elem_offset); + replacer.register_rule("C", c_ptr + " + " + c_elem_offset); + replacer.register_rule("D", c_ptr + " + " + c_elem_offset); replacer.register_rule("E", metadata + " + " + metadata_offset); replacer.register_rule("F", sparsity_selector); asm_code = replacer.rewrite(asm_code); return asm_code; } +inline std::tuple GetLoadMatrixOperands( + int num, const std::string& local_ptr, const std::string& local_elem_offset) { + std::stringstream templates, outputs; + int arg_counter = 0; + // generate templates + templates << "{%" << arg_counter++; + for (int i = 1; i < num; ++i) { + templates << ", %" << arg_counter++; + } + templates << "}, [%" << arg_counter++ << "]"; + // generate outputs + std::string ptr_type = "(unsigned *)"; + for (int i = 0; i < num; ++i) { + if (i != 0) { + outputs << ", "; + } + outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))[" + << i << "])"; + } + return std::make_tuple(templates.str(), outputs.str()); +} + +std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, + const std::string& local_ptr, + const std::string& local_elem_offset, + const std::string& smem_ptr, + const std::string& smem_elem_offset) { + CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices."; + ptx::DataType data_type = ptx::DTypeFromString(type); + CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16."; + std::string asm_code = R"( + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)({smem_addr})) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}" + "{templates};\n" + : {outputs} + : "r"(addr) + ); + } +)"; + std::string templates_str, outputs_str; + std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset); + + Replacer replacer; + replacer.register_rule("{.shape}", ".m8n8"); + replacer.register_rule("{.num}", ".x" + std::to_string(num)); + replacer.register_rule("{.trans}", trans ? ".trans" : ""); + replacer.register_rule("{.ss}", ".shared"); + replacer.register_rule("{.type}", ptx::DTypeToString(data_type)); + replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset); + replacer.register_rule("{templates}", templates_str); + replacer.register_rule("{outputs}", outputs_str); + asm_code = replacer.rewrite(asm_code); + return asm_code; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx_mma.h b/src/target/source/ptx.h similarity index 63% rename from src/target/source/ptx_mma.h rename to src/target/source/ptx.h index 728478cdf5fb..c4255d737ad0 100644 --- a/src/target/source/ptx_mma.h +++ b/src/target/source/ptx.h @@ -18,11 +18,11 @@ */ /*! - * \file ptx_mma.h - * \brief MMA code generation with inlined PTX code. + * \file ptx.h + * \brief Code generation with inlined PTX code. */ -#ifndef TVM_TARGET_SOURCE_PTX_MMA_H_ -#define TVM_TARGET_SOURCE_PTX_MMA_H_ +#ifndef TVM_TARGET_SOURCE_PTX_H_ +#define TVM_TARGET_SOURCE_PTX_H_ #include @@ -40,11 +40,11 @@ namespace codegen { * \param A_dtype The data type of multiplicand A. * \param B_dtype The data type of multiplicand B. * \param C_dtype The data type of multiplicand C. - * \param a_ref Pointer to buffer A. + * \param a_ptr Pointer to buffer A. * \param a_offset The offset of element in A. - * \param b_ref Pointer to buffer B. + * \param b_ptr Pointer to buffer B. * \param b_offset The offset of element in B. - * \param c_ref Pointer to buffer C. + * \param c_ptr Pointer to buffer C. * \param c_offset The offset of element in C. * \param metadata Pointer to metadata buffer (only used for sparse mma). * \param metadata_offset The offset of element in metadata. @@ -56,14 +56,30 @@ namespace codegen { std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, const std::string& B_layout, const std::string& A_dtype, const std::string& B_dtype, const std::string& C_dtype, - const std::string& a_ref, const std::string& a_offset, - const std::string& b_ref, const std::string& b_offset, - const std::string& c_ref, const std::string& c_offset, + const std::string& a_ptr, const std::string& a_offset, + const std::string& b_ptr, const std::string& b_offset, + const std::string& c_ptr, const std::string& c_offset, const std::string& metadata, const std::string& metadata_offset, const std::string& sparsity_selector, const std::string& bit_op, bool sparse, bool saturate); +/*! + * \brief Print ldmatrix assembly string given parameters. + * \param trans: whether the matrix is loaded in column major format or not. + * \param num: number of matrices to load. + * \param type: The data type in the matrix, .b16 is the only accepted data type. + * \param local_ptr: pointer to local buffer. + * \param local_elem_offset: The offset of the element to store in the local buffer. + * \param smem_ptr: pointer to the shared memory buffer to load. + * \param smem_elem_offset: The offset of the start element of the row to load in shared memory. + */ +std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, + const std::string& local_ptr, + const std::string& local_elem_offset, + const std::string& smem_ptr, + const std::string& smem_elem_offset); + } // namespace codegen } // namespace tvm -#endif // TVM_TARGET_SOURCE_PTX_MMA_H_ +#endif // TVM_TARGET_SOURCE_PTX_H_ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 465428e1e880..4e8d83dd32df 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -244,6 +244,9 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(vectorhigh) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/tests/python/unittest/test_tir_ptx_ldmatrix.py b/tests/python/unittest/test_tir_ptx_ldmatrix.py new file mode 100644 index 000000000000..f718082ff8a1 --- /dev/null +++ b/tests/python/unittest/test_tir_ptx_ldmatrix.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +@T.prim_func +def ptx_ldmatrix( + A: T.Buffer[(16, 16), "float16"], B: T.Buffer[(16, 16), "float16"], num: T.int32, trans: T.uint8 +) -> None: + T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + A_shared = T.alloc_buffer([16, 16], "float16", scope="shared") + A_local = T.alloc_buffer([8], "float16", scope="local") + + for i in range(8): + A_shared[i * 2 + tx // 16, tx % 16] = A[i * 2 + tx // 16, tx % 16] + + T.evaluate( + T.ptx_ldmatrix( + trans, + num, + ".b16", + A_local.data, + 0, + A_shared.data, + 16 * (tx % 16) + 8 * (tx // 16), + dtype="float16", + ) + ) + + for k in range(2): + for j in range(2): + for i in range(2): + B[8 * j + tx // 4, 8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i] + + +@tvm.testing.requires_cuda +def test_ptx_ldmatrix(): + f = ptx_ldmatrix + _, _, param_num, param_trans = f.params + arch = tvm.contrib.nvcc.get_target_compute_version() + major, minor = tvm.contrib.nvcc.parse_compute_version(arch) + if major * 10 + minor < 75: + # Require at least SM75 + return + for num in [1, 2, 4]: + for trans in [False, True]: + mod = tvm.build(f.specialize({param_num: num, param_trans: trans}), target="cuda") + A_np = np.random.rand(16, 16).astype("float16") + A_mask_np = np.zeros_like(A_np) + if num == 1: + if trans: + A_mask_np[:8, :8] = A_np[:8, :8].T + else: + A_mask_np[:8, :8] = A_np[:8, :8] + elif num == 2: + if trans: + A_mask_np[:8, :8] = A_np[:8, :8].T + A_mask_np[8:16, :8] = A_np[8:16, :8].T + else: + A_mask_np[:16, :8] = A_np[:16, :8] + else: # num == 4 + if trans: + A_mask_np[:8, :8] = A_np[:8, :8].T + A_mask_np[8:16, :8] = A_np[8:16, :8].T + A_mask_np[:8, 8:16] = A_np[:8, 8:16].T + A_mask_np[8:16, 8:16] = A_np[8:16, 8:16].T + else: + A_mask_np[:16, :16] = A_np[:16, :16] + B_np = np.zeros((16, 16)).astype("float16") + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.numpy(), A_mask_np) + + +if __name__ == "__main__": + test_ptx_ldmatrix()