From 3218facf100b0dfc55715acfd1cee156764129ba Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 18 May 2022 14:04:56 +0900 Subject: [PATCH] some clean up --- python/tvm/tir/tensor_intrin/cuda.py | 17 ++++++++--------- src/target/source/codegen_cuda.cc | 2 ++ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index bbd26681f20b8..59e0e54c63fff 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -79,15 +79,17 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 8 * ( tx // HALF_WARP_expr ) - - elif k_dim == 32: - assert dtype == "int8" + else: + assert ( + k_dim == 32 and dtype == "int8" + ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" if ldmatrix_col_major: index_map = shared_32x16_to_ldmatrix_32x16_layout - shared_offset = ( - lambda _, stride: stride - ) # dummy offset, ldmatrix cannot be used for int8 + trans case + # A dummy offset, ldmatrix cannot be used for int8 + trans case. + # We still use the ldmatrix intrinsic, but lower it to a manual loop in the codegen. + # Only the stride information is required. + shared_offset = lambda _, stride: stride elif is_b and transposed: index_map = shared_16x32_to_ldmatrix_32x16_layout shared_offset = ( @@ -99,9 +101,6 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed): index_map = shared_16x32_to_ldmatrix_32x16_layout shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 16) - else: - assert False, "Unsupported k dim" - assert index_map and shared_offset if is_b and not transposed: diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 17dfa9cb5876c..438acd5082f63 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -821,6 +821,8 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string smem_ptr = this->PrintExpr(op->args[5]); if (trans && op->dtype.bits() == 8) { + // Since ldmatrix assumes that a matrix element is 16 bit, it cannot properly transpose an + // int8 matrix. std::string smem_stride = this->PrintExpr(op->args[6]); ICHECK(num == 4); os << "for (int i = 0; i < 16; ++i) {\n";