Skip to content

Commit

Permalink
[CIR][IR] Implement loop's conditional operation (#391)
Browse files Browse the repository at this point in the history
Like SCF's `scf.condition`, the `cir.condition` simplifies codegen of
loop conditions by removing the need of a contitional branch. It takes a
single boolean operand which, if true, executes the body region,
otherwise exits the loop. This also simplifies lowering and the dialect
it self.

A new constraint is now enforced on `cir.loops`: the condition region
must terminate with a `cir.condition` operation.

A few tests were removed as they became redundant, and others where
simplified.

The merge-cleanups pass no longer simplifies compile-time constant
conditions, as the condition body terminator is no longer allowed to be
terminated with a `cir.yield`. To circumvent this, a proper folder
should be implemented to fold constant conditions, but this was left as
future work.

Co-authored-by: Bruno Cardoso Lopes <bcardosolopes@users.noreply.github.com>
  • Loading branch information
sitio-couto and bcardosolopes authored Jan 10, 2024
1 parent f7c5363 commit 8ce2153
Show file tree
Hide file tree
Showing 16 changed files with 230 additions and 671 deletions.
19 changes: 19 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,25 @@ def TernaryOp : CIR_Op<"ternary",
}];
}

//===----------------------------------------------------------------------===//
// ConditionOp
//===----------------------------------------------------------------------===//

def ConditionOp : CIR_Op<"condition", [
Terminator,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
["getSuccessorRegions"]>
]> {
let summary = "Loop continuation condition.";
let description = [{
The `cir.condition` termintes loop's conditional regions. It takes a single
`cir.bool` operand. if the operand is true, the loop continues, otherwise
it terminates.
}];
let arguments = (ins CIR_BoolType:$condition);
let assemblyFormat = " `(` $condition `)` attr-dict ";
}

//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,11 @@ class CIRGenBuilderTy : public CIRBaseBuilderTy {
return create<mlir::cir::CopyOp>(dst.getLoc(), dst, src);
}

/// Create a loop condition.
mlir::cir::ConditionOp createCondition(mlir::Value condition) {
return create<mlir::cir::ConditionOp>(condition.getLoc(), condition);
}

mlir::cir::MemCpyOp createMemCpy(mlir::Location loc, mlir::Value dst,
mlir::Value src, mlir::Value len) {
return create<mlir::cir::MemCpyOp>(loc, dst, src, len);
Expand Down
32 changes: 4 additions & 28 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,26 +650,6 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
return buildCaseDefaultCascade(&S, condType, caseAttrs, os);
}

static mlir::LogicalResult buildLoopCondYield(mlir::OpBuilder &builder,
mlir::Location loc,
mlir::Value cond) {
mlir::Block *trueBB = nullptr, *falseBB = nullptr;
{
mlir::OpBuilder::InsertionGuard guard(builder);
trueBB = builder.createBlock(builder.getBlock()->getParent());
builder.create<mlir::cir::YieldOp>(loc, YieldOpKind::Continue);
}
{
mlir::OpBuilder::InsertionGuard guard(builder);
falseBB = builder.createBlock(builder.getBlock()->getParent());
builder.create<mlir::cir::YieldOp>(loc);
}

assert((trueBB && falseBB) && "expected both blocks to exist");
builder.create<mlir::cir::BrCondOp>(loc, cond, trueBB, falseBB);
return mlir::success();
}

