From 2aba985daaa70234823ea8f1161da938477d3e02 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 9 Feb 2023 20:17:25 -0800 Subject: [PATCH] [OPTIMIZER] Improved layout simplifications heuristics (#1168) --- include/triton/Dialect/Triton/IR/TritonOps.td | 7 ++- lib/Analysis/AxisInfo.cpp | 36 +++++++++++- .../TritonToTritonGPUPass.cpp | 3 +- lib/Dialect/Triton/IR/Ops.cpp | 7 ++- lib/Dialect/Triton/Transforms/Combine.cpp | 5 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 43 ++++++++++++--- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 5 +- python/src/triton.cc | 23 ++++++-- python/triton/compiler.py | 14 ++++- python/triton/language/core.py | 6 +- python/triton/language/random.py | 1 + python/triton/language/semantic.py | 55 ++++++++++++------- test/TritonGPU/combine.mlir | 2 +- 13 files changed, 160 insertions(+), 47 deletions(-) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 7303681158bb..9902ee447cdb 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -161,10 +161,13 @@ def TT_StoreOp : TT_Op<"store", "($_op.getOperands().size() <= 2) || std::equal_to<>()">]> { let summary = "store"; - let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional:$mask); + let arguments = (ins TT_PtrLike:$ptr, TT_Type:$value, Optional:$mask, + DefaultValuedAttr:$cache, + DefaultValuedAttr:$evict); let builders = [ - OpBuilder<(ins "Value":$ptr, "Value":$value)>, + OpBuilder<(ins "Value":$ptr, "Value":$value, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict)>, ]; // let assemblyFormat = "operands attr-dict `:` type($value)"; diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 675b5e324c7d..0b7142b04ded 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -85,6 +85,23 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { } } } + } else if (Operation *op = value.getDefiningOp()) { + DimVectorT knownContiguity(rank, 1); + DimVectorT knownDivisibility(rank, 1); + DimVectorT knownConstancy(rank, 1); + if (Attribute attr = op->getAttr("tt.divisibility")) { + auto vals = attr.cast().getValues(); + knownDivisibility = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getAttr("tt.contiguity")) { + auto vals = attr.cast().getValues(); + knownContiguity = DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getAttr("tt.constancy")) { + auto vals = attr.cast().getValues(); + knownConstancy = DimVectorT(vals.begin(), vals.end()); + } + return AxisInfo(knownContiguity, knownDivisibility, knownConstancy); } return AxisInfo(/*knownContiguity=*/DimVectorT(rank, contiHint), @@ -818,7 +835,24 @@ ChangeResult AxisInfoAnalysis::visitOperation( if (curr.getRank() == 0) { return markAllPessimisticFixpoint(op->getResults()); } - + // override with hint + auto newContiguity = curr.getContiguity(); + auto newDivisibility = curr.getDivisibility(); + auto newConstancy = curr.getConstancy(); + if (Attribute attr = op->getAttr("tt.contiguity")) { + auto vals = attr.cast().getValues(); + newContiguity = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getAttr("tt.divisibility")) { + auto vals = attr.cast().getValues(); + newDivisibility = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + if (Attribute attr = op->getAttr("tt.constancy")) { + auto vals = attr.cast().getValues(); + newConstancy = AxisInfo::DimVectorT(vals.begin(), vals.end()); + } + curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy, + curr.getConstantValue()); // join all lattice elements ChangeResult result = ChangeResult::NoChange; for (Value value : op->getResults()) { diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 9f4ae5950ee5..fe42202c342b 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -345,7 +345,8 @@ struct TritonStorePattern : public OpConversionPattern { matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, adaptor.ptr(), adaptor.value(), adaptor.mask()); + op, adaptor.ptr(), adaptor.value(), adaptor.mask(), adaptor.cache(), + adaptor.evict()); return success(); } }; diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index f03fa5cb26ae..3aadbfa0c0a0 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -149,8 +149,11 @@ bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs, //-- StoreOp -- void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, - ::mlir::Value ptr, ::mlir::Value value) { - StoreOp::build(builder, state, ptr, value, mlir::Value()); + ::mlir::Value ptr, ::mlir::Value value, + ::mlir::triton::CacheModifier cache, + ::mlir::triton::EvictionPolicy evict) { + return StoreOp::build(builder, state, ptr, value, mlir::Value(), cache, + evict); } //-- LoadOp -- diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 878bab56fe01..2261472170f1 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -169,8 +169,9 @@ struct CanonicalizeMaskedStorePattern if (splatMask.getSplatValue().getValue() == true) { // mask = splat(1) - rewriter.replaceOpWithNewOp(storeOp, storeOp.ptr(), - storeOp.value()); + rewriter.replaceOpWithNewOp( + storeOp, storeOp.ptr(), storeOp.value(), storeOp.cache(), + storeOp.evict()); } else { // mask = splat(0) rewriter.eraseOp(storeOp); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index 92d74eb5a0d6..924eea08c05c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -154,6 +154,12 @@ class SimplifyConversion : public mlir::RewritePattern { // block argument if (!arg) return mlir::failure(); + // cvt(view) -> view + if (auto view = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), view.getResult()); + return mlir::success(); + } // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) auto alloc_tensor = dyn_cast(arg); if (alloc_tensor) { @@ -278,6 +284,9 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, return failure(); ret = sliceEncoding.getParent(); } + if (auto view = dyn_cast(op)) { + return failure(); + } return success(); } @@ -287,16 +296,23 @@ inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) { if (isSingleValue(op->getOperand(0))) return false; auto ptr = op->getOperand(0); + // Case 2: We assume that `evict_last` loads/stores have high hit rate + if (auto load = dyn_cast(op)) + if (load.evict() == triton::EvictionPolicy::EVICT_LAST) + return false; + if (auto store = dyn_cast(op)) + if (store.evict() == triton::EvictionPolicy::EVICT_LAST) + return false; if (auto tensorTy = ptr.getType().dyn_cast()) { auto encoding = tensorTy.getEncoding(); - // Case 2: Different type conversion is expensive (e.g., mma <-> block) + // Case 3: Different type conversion is expensive (e.g., mma <-> block) if (encoding.getTypeID() != targetEncoding.getTypeID()) return true; auto sizePerThread = triton::gpu::getSizePerThread(encoding); auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding); auto order = triton::gpu::getOrder(encoding); auto targetOrder = triton::gpu::getOrder(targetEncoding); - // Case 3: The targeEncoding may expose more vectorization opportunities + // Case 4: The targeEncoding may expose more vectorization opportunities return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]]; } return false; @@ -365,6 +381,9 @@ LogicalResult simulateBackwardRematerialization( if (isa(*opArgI)) continue; + if (auto view = dyn_cast(opArgI)) + continue; + // We add one expensive conversion for the current operand numCvts += 1; queue.emplace_back(opArgI, newEncoding); @@ -383,9 +402,9 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, BlockAndValueMapping &mapping) { Operation *newOp = rewriter.clone(*op, mapping); auto origType = op->getResult(0).getType().cast(); + auto argType = newOp->getOperand(0).getType().cast(); auto newType = RankedTensorType::get( - origType.getShape(), origType.getElementType(), - newOp->getOperand(0).getType().cast().getEncoding()); + origType.getShape(), origType.getElementType(), argType.getEncoding()); newOp->getResult(0).setType(newType); auto typeInfer = dyn_cast(newOp); if (typeInfer) { @@ -425,6 +444,11 @@ void pushConversionForward(triton::gpu::ConvertLayoutOp cvt, } } rewriter.setInsertionPoint(op); + if (op->getNumResults() == 0) { + Operation *newOp = rewriter.clone(*op, mapping); + rewriter.eraseOp(op); + return; + } auto *newOp = cloneWithInferType(rewriter, op, mapping); auto newType = newOp->getResult(0).getType().cast(); auto newCvtType = RankedTensorType::get( @@ -564,17 +588,22 @@ class FoldConvertAndReduce : public mlir::RewritePattern { !isa(op) && !isa(op); }; mlir::getForwardSlice(cvt.getResult(), &cvtSlices, filter); - if (cvtSlices.empty()) + if (cvtSlices.empty()) { return failure(); + } llvm::MapVector toConvert; for (Operation *op : cvtSlices) { // don't rematerialize anything expensive - if (expensiveToRemat(op, srcEncoding)) + if (expensiveToRemat(op, dstEncoding)) { return failure(); + } // don't rematerialize non-element-wise - if (!op->hasTrait()) + if (!op->hasTrait() && + !op->hasTrait() && + !isa(op)) { return failure(); + } // don't rematerialize if it adds an extra conversion that can't // be removed for (Value arg : op->getOperands()) { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index fbbb1fa321e4..c015a69c6d5d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -169,7 +169,10 @@ LogicalResult LoopPipeliner::initialize() { if (auto loadOp = dyn_cast(&op)) { auto ptr = loadOp.ptr(); unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); - auto ty = getElementTypeOrSelf(ptr.getType()) + auto tensorTy = ptr.getType().dyn_cast(); + if (!tensorTy) + continue; + auto ty = tensorTy.getElementType() .cast() .getPointeeType(); unsigned width = vec * ty.getIntOrFloatBitWidth(); diff --git a/python/src/triton.cc b/python/src/triton.cc index bdf0ca9c5607..c40b117a5595 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -190,7 +190,14 @@ void init_triton_ir(py::module &&m) { if (mlir::Operation *definingOp = self.getDefiningOp()) definingOp->setAttr(name, attr); else { - /* issue a warning */ + auto arg = self.cast(); + int id = arg.getArgNumber(); + std::string attrName = name + "_arg" + std::to_string(id); + mlir::Block *owner = arg.getOwner(); + if (owner->isEntryBlock() && + !mlir::isa(owner->getParentOp())) { + owner->getParentOp()->setAttr(attrName, attr); + } } }) .def("get_context", &mlir::Value::getContext) @@ -1082,10 +1089,12 @@ void init_triton_ir(py::module &&m) { loc, ptrs, cacheModifier, evictionPolicy, isVolatile); }) .def("create_store", - [](mlir::OpBuilder &self, mlir::Value &ptrs, - mlir::Value &value) -> void { + [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &value, + mlir::triton::CacheModifier cacheModifier, + mlir::triton::EvictionPolicy evictionPolicy) -> void { auto loc = self.getUnknownLoc(); - self.create(loc, ptrs, value); + self.create(loc, ptrs, value, cacheModifier, + evictionPolicy); }) .def("create_masked_load", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &mask, @@ -1100,9 +1109,11 @@ void init_triton_ir(py::module &&m) { }) .def("create_masked_store", [](mlir::OpBuilder &self, mlir::Value &ptrs, mlir::Value &val, - mlir::Value &mask) -> void { + mlir::Value &mask, mlir::triton::CacheModifier cacheModifier, + mlir::triton::EvictionPolicy evictionPolicy) -> void { auto loc = self.getUnknownLoc(); - self.create(loc, ptrs, val, mask); + self.create(loc, ptrs, val, mask, + cacheModifier, evictionPolicy); }) .def("create_view", [](mlir::OpBuilder &self, mlir::Value &arg, diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 8c999247ea92..884983eb853b 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -179,6 +179,18 @@ def visit_compound_statement(self, stmts): break return stmts and isinstance(stmt, ast.Return) + def contains_return_op(self, node): + if isinstance(node, ast.Return): + return True + elif isinstance(node, ast.If): + pred = lambda s: self.contains_return_op(s) + ret = any(pred(s) for s in node.body) + if node.orelse: + ret = ret or any(pred(s) for s in node.orelse) + return ret + else: + return False + def visit_Module(self, node): ast.NodeVisitor.generic_visit(self, node) @@ -475,7 +487,7 @@ def visit_If(self, node): cond = self.visit(node.test) if isinstance(cond, triton.language.tensor): cond = cond.to(triton.language.int1, _builder=self.builder) - if self.scf_stack: + if self.scf_stack or not self.contains_return_op(node): self.visit_if_scf(cond, node) else: self.visit_if_top_level(cond, node) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 7f791c5f643b..4b360c8ca6ee 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -873,7 +873,7 @@ def load(pointer, mask=None, other=None, cache_modifier="", eviction_policy="", @builtin -def store(pointer, value, mask=None, _builder=None): +def store(pointer, value, mask=None, cache_modifier="", eviction_policy="", _builder=None): """ Stores :code:`value` tensor of elements in memory, element-wise, at the memory locations specified by :code:`pointer`. @@ -890,7 +890,9 @@ def store(pointer, value, mask=None, _builder=None): value = _to_tensor(value, _builder) if _constexpr_to_value(mask) is not None: mask = _to_tensor(mask, _builder) - return semantic.store(pointer, value, mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, cache_modifier, eviction_policy, _builder) # ----------------------- diff --git a/python/triton/language/random.py b/python/triton/language/random.py index c7063a0e300c..ba9f85227f7f 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -18,6 +18,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). """ for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): # update random state A = PHILOX_ROUND_A B = PHILOX_ROUND_B diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 3430f4364729..35d7f6d71e15 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -747,6 +747,30 @@ def cast(input: tl.tensor, # ===----------------------------------------------------------------------===// +def str_to_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], @@ -775,24 +799,6 @@ def load(ptr: tl.tensor, other = cast(other, elt_ty, builder) # cache modifier - cache = ir.CACHE_MODIFIER.NONE # default - if cache_modifier: - if cache_modifier == ".ca": - cache = ir.CACHE_MODIFIER.CA - elif cache_modifier == ".cg": - cache = ir.CACHE_MODIFIER.CG - else: - raise ValueError(f"Cache modifier {cache_modifier} not supported") - - # eviction policy - eviction = ir.EVICTION_POLICY.NORMAL # default - if eviction_policy: - if eviction_policy == "evict_last": - eviction = ir.EVICTION_POLICY.EVICT_LAST - elif eviction_policy == "evict_first": - eviction = ir.EVICTION_POLICY.EVICT_FIRST - else: - raise ValueError(f"Eviction policy {eviction_policy} not supported") if ptr.type.is_block(): shape = ptr.type.get_block_shapes() @@ -800,6 +806,9 @@ def load(ptr: tl.tensor, else: dst_ty = elt_ty + cache = str_to_cache_modifier(cache_modifier) + eviction = str_to_eviction_policy(eviction_policy) + if not mask: if other: raise ValueError("`other` cannot be provided without `mask`") @@ -816,6 +825,8 @@ def load(ptr: tl.tensor, def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], + cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) @@ -830,14 +841,16 @@ def store(ptr: tl.tensor, elt_ty = tl.int8 ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) ptr = cast(ptr, ptr_ty, builder) - + # attributes + cache = str_to_cache_modifier(cache_modifier) + eviction = str_to_eviction_policy(eviction_policy) # cast to target data-type val = cast(val, elt_ty, builder) if not mask: - return tl.tensor(builder.create_store(ptr.handle, val.handle), tl.void) + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) if not mask.type.scalar.is_bool(): raise ValueError("Mask must have boolean scalar type") - return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle), tl.void) + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) ######### # atomic diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 0eaef05d6f48..2c009ffa48d1 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2,7 +2,7 @@ #layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#layout2 = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> +#layout2 = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1]}> // CHECK: [[target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> // CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>