From b9f7eae7041d1a9b3e434c331c874e8347e89dc4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 28 Apr 2022 18:01:08 +0900 Subject: [PATCH] fixed bug in LowerWarpMemory index splitting for ldmatrix --- src/tir/transforms/lower_warp_memory.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index e9a8513391c5..436e8db12160 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -274,7 +274,8 @@ class WarpAccessRewriter : protected StmtExprMutator { Array new_args = op->args; PrimExpr local_index, group; if (op->args[3].get() == buffer_) { - new_args.Set(4, 0); + std::tie(local_index, group) = SplitIndexByGroup(op->args[4]); + new_args.Set(4, local_index); return Call(op->dtype, op->op, new_args); } return GetRef(op); @@ -466,11 +467,13 @@ Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto target = f->GetAttr(tvm::attr::kTarget); - ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; + LOG(INFO)<< "Before LowerWarpMemory \n" << f; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); + LOG(INFO)<< f; + return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});