diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index 1598d409c5d8..c79b9c1f9399 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -662,11 +662,6 @@ namespace transform { Pass MergeSharedMemoryAllocations() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); - // disable this pass for Vulkan - auto target = Target::Current(true); - if (target.defined() && target->kind->name == "vulkan") { - return f; - } auto* n = f.CopyOnWrite(); n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem); return f; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 6875523a956d..991c48219b96 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -380,13 +380,15 @@ class StoragePlanRewriter : public StmtExprMutator { using StmtEntry = LinearAccessPatternFinder::StmtEntry; using AllocEntry = LinearAccessPatternFinder::AllocEntry; - Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) { + Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse, + bool reuse_require_exact_matched_dtype) { detect_inplace_ = detect_inplace; // plan the rewrite LinearAccessPatternFinder finder; finder(stmt); this->LivenessAnalysis(finder.linear_seq_); - this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse); + this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse, + reuse_require_exact_matched_dtype); all_buffers_accessed_ = finder.all_buffers_accessed_; this->PrepareNewAlloc(); // start rewrite @@ -817,7 +819,7 @@ class StoragePlanRewriter : public StmtExprMutator { // Memory plan algorithm void PlanMemory(const std::vector& seq, const std::unordered_map& alloc_info, - bool enable_reuse = true) { + bool enable_reuse, bool reuse_require_exact_matched_dtype) { std::unordered_set inplace_flag; for (size_t i = 0; i < seq.size(); ++i) { @@ -864,8 +866,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(alloc, thread_scope_, storage_scope, - entry.num_physical_dimensions, enable_reuse); + dst_entry = + FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions, + enable_reuse, reuse_require_exact_matched_dtype); } dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; @@ -919,7 +922,7 @@ class StoragePlanRewriter : public StmtExprMutator { StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope, const StorageScope& scope, size_t num_physical_dimensions, - bool enable_reuse = true) { + bool enable_reuse, bool reuse_require_exact_matched_dtype) { ICHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. @@ -958,6 +961,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->scope != scope) continue; // when not divided, no reuse, eg, float4 vs float3 if (e->bits_offset % op_elem_bits != 0) continue; + if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) { + continue; + } e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); return e; @@ -969,6 +975,9 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->attach_scope_ != attach_scope) continue; if (e->scope != scope) continue; if (e->elem_type != op->dtype.element_of()) continue; + if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) { + continue; + } e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); return e; @@ -1704,17 +1713,24 @@ namespace transform { Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + bool enable_reuse = true; + bool reuse_require_exact_matched_dtype = false; bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); - // disable merge_static_smem for Vulkan - auto target = Target::Current(true); - if (target.defined() && target->kind->name == "vulkan") { - merge_static_smem = false; - } - // Only enable reuse when we are not merging static shared memory. - // Otherwise we will do it in a separate stage - bool enable_reuse = merge_static_smem ? false : true; + if (merge_static_smem) { + // When `merge_static_smem` is true, we will reuse and merge shared + // memory in a dedicated pass `MergeSharedMemoryAllocations`. + // And so we don't enable reuse in this pass. + enable_reuse = false; + } + + Optional target = f->GetAttr("target"); + if (target.defined() && target.value()->kind->name == "vulkan") { + // Require exactly same-dtype matching in smem reuse for Vulkan + reuse_require_exact_matched_dtype = true; + } auto* n = f.CopyOnWrite(); - n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse); + n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse, + reuse_require_exact_matched_dtype); // Parameters may not be rewritten, but internal allocations may. // Vectorization of AllocateConst is currently disabled, as it has // indexing issues for types that include padding (e.g. int8x3 diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 197e81818ee3..4b71eb825414 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. import sys + import pytest + import tvm import tvm.testing from tvm import te @@ -928,5 +930,87 @@ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")): D[i] = C[i] +def test_vulkan_smem_reuse(): + target = tvm.target.Target( + { + "keys": ["vulkan", "gpu"], + "kind": "vulkan", + "max_num_threads": 256, + "max_threads_per_block": 256, + "supports_float32": T.bool(True), + "supports_int32": T.bool(True), + "tag": "", + "thread_warp_size": 1, + } + ) + + @T.prim_func(private=True) + def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + A_shared = T.allocate([4], "float32", "shared") + A_local = T.allocate([4], "float32", "local") + B_shared = T.allocate([4], "float16", "shared") + A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_1 = T.Buffer((4,), data=A.data) + A_shared_1[threadIdx_x] = A_1[threadIdx_x] + A_local_1 = T.Buffer((4,), data=A_local, scope="local") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] + B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) + threadIdx_x = T.launch_thread("threadIdx.x", 4) + B_1 = T.Buffer((4,), "float16", data=B.data) + B_1[threadIdx_x] = B_shared_1[threadIdx_x] + + @T.prim_func(private=True) + def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + A_shared = T.allocate([4], "float32", "shared") + A_local = T.allocate([4], "float32", "local") + A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_1 = T.Buffer((4,), data=A.data) + A_shared_1[threadIdx_x] = A_1[threadIdx_x] + A_local_1 = T.Buffer((4,), data=A_local, scope="local") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] + A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) + threadIdx_x = T.launch_thread("threadIdx.x", 4) + B_1 = T.Buffer((4,), "float16", data=B.data) + B_1[threadIdx_x] = A_shared_2[threadIdx_x] + + @T.prim_func(private=True) + def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): + T.func_attr({"target": target, "tir.noalias": T.bool(True)}) + A_shared_1 = T.allocate([4], "float32", "shared") + A_local_1 = T.allocate([4], "float32", "local") + B_shared_1 = T.allocate([4], "float16", "shared") + A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_1 = T.Buffer((4,), data=A.data) + A_shared_1_1[threadIdx_x] = A_1[threadIdx_x] + A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x] + B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared") + with T.launch_thread("threadIdx.x", 4) as threadIdx_x: + B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x]) + threadIdx_x = T.launch_thread("threadIdx.x", 4) + B_1 = T.Buffer((4,), "float16", data=B.data) + B_1[threadIdx_x] = B_shared_1_1[threadIdx_x] + + # Reuse shared memory when lowering without target. + mod = tvm.IRModule({"main": func}) + tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering) + + # No shared memory reuse when lowering with target Vulkan. + mod = tvm.tir.transform.BindTarget(target)(mod) + tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering) + + if __name__ == "__main__": tvm.testing.main()