Skip to content

Commit

Permalink
[OPTIMIZER] Improved layout simplifications heuristics (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
ptillet committed Feb 10, 2023
1 parent c61c8a1 commit 2aba985
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 47 deletions.
7 changes: 5 additions & 2 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,13 @@ def TT_StoreOp : TT_Op<"store",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">]> {
let summary = "store";

let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask);
let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional<TT_BoolLike>:$mask,
DefaultValuedAttr<TT_CacheModifierAttr, "triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "triton::EvictionPolicy::NORMAL">:$evict);

let builders = [
OpBuilder<(ins "Value":$ptr, "Value":$value)>,
OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict)>,
];

// let assemblyFormat = "operands attr-dict `:` type($value)";
Expand Down
36 changes: 35 additions & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,23 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
}
}
}
} else if (Operation *op = value.getDefiningOp()) {
DimVectorT knownContiguity(rank, 1);
DimVectorT knownDivisibility(rank, 1);
DimVectorT knownConstancy(rank, 1);
if (Attribute attr = op->getAttr("tt.divisibility")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownDivisibility = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.contiguity")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownContiguity = DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.constancy")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
knownConstancy = DimVectorT(vals.begin(), vals.end());
}
return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
}

return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint),
Expand Down Expand Up @@ -818,7 +835,24 @@ ChangeResult AxisInfoAnalysis::visitOperation(
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(op->getResults());
}

// override with hint
auto newContiguity = curr.getContiguity();
auto newDivisibility = curr.getDivisibility();
auto newConstancy = curr.getConstancy();
if (Attribute attr = op->getAttr("tt.contiguity")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.divisibility")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
if (Attribute attr = op->getAttr("tt.constancy")) {
auto vals = attr.cast<DenseElementsAttr>().getValues<int>();
newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end());
}
curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy,
curr.getConstantValue());
// join all lattice elements
ChangeResult result = ChangeResult::NoChange;
for (Value value : op->getResults()) {
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<triton::StoreOp>(
op, adaptor.ptr(), adaptor.value(), adaptor.mask());
op, adaptor.ptr(), adaptor.value(), adaptor.mask(), adaptor.cache(),
adaptor.evict());
return success();
}
};
Expand Down
7 changes: 5 additions & 2 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,11 @@ bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,

//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value) {
StoreOp::build(builder, state, ptr, value, mlir::Value());
::mlir::Value ptr, ::mlir::Value value,
::mlir::triton::CacheModifier cache,
::mlir::triton::EvictionPolicy evict) {
return StoreOp::build(builder, state, ptr, value, mlir::Value(), cache,
evict);
}

//-- LoadOp --
Expand Down
5 changes: 3 additions & 2 deletions lib/Dialect/Triton/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ struct CanonicalizeMaskedStorePattern

if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::StoreOp>(storeOp, storeOp.ptr(),
storeOp.value());
rewriter.replaceOpWithNewOp<triton::StoreOp>(
storeOp, storeOp.ptr(), storeOp.value(), storeOp.cache(),
storeOp.evict());
} else {
// mask = splat(0)
rewriter.eraseOp(storeOp);
Expand Down
43 changes: 36 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ class SimplifyConversion : public mlir::RewritePattern {
// block argument
if (!arg)
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), view.getResult());
return mlir::success();
}
// cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2)
auto alloc_tensor = dyn_cast<triton::gpu::AllocTensorOp>(arg);
if (alloc_tensor) {
Expand Down Expand Up @@ -278,6 +284,9 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
return failure();
ret = sliceEncoding.getParent();
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
return failure();
}
return success();
}

Expand All @@ -287,16 +296,23 @@ inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
if (isSingleValue(op->getOperand(0)))
return false;
auto ptr = op->getOperand(0);
// Case 2: We assume that `evict_last` loads/stores have high hit rate
if (auto load = dyn_cast<triton::LoadOp>(op))
if (load.evict() == triton::EvictionPolicy::EVICT_LAST)
return false;
if (auto store = dyn_cast<triton::StoreOp>(op))
if (store.evict() == triton::EvictionPolicy::EVICT_LAST)
return false;
if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
auto encoding = tensorTy.getEncoding();
// Case 2: Different type conversion is expensive (e.g., mma <-> block)
// Case 3: Different type conversion is expensive (e.g., mma <-> block)
if (encoding.getTypeID() != targetEncoding.getTypeID())
return true;
auto sizePerThread = triton::gpu::getSizePerThread(encoding);
auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
auto order = triton::gpu::getOrder(encoding);
auto targetOrder = triton::gpu::getOrder(targetEncoding);
// Case 3: The targeEncoding may expose more vectorization opportunities
// Case 4: The targeEncoding may expose more vectorization opportunities
return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
}
return false;
Expand Down Expand Up @@ -365,6 +381,9 @@ LogicalResult simulateBackwardRematerialization(
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
continue;

// We add one expensive conversion for the current operand
numCvts += 1;
queue.emplace_back(opArgI, newEncoding);
Expand All @@ -383,9 +402,9 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
BlockAndValueMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(),
newOp->getOperand(0).getType().cast<RankedTensorType>().getEncoding());
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
if (typeInfer) {
Expand Down Expand Up @@ -425,6 +444,11 @@ void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
}
}
rewriter.setInsertionPoint(op);
if (op->getNumResults() == 0) {
Operation *newOp = rewriter.clone(*op, mapping);
rewriter.eraseOp(op);
return;
}
auto *newOp = cloneWithInferType(rewriter, op, mapping);
auto newType = newOp->getResult(0).getType().cast<RankedTensorType>();
auto newCvtType = RankedTensorType::get(
Expand Down Expand Up @@ -564,17 +588,22 @@ class FoldConvertAndReduce : public mlir::RewritePattern {
!isa<triton::gpu::ConvertLayoutOp>(op) && !isa<scf::YieldOp>(op);
};
mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter);
if (cvtSlices.empty())
if (cvtSlices.empty()) {
return failure();
}

