Skip to content

Commit

Permalink
fixed bug in LowerWarpMemory index splitting for ldmatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 00df308 commit b9f7eae
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ class WarpAccessRewriter : protected StmtExprMutator {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
if (op->args[3].get() == buffer_) {
new_args.Set(4, 0);
std::tie(local_index, group) = SplitIndexByGroup(op->args[4]);
new_args.Set(4, local_index);
return Call(op->dtype, op->op, new_args);
}
return GetRef<PrimExpr>(op);
Expand Down Expand Up @@ -466,11 +467,13 @@ Pass LowerWarpMemory() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
LOG(INFO)<< "Before LowerWarpMemory \n" << f;
int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value();
WarpMemoryRewriter warp_memory_rewriter(warp_size);
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
LOG(INFO)<< f;

return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
Expand Down

0 comments on commit b9f7eae

Please sign in to comment.