Skip to content

Commit

Permalink
clean up LowerWarpmemory
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 178c3dc commit 754c83e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 42 deletions.
1 change: 0 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/index_map.h>
#include <tvm/arith/iter_affine_map.h>

#include <algorithm>
#include <cmath>
Expand Down
59 changes: 19 additions & 40 deletions src/tir/transforms/lower_warp_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ class WarpStoreCoeffFinder : private StmtExprVisitor {
if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() == buffer_) {
UpdatePattern(op->args[4]);
} else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) {
auto* ptr = op->args[0].as<IntImmNode>();
CHECK(ptr);
warp_coeff_ = ptr->value;;
auto* local_size = op->args[0].as<IntImmNode>();
ICHECK(local_size) << "Integer expected for the first argument of mma_fill";
warp_coeff_ = local_size->value;
;
}

StmtExprVisitor::VisitExpr_(op);
Expand Down Expand Up @@ -257,54 +258,32 @@ class WarpAccessRewriter : protected StmtExprMutator {
}

protected:
PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector<int>& indices) {
Array<PrimExpr> new_args = op->args;
for (int i : indices) {
if (op->args[i].get() == buffer_) {
PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first;
new_args.Set(i + 1, local_index);
}
}
return Call(op->dtype, op->op, new_args);
}

PrimExpr VisitExpr_(const CallNode* op) override {
if (op->op.same_as(builtin::ptx_mma())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
bool changed = false;
int A_warp_arg_ind = 6;
for (int i = A_warp_arg_ind; i < A_warp_arg_ind + 6; i += 2) {
if (op->args[i].get() == buffer_) {
std::tie(local_index, group) = SplitIndexByGroup(op->args[i + 1]);
new_args.Set(i + 1, local_index);
changed = true;
}
}
if (!changed) return GetRef<PrimExpr>(op);
return Call(op->dtype, op->op, new_args);
return RewriteIndicesAt(op, {6, 8, 10});
}

if (op->op.same_as(builtin::ptx_ldmatrix())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_index, group;
if (op->args[3].get() == buffer_) {
std::tie(local_index, group) = SplitIndexByGroup(op->args[4]);
new_args.Set(4, local_index);
return Call(op->dtype, op->op, new_args);
}
return GetRef<PrimExpr>(op);
return RewriteIndicesAt(op, {3});
}

if (op->op.same_as(builtin::mma_store())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_offset, group;
if (op->args[3].get() == buffer_) {
std::tie(local_offset, group) = SplitIndexByGroup(op->args[4]);
new_args.Set(4, local_offset);
return Call(op->dtype, op->op, new_args);
}
return GetRef<PrimExpr>(op);
return RewriteIndicesAt(op, {3});
}

if (op->op.same_as(builtin::mma_fill())) {
Array<PrimExpr> new_args = op->args;
PrimExpr local_offset, group;
if (op->args[1].get() == buffer_) {
std::tie(local_offset, group) = SplitIndexByGroup(op->args[2]);
new_args.Set(2, local_offset);
return Call(op->dtype, op->op, new_args);
}
return GetRef<PrimExpr>(op);
return RewriteIndicesAt(op, {1});
}

return StmtExprMutator::VisitExpr_(op);
Expand Down

0 comments on commit 754c83e

Please sign in to comment.