Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Require exactly same-dtype matching for Vulkan smem reuse #16515

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/tir/transforms/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,11 +662,6 @@ namespace transform {
Pass MergeSharedMemoryAllocations() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable this pass for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
return f;
}
auto* n = f.CopyOnWrite();
n->body = MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem);
return f;
Expand Down
46 changes: 31 additions & 15 deletions src/tir/transforms/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,15 @@ class StoragePlanRewriter : public StmtExprMutator {
using StmtEntry = LinearAccessPatternFinder::StmtEntry;
using AllocEntry = LinearAccessPatternFinder::AllocEntry;

Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse = true) {
Stmt Rewrite(Stmt stmt, bool detect_inplace, bool enable_reuse,
bool reuse_require_exact_matched_dtype) {
detect_inplace_ = detect_inplace;
// plan the rewrite
LinearAccessPatternFinder finder;
finder(stmt);
this->LivenessAnalysis(finder.linear_seq_);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse);
this->PlanMemory(finder.linear_seq_, finder.alloc_info_, enable_reuse,
reuse_require_exact_matched_dtype);
all_buffers_accessed_ = finder.all_buffers_accessed_;
this->PrepareNewAlloc();
// start rewrite
Expand Down Expand Up @@ -817,7 +819,7 @@ class StoragePlanRewriter : public StmtExprMutator {
// Memory plan algorithm
void PlanMemory(const std::vector<StmtEntry>& seq,
const std::unordered_map<const VarNode*, AllocEntry>& alloc_info,
bool enable_reuse = true) {
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
std::unordered_set<const VarNode*> inplace_flag;

for (size_t i = 0; i < seq.size(); ++i) {
Expand Down Expand Up @@ -864,8 +866,9 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
if (dst_entry == nullptr) {
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope,
entry.num_physical_dimensions, enable_reuse);
dst_entry =
FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions,
enable_reuse, reuse_require_exact_matched_dtype);
}
dst_entry->allocs.emplace_back(alloc);
alloc_map_[var] = dst_entry;
Expand Down Expand Up @@ -919,7 +922,7 @@ class StoragePlanRewriter : public StmtExprMutator {

StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
const StorageScope& scope, size_t num_physical_dimensions,
bool enable_reuse = true) {
bool enable_reuse, bool reuse_require_exact_matched_dtype) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
Expand Down Expand Up @@ -958,6 +961,9 @@ class StoragePlanRewriter : public StmtExprMutator {
if (e->scope != scope) continue;
// when not divided, no reuse, eg, float4 vs float3
if (e->bits_offset % op_elem_bits != 0) continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
Expand All @@ -969,6 +975,9 @@ class StoragePlanRewriter : public StmtExprMutator {
if (e->attach_scope_ != attach_scope) continue;
if (e->scope != scope) continue;
if (e->elem_type != op->dtype.element_of()) continue;
if (reuse_require_exact_matched_dtype && e->elem_type != op->dtype) {
continue;
}
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
return e;
Expand Down Expand Up @@ -1704,17 +1713,24 @@ namespace transform {

Pass StorageRewrite() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
bool enable_reuse = true;
bool reuse_require_exact_matched_dtype = false;
bool merge_static_smem = ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
// disable merge_static_smem for Vulkan
auto target = Target::Current(true);
if (target.defined() && target->kind->name == "vulkan") {
merge_static_smem = false;
}
// Only enable reuse when we are not merging static shared memory.
// Otherwise we will do it in a separate stage
bool enable_reuse = merge_static_smem ? false : true;
if (merge_static_smem) {
// When `merge_static_smem` is true, we will reuse and merge shared
// memory in a dedicated pass `MergeSharedMemoryAllocations`.
// And so we don't enable reuse in this pass.
enable_reuse = false;
}

Optional<Target> target = f->GetAttr<Target>("target");
if (target.defined() && target.value()->kind->name == "vulkan") {
// Require exactly same-dtype matching in smem reuse for Vulkan
reuse_require_exact_matched_dtype = true;
}
auto* n = f.CopyOnWrite();
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse);
n->body = StoragePlanRewriter().Rewrite(std::move(n->body), true, enable_reuse,
reuse_require_exact_matched_dtype);
// Parameters may not be rewritten, but internal allocations may.
// Vectorization of AllocateConst is currently disabled, as it has
// indexing issues for types that include padding (e.g. int8x3
Expand Down
84 changes: 84 additions & 0 deletions tests/python/tir-transform/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import sys

import pytest

import tvm
import tvm.testing
from tvm import te
Expand Down Expand Up @@ -928,5 +930,87 @@ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")):
D[i] = C[i]


def test_vulkan_smem_reuse():
target = tvm.target.Target(
{
"keys": ["vulkan", "gpu"],
"kind": "vulkan",
"max_num_threads": 256,
"max_threads_per_block": 256,
"supports_float32": T.bool(True),
"supports_int32": T.bool(True),
"tag": "",
"thread_warp_size": 1,
}
)

@T.prim_func(private=True)
def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
A_shared = T.allocate([4], "float32", "shared")
A_local = T.allocate([4], "float32", "local")
B_shared = T.allocate([4], "float16", "shared")
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = B_shared_1[threadIdx_x]

@T.prim_func(private=True)
def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
A_shared = T.allocate([4], "float32", "shared")
A_local = T.allocate([4], "float32", "local")
A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1 = T.Buffer((4,), data=A_local, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = A_shared_2[threadIdx_x]

@T.prim_func(private=True)
def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
T.func_attr({"target": target, "tir.noalias": T.bool(True)})
A_shared_1 = T.allocate([4], "float32", "shared")
A_local_1 = T.allocate([4], "float32", "local")
B_shared_1 = T.allocate([4], "float16", "shared")
A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_1 = T.Buffer((4,), data=A.data)
A_shared_1_1[threadIdx_x] = A_1[threadIdx_x]
A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x]
B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared")
with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x])
threadIdx_x = T.launch_thread("threadIdx.x", 4)
B_1 = T.Buffer((4,), "float16", data=B.data)
B_1[threadIdx_x] = B_shared_1_1[threadIdx_x]

# Reuse shared memory when lowering without target.
mod = tvm.IRModule({"main": func})
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering)

# No shared memory reuse when lowering with target Vulkan.
mod = tvm.tir.transform.BindTarget(target)(mod)
tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering)


if __name__ == "__main__":
tvm.testing.main()
Loading