Skip to content

Commit

Permalink
[BugFix][TIR] Fix dynamic smem merge leaf alloc (#16216)
Browse files Browse the repository at this point in the history
  • Loading branch information
nox-410 authored Dec 14, 2023
1 parent b0e146f commit c8bfdb2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,13 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
// - leaf stmt(offset = 0)
// - end of scope(offset < 0)
// In both cases, we need to handle the kill event correctly
auto is_leaf_alloc = [&](const VarNode* var) {
return seq[i].scope_pair_offset == 0 &&
std::find(it->second.gen.begin(), it->second.gen.end(), var) != it->second.gen.end();
};
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode* var : it->second.kill) {
this->Free(var);
if (!is_leaf_alloc(var)) this->Free(var);
}
}
// scope_pair_offset >= 0 means it is either
Expand All @@ -464,6 +468,11 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
alloc_map_[var] = dst_entry;
}
}
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode* var : it->second.kill) {
if (is_leaf_alloc(var)) this->Free(var);
}
}
}
}
/*!
Expand Down Expand Up @@ -510,6 +519,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
StorageEntry* e = it->second;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
it->second->allocs.push_back({op->buffer_var.get()});
return e;
}
// Then start looking at smaller buffers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def func(

C_local[0] = T.float32(0)
for i in range(64):

A_sh[threadIdx_y * 16 + threadIdx_x] = A_flat[
blockIdx_y * 16384 + threadIdx_y * 1024 + i * 16 + threadIdx_x
]
Expand Down Expand Up @@ -454,5 +453,65 @@ def func(
return func


class TestSimpleAllocNoReuse(tvm.testing.CompareBeforeAfter):
"""Test alloc and free within the same scope."""

transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()

def before(self):
@T.prim_func
def func():
threadIdx_x = T.launch_thread("threadIdx.x", 128)
A_sh_data = T.allocate([128], "float32", "shared.dyn")
B_sh_data = T.allocate([128], "float32", "shared.dyn")
A_sh = T.decl_buffer([128], data=A_sh_data, scope="shared.dyn")
B_sh = T.decl_buffer([128], data=B_sh_data, scope="shared.dyn")
B_sh[threadIdx_x] = A_sh[threadIdx_x]

return func

def expected(self):
@T.prim_func
def func():
threadIdx_x = T.launch_thread("threadIdx.x", 128)
buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn")
A_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn")
B_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn")
B_sh[threadIdx_x + 128] = A_sh[threadIdx_x]

return func


class TestSimpleAllocReuse(tvm.testing.CompareBeforeAfter):
"""Test alloc and free within the same scope with a reuse chance."""

transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()

def before(self):
@T.prim_func
def func():
threadIdx_x = T.launch_thread("threadIdx.x", 128)
A_sh_data = T.allocate([128], "float32", "shared.dyn")
B_sh_data = T.allocate([128], "float32", "shared.dyn")
A_sh = T.decl_buffer([128], data=A_sh_data, scope="shared.dyn")
B_sh = T.decl_buffer([128], data=B_sh_data, scope="shared.dyn")
A_sh[threadIdx_x] = 0
B_sh[threadIdx_x] = 0

return func

def expected(self):
@T.prim_func
def func():
threadIdx_x = T.launch_thread("threadIdx.x", 128)
buf_dyn_shmem = T.allocate([512], "uint8", "shared.dyn")
A_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn")
B_sh = T.decl_buffer((128,), data=buf_dyn_shmem, scope="shared.dyn")
A_sh[threadIdx_x] = 0
B_sh[threadIdx_x] = 0

return func


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c8bfdb2

Please sign in to comment.