Skip to content

Commit

Permalink
[TRANSFORM] Triton Dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Jul 14, 2023
1 parent d0ed596 commit 6aa244d
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 103 deletions.
38 changes: 20 additions & 18 deletions lib/Dialect/TritonGPU/Transforms/RewriteTensorPointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ struct RewritedInfo {
tensorShape[i]);
Value i64Range = builder.create<arith::ExtSIOp>(loc, indexRowType, range);

auto argEncoding = layout.dyn_cast<triton::gpu::BlockedEncodingAttr>();
assert(argEncoding);
auto argEncoding = layout.cast<triton::gpu::BlockedEncodingAttr>();

// Expand dimensions
Value expandedResult =
Expand Down Expand Up @@ -236,9 +235,8 @@ struct RewritedInfo {

Value generateMask(OpBuilder &builder, const Location &loc,
const std::optional<ArrayRef<int32_t>> &boundaryCheck) {
if (!boundaryCheck.has_value() || boundaryCheck.value().size() == 0) {
if (!boundaryCheck.has_value() || boundaryCheck.value().empty())
return {};
}

// Generate mask per dimension
auto maskTensorType =
Expand Down Expand Up @@ -415,11 +413,14 @@ class RewriteTensorPointerPass
Operation *rewriteLoadStoreOp(OpBuilder &builder, Operation *op,
std::stack<Operation *> &eraser,
const DenseSet<Value> &valueToRemove) {

assert(isa<triton::LoadOp>(op) || isa<triton::StoreOp>(op));
if (!valueToRemove.count(op->getResult(0))) {
return nullptr;
}
if (isa<triton::LoadOp>(op)) {
if (!valueToRemove.count(op->getResult(0)))
return nullptr;
} else if (isa<triton::StoreOp>(op)) {
if (!valueToRemove.count(op->getOperand(0)))
return nullptr;
} else
llvm_unreachable("Unsupported operation");

// We only have to rewrite load/stores with tensor pointers
auto ptr = op->getOperand(0);
Expand All @@ -439,21 +440,18 @@ class RewriteTensorPointerPass
assert(!loadOp.getMask() && !loadOp.getOther());
boundaryCheck = loadOp.getBoundaryCheck();
if (auto valueType =
loadOp.getResult().getType().dyn_cast<RankedTensorType>()) {
dyn_cast<RankedTensorType>(loadOp.getResult().getType()))
info.setEncoding(valueType.getEncoding());
}
} else if (auto storeOp = dyn_cast<triton::StoreOp>(op)) {
assert(!storeOp.getMask());
boundaryCheck = storeOp.getBoundaryCheck();
if (auto valueType =
storeOp.getValue().getType().dyn_cast<RankedTensorType>()) {
dyn_cast<RankedTensorType>(storeOp.getValue().getType()))
info.setEncoding(valueType.getEncoding());
}
}

// Generate new `ptr`, `mask` and `other`
auto newPtr = info.generatePtr(builder, op->getLoc());

auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck);
Value newOther;
if (auto loadOp = dyn_cast<triton::LoadOp>(op))
Expand Down Expand Up @@ -532,9 +530,8 @@ class RewriteTensorPointerPass
for (Operation &opInFor : *op.getBody()) {
Operation *newOp = builder.clone(opInFor, mapping);
for (unsigned i = 0; i < opInFor.getNumResults(); ++i) {
if (valueToRemove.count(opInFor.getResult(i))) {
if (valueToRemove.count(opInFor.getResult(i)))
valueToRemove.insert(newOp->getResult(i));
}
mapping.map(opInFor.getResult(i), newOp->getResult(i));
}
}
Expand Down Expand Up @@ -643,6 +640,7 @@ class RewriteTensorPointerPass
}

void runOnOperation() override {
// XXX(Keren): revisit this condition
if (computeCapability >= 90) {
return;
}
Expand All @@ -660,8 +658,12 @@ class RewriteTensorPointerPass
auto src = op->getOperand(0);
if (triton::isTensorPointerType(src.getType())) {
auto makeTensorPtrOp = getMakeTensorPtrOp(src);
if (shouldRemove(makeTensorPtrOp, computeCapability))
valueToRemove.insert(op->getResult(0));
if (shouldRemove(makeTensorPtrOp, computeCapability)) {
if (isa<triton::StoreOp>(op))
valueToRemove.insert(src);
else
valueToRemove.insert(op->getResult(0));
}
}
}
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Expand Down
1 change: 0 additions & 1 deletion test/Triton/combine.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine
// RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s

// CHECK-LABEL: @test_combine_dot_add_pattern
Expand Down
83 changes: 0 additions & 83 deletions test/Triton/rewrite-tensor-pointer.mlir

This file was deleted.

2 changes: 1 addition & 1 deletion test/TritonGPU/rewrite-tensor-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
%c1_i64 = arith.constant 1 : i64
%c128_i32 = arith.constant 128 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%0 = tt.get_program_id x : i32
%1 = arith.addi %arg5, %c127_i32 : i32
%2 = arith.divsi %1, %c128_i32 : i32
%3 = arith.addi %arg4, %c127_i32 : i32
Expand Down

0 comments on commit 6aa244d

Please sign in to comment.