Skip to content

Commit

Permalink
[CIR] Introduce cir.try_call operation
Browse files Browse the repository at this point in the history
This will be used for any calls happening inside try regions.

More refactoring. For now it's incremental work, still some mileage to cover
before I can introduce a testcase. The current implementation mimics cir.call,
pieces are going to change in following commits.
  • Loading branch information
bcardosolopes committed Jan 13, 2024
1 parent b845800 commit b33de0c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 30 deletions.
52 changes: 51 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,7 @@ def FuncOp : CIR_Op<"func", [
}

//===----------------------------------------------------------------------===//
// CallOp
// CallOp and TryCallOp
//===----------------------------------------------------------------------===//

class CIR_CallOp<string mnemonic> :
Expand Down Expand Up @@ -1999,6 +1999,56 @@ def CallOp : CIR_CallOp<"call"> {
}]>];
}

def TryCallOp : CIR_CallOp<"try_call"> {
let summary = "try call operation";
let description = [{
Works very similar to `cir.call` but passes down an exception object
in case anything is thrown by the callee. Upon the callee throwing,
`cir.try_call` goes to current `cir.scope`'s `abort` label, otherwise
execution follows to the `continue` label.

To walk the operands for this operation, use `getNumArgOperands()`,
`getArgOperand()`, `getArgOperands()`, `arg_operand_begin()` and
`arg_operand_begin()`. Avoid using `getNumOperands()`, `getOperand()`,
`operand_begin()`, etc, direclty - might be misleading given the
exception object address is also part of the raw operation's operands.
``

Example:

```mlir
%r = cir.try_call @division(%1, %2), ^continue_A, ^abort, %0
```
}];

let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<CIR_AnyType>:$operands,
OptionalAttr<ASTCallExprInterface>:$ast);
let results = (outs Variadic<CIR_AnyType>);

let builders = [
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", SymbolRefAttr::get(callee));
if (!callee.getFunctionType().isVoid())
$_state.addTypes(callee.getFunctionType().getReturnType());
}]>,
OpBuilder<(ins "Value":$ind_target,
"FuncType":$fn_type,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(ValueRange{ind_target});
$_state.addOperands(operands);
if (!fn_type.isVoid())
$_state.addTypes(fn_type.getReturnType());
}]>,
OpBuilder<(ins "SymbolRefAttr":$callee, "mlir::Type":$resType,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
$_state.addTypes(resType);
}]>];
}

//===----------------------------------------------------------------------===//
// AwaitOp
//===----------------------------------------------------------------------===//
Expand Down
114 changes: 85 additions & 29 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1977,56 +1977,57 @@ unsigned cir::CallOp::getNumArgOperands() {
return this->getOperation()->getNumOperands();
}

LogicalResult
cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
static LogicalResult
verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
// Callee attribute only need on indirect calls.
auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
auto fnAttr = op->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!fnAttr)
return success();

FuncOp fn =
symbolTable.lookupNearestSymbolFrom<mlir::cir::FuncOp>(*this, fnAttr);
symbolTable.lookupNearestSymbolFrom<mlir::cir::FuncOp>(op, fnAttr);
if (!fn)
return emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
return op->emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";

// Verify that the operand and result types match the callee. Note that
// argument-checking is disabled for functions without a prototype.
auto fnType = fn.getFunctionType();
if (!fn.getNoProto()) {
if (!fnType.isVarArg() && getNumOperands() != fnType.getNumInputs())
return emitOpError("incorrect number of operands for callee");
if (!fnType.isVarArg() && op->getNumOperands() != fnType.getNumInputs())
return op->emitOpError("incorrect number of operands for callee");

if (fnType.isVarArg() && getNumOperands() < fnType.getNumInputs())
return emitOpError("too few operands for callee");
if (fnType.isVarArg() && op->getNumOperands() < fnType.getNumInputs())
return op->emitOpError("too few operands for callee");

for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
if (getOperand(i).getType() != fnType.getInput(i))
return emitOpError("operand type mismatch: expected operand type ")
if (op->getOperand(i).getType() != fnType.getInput(i))
return op->emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< getOperand(i).getType() << " for operand number " << i;
<< op->getOperand(i).getType() << " for operand number " << i;
}

