From 1e2dea3c514e0a8ce0251b4062a2748f1e6735b2 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 2 Feb 2024 17:23:33 -0500 Subject: [PATCH] [TIR] Require exactly same-dtype matching for Vulkan smem reuse This PR fixes the StorageRewrite pass which failed to avoid shared memory reuse of different dtypes for Vulkan. Since the Vulkan target information is required at the time of lowering, the pass `BindTarget` needs to apply before lowering, so that the functions have correct target information. Note that previously the pass checks `Target::Current`, while `tvm.build` does not set the current target. One regression test is added. --- .../merge_shared_memory_allocations.cc | 5 -- src/tir/transforms/storage_rewrite.cc | 46 ++++++---- .../test_tir_transform_storage_rewrite.py | 84 +++++++++++++++++++ 3 files changed, 115 insertions(+), 20 deletions(-) diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 1598d409c5d8..c79b9c1f9399 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -662,11 +662,6 @@ namespace transform { Pass MergeSharedMemoryAllocations() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); - // disable this pass for Vulkan - auto target = Target::Current(true); - if (target.defined() && target->kind->name == "vulkan") { - return f; - } auto* n = f.CopyOnWrite(); n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem); return f; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 6875523a956d..991c48219b96 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -380,13 +380,15 @@ class StoragePlanRewriter : public StmtExprMutator { using StmtEntry = LinearAccessPatternFinder::StmtEntry; using AllocEntry = LinearAccessPatternFinder::AllocEntry; - Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) { + Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse, + bool reuse_require_exact_matched_dtype) { detect_inplace_ = detect_inplace; // plan the rewrite LinearAccessPatternFinder finder; finder(stmt); this->LivenessAnalysis(finder.linear_seq_); - this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse); + this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse, + reuse_require_exact_matched_dtype); all_buffers_accessed_ = finder.all_buffers_accessed_; this->PrepareNewAlloc(); // start rewrite @@ -817,7 +819,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Memory plan algorithm void PlanMemory(const std::vector& seq, const std::unordered_map& alloc_info, - bool enable_reuse = true) { + bool enable_reuse, bool reuse_require_exact_matched_dtype) { std::unordered_set inplace_flag; for (size_t i = 0; i < seq.size(); ++i) { @@ -864,8 +866,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(alloc, thread_scope_, storage_scope, - entry.num_physical_dimensions, enable_reuse); + dst_entry = + FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions, + enable_reuse, reuse_require_exact_matched_dtype); } dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; @@ -919,7 +922,7 @@ class StoragePlanRewriter : public StmtExprMutator { StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope, size_t num_physical_dimensions, - bool enable_reuse = true) { + bool enable_reuse, bool reuse_require_exact_matched_dtype) { ICHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. @@ -958,6 +961,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->scope != scope) continue; // when not divided, no reuse, eg, float4 vs float3 if (e->bits_offset % op_elem_bits != 0) continue; + if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) { + continue; + } e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); return e; @@ -969,6 +975,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; if (e->elem_type != op->dtype.element_of()) continue; + if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) { + continue; + } e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); return e; @@ -1704,17 +1713,24 @@ namespace transform { Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + bool enable_reuse = true; + bool reuse_require_exact_matched_dtype = false; bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); - // disable merge_static_smem for Vulkan - auto target = Target::Current(true); - if (target.defined() && target->kind->name == "vulkan") { - merge_static_smem = false; - } - // Only enable reuse when we are not merging static shared memory. - // Otherwise we will do it in a separate stage - bool enable_reuse = merge_static_smem ? false : true; + if (merge_static_smem) { + // When `merge_static_smem` is true, we will reuse and merge shared + // memory in a dedicated pass `MergeSharedMemoryAllocations`. + // And so we don't enable reuse in this pass. + enable_reuse = false; + } + + Optional target = f->GetAttr("target"); + if (target.defined() && target.value()->kind->name == "vulkan") { + // Require exactly same-dtype matching in smem reuse for Vulkan + reuse_require_exact_matched_dtype = true; + } auto* n = f.CopyOnWrite(); - n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse); + n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse, + reuse_require_exact_matched_dtype); // Parameters may not be rewritten, but internal allocations may. // Vectorization of AllocateConst is currently disabled, as it has // indexing issues for types that include padding (e.g. int8x3 diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 197e81818ee3..4b71eb825414 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. import sys + import pytest + import tvm import tvm.testing from tvm import te @@ -928,5 +930,87 @@ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")): D[i] = C[i] +def test_vulkan_smem_reuse(): + target = tvm.target.Target( + { + "keys": ["vulkan", "gpu"], + "kind": "vulkan", + "max_num_threads": 256, + "max_threads_per_block": 256, + "supports_float32": T.bool(True), + "supports_int32": T.bool(True), + "tag": "", + "thread_warp_size": 1, + } + ) + + @T.prim_func(private=True) + def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + A_shared = T.allocate([4], "float32", "shared") + A_local = T.allocate([4], "float32", "local") + B_shared = T.allocate([4], "float16", "shared") + A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_1 = T.Buffer((4,), data=A.data) + A_shared_1[threadIdx_x] = A_1[threadIdx_x] + A_local_1 = T.Buffer((4,), data=A_local, scope="local") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] + B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) + threadIdx_x = T.launch_thread("threadIdx.x", 4) + B_1 = T.Buffer((4,), "float16", data=B.data) + B_1[threadIdx_x] = B_shared_1[threadIdx_x] + + @T.prim_func(private=True) + def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + A_shared = T.allocate([4], "float32", "shared") + A_local = T.allocate([4], "float32", "local") + A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_1 = T.Buffer((4,), data=A.data) + A_shared_1[threadIdx_x] = A_1[threadIdx_x] + A_local_1 = T.Buffer((4,), data=A_local, scope="local") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] + A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) + threadIdx_x = T.launch_thread("threadIdx.x", 4) + B_1 = T.Buffer((4,), "float16", data=B.data) + B_1[threadIdx_x] = A_shared_2[threadIdx_x] + + @T.prim_func(private=True) + def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): + T.func_attr({"target": target, "tir.noalias": T.bool(True)}) + A_shared_1 = T.allocate([4], "float32", "shared") + A_local_1 = T.allocate([4], "float32", "local") + B_shared_1 = T.allocate([4], "float16", "shared") + A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_1 = T.Buffer((4,), data=A.data) + A_shared_1_1[threadIdx_x] = A_1[threadIdx_x] + A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x] + B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x]) + threadIdx_x = T.launch_thread("threadIdx.x", 4) + B_1 = T.Buffer((4,), "float16", data=B.data) + B_1[threadIdx_x] = B_shared_1_1[threadIdx_x] + + # Reuse shared memory when lowering without target. + mod = tvm.IRModule({"main": func}) + tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering) + + # No shared memory reuse when lowering with target Vulkan. + mod = tvm.tir.transform.BindTarget(target)(mod) + tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering) + + if __name__ == "__main__": tvm.testing.main()