Skip to content

Commit

Permalink
[TIR] Require exactly same-dtype matching for Vulkan smem reuse (#16515)
Browse files Browse the repository at this point in the history
This PR fixes the StorageRewrite pass which failed to avoid
shared memory reuse of different dtypes for Vulkan.

Since the Vulkan target information is required at the time
of lowering, the pass `BindTarget` needs to apply before
lowering, so that the functions have correct target information.
Note that previously the pass checks `Target::Current`, while
`tvm.build` does not set the current target.

One regression test is added.
  • Loading branch information
MasterJH5574 authored Feb 3, 2024
1 parent a3ec544 commit cdc2303
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 20 deletions.
5 changes: 0 additions & 5 deletions src/tir/transforms/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,11 +662,6 @@ namespace transform {
Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable this pass for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
return f;
}
auto* n = f.CopyOnWrite();
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
return f;
Expand Down
46 changes: 31 additions & 15 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,15 @@ class StoragePlanRewriter : public StmtExprMutator {
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
using AllocEntry = LinearAccessPatternFinder::AllocEntry;

Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) {
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
bool reuse_require_exact_matched_dtype) {
detect_inplace_ = detect_inplace;
// plan the rewrite
LinearAccessPatternFinder finder;
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse,
reuse_require_exact_matched_dtype);
all_buffers_accessed_ = finder.all_buffers_accessed_;
this->PrepareNewAlloc();
// start rewrite
Expand Down Expand Up @@ -817,7 +819,7 @@ class StoragePlanRewriter : public StmtExprMutator {
// Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq,
const std::unordered_map<const VarNode*, AllocEntry>& alloc_info,
bool enable_reuse = true) {
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
std::unordered_set<const VarNode*> inplace_flag;

for (size_t i = 0; i < seq.size(); ++i) {
Expand Down Expand Up @@ -864,8 +866,9 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
if (dst_entry == nullptr) {
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
entry.num_physical_dimensions, enable_reuse);
dst_entry =
FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions,
enable_reuse, reuse_require_exact_matched_dtype);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
Expand Down Expand Up @@ -919,7 +922,7 @@ class StoragePlanRewriter : public StmtExprMutator {

StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
const StorageScope& scope, size_t num_physical_dimensions,
bool enable_reuse = true) {
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
Expand Down Expand Up @@ -958,6 +961,9 @@ class StoragePlanRewriter : public StmtExprMutator {
if (e->scope != scope) continue;
// when not divided, no reuse, eg, float4 vs float3
if (e->bits_offset % op_elem_bits != 0) continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
Expand All @@ -969,6 +975,9 @@ class StoragePlanRewriter : public StmtExprMutator {
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->dtype.element_of()) continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
Expand Down Expand Up @@ -1704,17 +1713,24 @@ namespace transform {

Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable merge_static_smem for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
merge_static_smem = false;
}
// Only enable reuse when we are not merging static shared memory.
// Otherwise we will do it in a separate stage
bool enable_reuse = merge_static_smem ? false : true;
if (merge_static_smem) {
// When `merge_static_smem` is true, we will reuse and merge shared
// memory in a dedicated pass `MergeSharedMemoryAllocations`.
// And so we don't enable reuse in this pass.
enable_reuse = false;
}

Optional<Target> target = f->GetAttr<Target>("target");
if (target.defined() && target.value()->kind->name == "vulkan") {
// Require exactly same-dtype matching in smem reuse for Vulkan
reuse_require_exact_matched_dtype = true;
}
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse);
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
reuse_require_exact_matched_dtype);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
Expand Down
84 changes: 84 additions & 0 deletions tests/python/tir-transform/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import sys

import pytest

import tvm
import tvm.testing
from tvm import te
Expand Down Expand Up @@ -928,5 +930,87 @@ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
D[i] = C[i]


def test_vulkan_smem_reuse():
target = tvm.target.Target(
{
"keys": ["vulkan", "gpu"],
"kind": "vulkan",
"max_num_threads": 256,
"max_threads_per_block": 256,
"supports_float32": T.bool(True),
"supports_int32": T.bool(True),
"tag": "",
"thread_warp_size": 1,
}
)

@T.prim_func(private=True)
def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
A_shared = T.allocate([4], "float32", "shared")
A_local = T.allocate([4], "float32", "local")
B_shared = T.allocate([4], "float16", "shared")
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = B_shared_1[threadIdx_x]

@T.prim_func(private=True)
def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
A_shared = T.allocate([4], "float32", "shared")
A_local = T.allocate([4], "float32", "local")
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = A_shared_2[threadIdx_x]

@T.prim_func(private=True)
def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"target": target, "tir.noalias": T.bool(True)})
A_shared_1 = T.allocate([4], "float32", "shared")
A_local_1 = T.allocate([4], "float32", "local")
B_shared_1 = T.allocate([4], "float16", "shared")
A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x]
B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = B_shared_1_1[threadIdx_x]

# Reuse shared memory when lowering without target.
mod = tvm.IRModule({"main": func})
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering)

# No shared memory reuse when lowering with target Vulkan.
mod = tvm.tir.transform.BindTarget(target)(mod)
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit cdc2303

Please sign in to comment.