llvm::MapVector<Value, Attribute> toConvert;
for (Operation *op : cvtSlices) {
// don't rematerialize anything expensive
if (expensiveToRemat(op, srcEncoding))
if (expensiveToRemat(op, dstEncoding)) {
return failure();
}
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::Elementwise>())
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op)) {
return failure();
}
// don't rematerialize if it adds an extra conversion that can't
// be removed
for (Value arg : op->getOperands()) {
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ LogicalResult LoopPipeliner::initialize() {
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
auto ptr = loadOp.ptr();
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
auto ty = getElementTypeOrSelf(ptr.getType())
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
continue;
auto ty = tensorTy.getElementType()
.cast<triton::PointerType>()
.getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();
Expand Down
23 changes: 17 additions & 6 deletions python/src/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,14 @@ void init_triton_ir(py::module &&m) {
if (mlir::Operation *definingOp = self.getDefiningOp())
definingOp->setAttr(name, attr);
else {
/* issue a warning */
auto arg = self.cast<mlir::BlockArgument>();
int id = arg.getArgNumber();
std::string attrName = name + "_arg" + std::to_string(id);
mlir::Block *owner = arg.getOwner();
if (owner->isEntryBlock() &&
!mlir::isa<mlir::FuncOp>(owner->getParentOp())) {
owner->getParentOp()->setAttr(attrName, attr);
}
}
})
.def("get_context", &mlir::Value::getContext)
Expand Down Expand Up @@ -1082,10 +1089,12 @@ void init_triton_ir(py::module &&m) {
loc, ptrs, cacheModifier, evictionPolicy, isVolatile);
})
.def("create_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs,
mlir::Value &value) -> void {
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptrs, value);
self.create<mlir::triton::StoreOp>(loc, ptrs, value, cacheModifier,
evictionPolicy);
})
.def("create_masked_load",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask,
Expand All @@ -1100,9 +1109,11 @@ void init_triton_ir(py::module &&m) {
})
.def("create_masked_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val,
mlir::Value &mask) -> void {
mlir::Value &mask, mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy) -> void {
auto loc = self.getUnknownLoc();
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask);
self.create<mlir::triton::StoreOp>(loc, ptrs, val, mask,
cacheModifier, evictionPolicy);
})
.def("create_view",
[](mlir::OpBuilder &self, mlir::Value &arg,
Expand Down
14 changes: 13 additions & 1 deletion python/triton/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,18 @@ def visit_compound_statement(self, stmts):
break
return stmts and isinstance(stmt, ast.Return)

def contains_return_op(self, node):
if isinstance(node, ast.Return):
return True
elif isinstance(node, ast.If):
pred = lambda s: self.contains_return_op(s)
ret = any(pred(s) for s in node.body)
if node.orelse:
ret = ret or any(pred(s) for s in node.orelse)
return ret
else:
return False

def visit_Module(self, node):
ast.NodeVisitor.generic_visit(self, node)

Expand Down Expand Up @@ -475,7 +487,7 @@ def visit_If(self, node):
cond = self.visit(node.test)
if isinstance(cond, triton.language.tensor):
cond = cond.to(triton.language.int1, _builder=self.builder)
if self.scf_stack:
if self.scf_stack or not self.contains_return_op(node):
self.visit_if_scf(cond, node)
else:
self.visit_if_top_level(cond, node)
Expand Down
6 changes: 4 additions & 2 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="",


@builtin
def store(pointer, value, mask=None, _builder=None):
def store(pointer, value, mask=None, cache_modifier="", eviction_policy="", _builder=None):
"""
Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`.
Expand All @@ -890,7 +890,9 @@ def store(pointer, value, mask=None, _builder=None):
value = _to_tensor(value, _builder)
if _constexpr_to_value(mask) is not None:
mask = _to_tensor(mask, _builder)
return semantic.store(pointer, value, mask, _builder)
cache_modifier = _constexpr_to_value(cache_modifier)
eviction_policy = _constexpr_to_value(eviction_policy)
return semantic.store(pointer, value, mask, cache_modifier, eviction_policy, _builder)


# -----------------------
Expand Down
1 change: 1 addition & 0 deletions python/triton/language/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
"""
for _ in tl.static_range(n_rounds):
# for _ in range(n_rounds):
# update random state
A = PHILOX_ROUND_A
B = PHILOX_ROUND_B
Expand Down
Loading

0 comments on commit 2aba985

Please sign in to comment.