Skip to content

Commit

Permalink
Move isExpensiveCat() to TritonGPU/IR/Dialect to avoid cyclic BUILD…
Browse files Browse the repository at this point in the history
… dependency.
  • Loading branch information
chsigg committed Jun 27, 2023
1 parent 79acf6e commit 644d92b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
2 changes: 2 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ bool isSharedEncoding(Value value);
} // namespace gpu
} // namespace triton

bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding);

} // namespace mlir

#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
2 changes: 0 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,

bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding);

bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding);

bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding);

// skipInit is True when we only consider the operands of the initOp but
Expand Down
14 changes: 13 additions & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
Expand Down Expand Up @@ -367,6 +366,19 @@ bool isSharedEncoding(Value value) {
} // namespace gpu
} // namespace triton

bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}

} // namespace mlir

static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
Expand Down
13 changes: 0 additions & 13 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,19 +109,6 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
return true;
}

bool isExpensiveCat(triton::CatOp cat, Attribute &targetEncoding) {
// If the new elements per thread is less than the old one, we will need to do
// convert encoding that goes through shared memory anyway. So we consider it
// as expensive.
auto tensorTy = cat.getResult().getType().cast<RankedTensorType>();
auto totalElemsPerThread = triton::gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto elemTy = tensorTy.getElementType();
auto newTotalElemsPerThread =
triton::gpu::getTotalElemsPerThread(targetEncoding, shape, elemTy);
return newTotalElemsPerThread < totalElemsPerThread;
}

bool isExpensiveToRemat(Operation *op, Attribute &targetEncoding) {
if (!op)
return true;
Expand Down

0 comments on commit 644d92b

Please sign in to comment.