// Void function must not return any results.
if (fnType.isVoid() && getNumResults() != 0)
return emitOpError("callee returns void but call has results");
if (fnType.isVoid() && op->getNumResults() != 0)
return op->emitOpError("callee returns void but call has results");

// Non-void function calls must return exactly one result.
if (!fnType.isVoid() && getNumResults() != 1)
return emitOpError("incorrect number of results for callee");
if (!fnType.isVoid() && op->getNumResults() != 1)
return op->emitOpError("incorrect number of results for callee");

// Parent function and return value types must match.
if (!fnType.isVoid() && getResultTypes().front() != fnType.getReturnType()) {
return emitOpError("result type mismatch: expected ")
if (!fnType.isVoid() &&
op->getResultTypes().front() != fnType.getReturnType()) {
return op->emitOpError("result type mismatch: expected ")
<< fnType.getReturnType() << ", but provided "
<< getResult(0).getType();
<< op->getResult(0).getType();
}

return success();
}

::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
mlir::FlatSymbolRefAttr calleeAttr;
llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> ops;
llvm::SMLoc opsLoc;
Expand Down Expand Up @@ -2068,12 +2069,13 @@ ::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser,
return ::mlir::success();
}

void CallOp::print(::mlir::OpAsmPrinter &state) {
void printCallCommon(Operation *op, mlir::FlatSymbolRefAttr flatSym,
::mlir::OpAsmPrinter &state) {
state << ' ';
auto ops = getOperands();
auto ops = op->getOperands();

if (getCallee()) { // Direct calls
state.printAttributeWithoutType(getCalleeAttr());
if (flatSym) { // Direct calls
state.printAttributeWithoutType(flatSym);
} else { // Indirect calls
state << ops.front();
ops = ops.drop_front();
Expand All @@ -2084,11 +2086,65 @@ void CallOp::print(::mlir::OpAsmPrinter &state) {
llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
elidedAttrs.push_back("callee");
elidedAttrs.push_back("ast");
state.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
state.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
state << ' ' << ":";
state << ' ';
state.printFunctionalType(getOperands().getTypes(),
getOperation()->getResultTypes());
state.printFunctionalType(op->getOperands().getTypes(), op->getResultTypes());
}

LogicalResult
cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyCallCommInSymbolUses(*this, symbolTable);
}

::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(parser, result);
}

void CallOp::print(::mlir::OpAsmPrinter &state) {
printCallCommon(*this, getCalleeAttr(), state);
}

//===----------------------------------------------------------------------===//
// TryCallOp
//===----------------------------------------------------------------------===//

mlir::Operation::operand_iterator cir::TryCallOp::arg_operand_begin() {
auto arg_begin = operand_begin();
if (!getCallee())
arg_begin++;
return arg_begin;
}
mlir::Operation::operand_iterator cir::TryCallOp::arg_operand_end() {
return operand_end();
}

/// Return the operand at index 'i', accounts for indirect call.
Value cir::TryCallOp::getArgOperand(unsigned i) {
if (!getCallee())
i++;
return getOperand(i);
}
/// Return the number of operands, , accounts for indirect call.
unsigned cir::TryCallOp::getNumArgOperands() {
if (!getCallee())
return this->getOperation()->getNumOperands() - 1;
return this->getOperation()->getNumOperands();
}

LogicalResult
cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return verifyCallCommInSymbolUses(*this, symbolTable);
}

::mlir::ParseResult TryCallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(parser, result);
}

void TryCallOp::print(::mlir::OpAsmPrinter &state) {
printCallCommon(*this, getCalleeAttr(), state);
}

//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit b33de0c

Please sign in to comment.