diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc index 99055cebf2dc..6c7b0f649cfe 100644 --- a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -447,9 +447,13 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { // - leaf stmt(offset = 0) // - end of scope(offset < 0) // In both cases, we need to handle the kill event correctly + auto is_leaf_alloc = [&](const VarNode* var) { + return seq[i].scope_pair_offset == 0 && + std::find(it->second.gen.begin(), it->second.gen.end(), var) != it->second.gen.end(); + }; if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { for (const VarNode* var : it->second.kill) { - this->Free(var); + if (!is_leaf_alloc(var)) this->Free(var); } } // scope_pair_offset >= 0 means it is either @@ -464,6 +468,11 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { alloc_map_[var] = dst_entry; } } + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const VarNode* var : it->second.kill) { + if (is_leaf_alloc(var)) this->Free(var); + } + } } } /*! @@ -510,6 +519,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator { StorageEntry* e = it->second; e->const_nbits = std::max(const_nbits, e->const_nbits); const_free_map_.erase(it); + it->second->allocs.push_back({op->buffer_var.get()}); return e; } // Then start looking at smaller buffers. diff --git a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py index 37372059a296..5dc5ec863fc1 100644 --- a/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/tir-transform/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -376,7 +376,6 @@ def func( C_local[0] = T.float32(0) for i in range(64): - A_sh[threadIdx_y * 16 + threadIdx_x] = A_flat[ blockIdx_y * 16384 + threadIdx_y * 1024 + i * 16 + threadIdx_x ] @@ -454,5 +453,65 @@ def func( return func +class TestSimpleAllocNoReuse(tvm.testing.CompareBeforeAfter): + """Test alloc and free within the same scope.""" + + transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() + + def before(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + A_sh_data = T.allocate([128], "float32", "shared.dyn") + B_sh_data = T.allocate([128], "float32", "shared.dyn") + A_sh = T.decl_buffer([128], data=A_sh_data, scope="shared.dyn") + B_sh = T.decl_buffer([128], data=B_sh_data, scope="shared.dyn") + B_sh[threadIdx_x] = A_sh[threadIdx_x] + + return func + + def expected(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn") + A_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + B_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + B_sh[threadIdx_x + 128] = A_sh[threadIdx_x] + + return func + + +class TestSimpleAllocReuse(tvm.testing.CompareBeforeAfter): + """Test alloc and free within the same scope with a reuse chance.""" + + transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations() + + def before(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + A_sh_data = T.allocate([128], "float32", "shared.dyn") + B_sh_data = T.allocate([128], "float32", "shared.dyn") + A_sh = T.decl_buffer([128], data=A_sh_data, scope="shared.dyn") + B_sh = T.decl_buffer([128], data=B_sh_data, scope="shared.dyn") + A_sh[threadIdx_x] = 0 + B_sh[threadIdx_x] = 0 + + return func + + def expected(self): + @T.prim_func + def func(): + threadIdx_x = T.launch_thread("threadIdx.x", 128) + buf_dyn_shmem = T.allocate([512], "uint8", "shared.dyn") + A_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + B_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn") + A_sh[threadIdx_x] = 0 + B_sh[threadIdx_x] = 0 + + return func + + if __name__ == "__main__": tvm.testing.main()