mlir::LogicalResult
CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
ArrayRef<const Attr *> ForAttrs) {
Expand Down Expand Up @@ -703,8 +683,7 @@ CIRGenFunction::buildCXXForRangeStmt(const CXXForRangeStmt &S,
assert(!UnimplementedFeature::createProfileWeightsForLoop());
assert(!UnimplementedFeature::emitCondLikelihoodViaExpectIntrinsic());
mlir::Value condVal = evaluateExprAsBool(S.getCond());
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down Expand Up @@ -786,8 +765,7 @@ mlir::LogicalResult CIRGenFunction::buildForStmt(const ForStmt &S) {
loc, boolTy,
mlir::cir::BoolAttr::get(b.getContext(), boolTy, true));
}
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down Expand Up @@ -850,8 +828,7 @@ mlir::LogicalResult CIRGenFunction::buildDoStmt(const DoStmt &S) {
// expression compares unequal to 0. The condition must be a
// scalar type.
mlir::Value condVal = evaluateExprAsBool(S.getCond());
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down Expand Up @@ -910,8 +887,7 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
// expression compares unequal to 0. The condition must be a
// scalar type.
condVal = evaluateExprAsBool(S.getCond());
if (buildLoopCondYield(b, loc, condVal).failed())
loopRes = mlir::failure();
builder.createCondition(condVal);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
Expand Down
47 changes: 28 additions & 19 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,30 @@ void AllocaOp::build(::mlir::OpBuilder &odsBuilder,
odsState.addTypes(addr);
}

//===----------------------------------------------------------------------===//
// ConditionOp
//===-----------------------------------------------------------------------===//

//===----------------------------------
// BranchOpTerminatorInterface Methods

void ConditionOp::getSuccessorRegions(
ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
auto loopOp = cast<LoopOp>(getOperation()->getParentOp());

// TODO(cir): The condition value may be folded to a constant, narrowing
// down its list of possible successors.
// Condition may branch to the body or to the parent op.
regions.emplace_back(&loopOp.getBody(), loopOp.getBody().getArguments());
regions.emplace_back(loopOp->getResults());
}

MutableOperandRange
ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
// No values are yielded to the successor region.
return MutableOperandRange(getOperation(), 0, 0);
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1303,26 +1327,11 @@ void LoopOp::getSuccessorRegions(mlir::RegionBranchPoint point,
llvm::SmallVector<Region *> LoopOp::getLoopRegions() { return {&getBody()}; }

LogicalResult LoopOp::verify() {
// Cond regions should only terminate with plain 'cir.yield' or
// 'cir.yield continue'.
auto terminateError = [&]() {
return emitOpError() << "cond region must be terminated with "
"'cir.yield' or 'cir.yield continue'";
};
if (getCond().empty())
return emitOpError() << "cond region must not be empty";

auto &blocks = getCond().getBlocks();
for (Block &block : blocks) {
if (block.empty())
continue;
auto &op = block.back();
if (isa<BrCondOp>(op))
continue;
if (!isa<YieldOp>(op))
terminateError();
auto y = cast<YieldOp>(op);
if (!(y.isPlain() || y.isContinue()))
terminateError();
}
if (!llvm::isa<ConditionOp>(getCond().back().getTerminator()))
return emitOpError() << "cond region terminate with 'cir.condition'";

return success();
}
Expand Down
45 changes: 0 additions & 45 deletions clang/lib/CIR/Dialect/Transforms/MergeCleanups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,50 +54,6 @@ struct RemoveRedudantBranches : public OpRewritePattern<BrOp> {
}
};

/// Merges basic blocks of trivial conditional branches. This is useful when a
/// the condition of conditional branch is a constant and the destinations of
/// the conditional branch both have only one predecessor.
///
/// From:
/// ^bb0:
/// %0 = cir.const(#true) : !cir.bool
/// cir.brcond %0 ^bb1, ^bb2
/// ^bb1: // pred: ^bb0
/// cir.yield continue
/// ^bb2: // pred: ^bb0
/// cir.yield
///
/// To:
/// ^bb0:
/// cir.yield continue
///
struct MergeTrivialConditionalBranches : public OpRewritePattern<BrCondOp> {
using OpRewritePattern<BrCondOp>::OpRewritePattern;

LogicalResult match(BrCondOp op) const final {
return success(isa<ConstantOp>(op.getCond().getDefiningOp()) &&
op.getDestFalse()->hasOneUse() &&
op.getDestTrue()->hasOneUse());
}

/// Replace conditional branch with unconditional branch.
void rewrite(BrCondOp op, PatternRewriter &rewriter) const final {
auto constOp = llvm::cast<ConstantOp>(op.getCond().getDefiningOp());
bool cond = constOp.getValue().cast<cir::BoolAttr>().getValue();
auto *destTrue = op.getDestTrue(), *destFalse = op.getDestFalse();
Block *block = op.getOperation()->getBlock();

rewriter.eraseOp(op);
if (cond) {
rewriter.mergeBlocks(destTrue, block);
rewriter.eraseBlock(destFalse);
} else {
rewriter.mergeBlocks(destFalse, block);
rewriter.eraseBlock(destTrue);
}
}
};

struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> {
using OpRewritePattern<ScopeOp>::OpRewritePattern;

Expand Down Expand Up @@ -146,7 +102,6 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
RemoveRedudantBranches,
MergeTrivialConditionalBranches,
RemoveEmptyScope,
RemoveEmptySwitch
>(patterns.getContext());
Expand Down
43 changes: 12 additions & 31 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,25 +403,14 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
using mlir::OpConversionPattern<mlir::cir::LoopOp>::OpConversionPattern;
using LoopKind = mlir::cir::LoopOpKind;

mlir::LogicalResult
fetchCondRegionYields(mlir::Region &condRegion,
mlir::cir::YieldOp &yieldToBody,
mlir::cir::YieldOp &yieldToCont) const {
for (auto &bb : condRegion) {
if (auto yieldOp = dyn_cast<mlir::cir::YieldOp>(bb.getTerminator())) {
if (!yieldOp.getKind().has_value())
yieldToCont = yieldOp;
else if (yieldOp.getKind() == mlir::cir::YieldOpKind::Continue)
yieldToBody = yieldOp;
else
return mlir::failure();
}
}

// Succeed only if both yields are found.
if (!yieldToBody)
return mlir::failure();
return mlir::success();
inline void
lowerConditionOp(mlir::cir::ConditionOp op, mlir::Block *body,
mlir::Block *exit,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
rewriter.replaceOpWithNewOp<mlir::cir::BrCondOp>(op, op.getCondition(),
body, exit);
}

mlir::LogicalResult
Expand All @@ -435,9 +424,6 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
// Fetch required info from the condition region.
auto &condRegion = loopOp.getCond();
auto &condFrontBlock = condRegion.front();
mlir::cir::YieldOp yieldToBody, yieldToCont;
if (fetchCondRegionYields(condRegion, yieldToBody, yieldToCont).failed())
return loopOp.emitError("failed to fetch yields in cond region");

// Fetch required info from the body region.
auto &bodyRegion = loopOp.getBody();
Expand Down Expand Up @@ -469,15 +455,10 @@ class CIRLoopOpLowering : public mlir::OpConversionPattern<mlir::cir::LoopOp> {
auto &entry = (kind != LoopKind::DoWhile ? condFrontBlock : bodyFrontBlock);
rewriter.create<mlir::cir::BrOp>(loopOp.getLoc(), &entry);

// Set loop exit point to continue block.
if (yieldToCont) {
rewriter.setInsertionPoint(yieldToCont);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToCont, continueBlock);
}

// Branch from condition to body.
rewriter.setInsertionPoint(yieldToBody);
rewriter.replaceOpWithNewOp<mlir::cir::BrOp>(yieldToBody, &bodyFrontBlock);
// Branch from condition region to body or exit.
auto conditionOp =
cast<mlir::cir::ConditionOp>(condFrontBlock.getTerminator());
lowerConditionOp(conditionOp, &bodyFrontBlock, continueBlock, rewriter);

// Branch from body to condition or to step on for-loop cases.
rewriter.setInsertionPoint(bodyYield);
Expand Down
Loading

0 comments on commit 8ce2153

Please sign in to comment.