diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 6cbbf606ebad..4df2dbd5aafc 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1909,7 +1909,7 @@ def FuncOp : CIR_Op<"func", [ } //===----------------------------------------------------------------------===// -// CallOp +// CallOp and TryCallOp //===----------------------------------------------------------------------===// class CIR_CallOp : @@ -1999,6 +1999,56 @@ def CallOp : CIR_CallOp<"call"> { }]>]; } +def TryCallOp : CIR_CallOp<"try_call"> { + let summary = "try call operation"; + let description = [{ + Works very similar to `cir.call` but passes down an exception object + in case anything is thrown by the callee. Upon the callee throwing, + `cir.try_call` goes to current `cir.scope`'s `abort` label, otherwise + execution follows to the `continue` label. + + To walk the operands for this operation, use `getNumArgOperands()`, + `getArgOperand()`, `getArgOperands()`, `arg_operand_begin()` and + `arg_operand_begin()`. Avoid using `getNumOperands()`, `getOperand()`, + `operand_begin()`, etc, direclty - might be misleading given the + exception object address is also part of the raw operation's operands. + `` + + Example: + + ```mlir + %r = cir.try_call @division(%1, %2), ^continue_A, ^abort, %0 + ``` + }]; + + let arguments = (ins OptionalAttr:$callee, + Variadic:$operands, + OptionalAttr:$ast); + let results = (outs Variadic); + + let builders = [ + OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); + if (!callee.getFunctionType().isVoid()) + $_state.addTypes(callee.getFunctionType().getReturnType()); + }]>, + OpBuilder<(ins "Value":$ind_target, + "FuncType":$fn_type, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(ValueRange{ind_target}); + $_state.addOperands(operands); + if (!fn_type.isVoid()) + $_state.addTypes(fn_type.getReturnType()); + }]>, + OpBuilder<(ins "SymbolRefAttr":$callee, "mlir::Type":$resType, + CArg<"ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(resType); + }]>]; +} + //===----------------------------------------------------------------------===// // AwaitOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 08837b4e5aa1..7b71b7d61cda 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1977,56 +1977,57 @@ unsigned cir::CallOp::getNumArgOperands() { return this->getOperation()->getNumOperands(); } -LogicalResult -cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { +static LogicalResult +verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) { // Callee attribute only need on indirect calls. - auto fnAttr = (*this)->getAttrOfType("callee"); + auto fnAttr = op->getAttrOfType("callee"); if (!fnAttr) return success(); FuncOp fn = - symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + symbolTable.lookupNearestSymbolFrom(op, fnAttr); if (!fn) - return emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; + return op->emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; // Verify that the operand and result types match the callee. Note that // argument-checking is disabled for functions without a prototype. auto fnType = fn.getFunctionType(); if (!fn.getNoProto()) { - if (!fnType.isVarArg() && getNumOperands() != fnType.getNumInputs()) - return emitOpError("incorrect number of operands for callee"); + if (!fnType.isVarArg() && op->getNumOperands() != fnType.getNumInputs()) + return op->emitOpError("incorrect number of operands for callee"); - if (fnType.isVarArg() && getNumOperands() < fnType.getNumInputs()) - return emitOpError("too few operands for callee"); + if (fnType.isVarArg() && op->getNumOperands() < fnType.getNumInputs()) + return op->emitOpError("too few operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) - if (getOperand(i).getType() != fnType.getInput(i)) - return emitOpError("operand type mismatch: expected operand type ") + if (op->getOperand(i).getType() != fnType.getInput(i)) + return op->emitOpError("operand type mismatch: expected operand type ") << fnType.getInput(i) << ", but provided " - << getOperand(i).getType() << " for operand number " << i; + << op->getOperand(i).getType() << " for operand number " << i; } // Void function must not return any results. - if (fnType.isVoid() && getNumResults() != 0) - return emitOpError("callee returns void but call has results"); + if (fnType.isVoid() && op->getNumResults() != 0) + return op->emitOpError("callee returns void but call has results"); // Non-void function calls must return exactly one result. - if (!fnType.isVoid() && getNumResults() != 1) - return emitOpError("incorrect number of results for callee"); + if (!fnType.isVoid() && op->getNumResults() != 1) + return op->emitOpError("incorrect number of results for callee"); // Parent function and return value types must match. - if (!fnType.isVoid() && getResultTypes().front() != fnType.getReturnType()) { - return emitOpError("result type mismatch: expected ") + if (!fnType.isVoid() && + op->getResultTypes().front() != fnType.getReturnType()) { + return op->emitOpError("result type mismatch: expected ") << fnType.getReturnType() << ", but provided " - << getResult(0).getType(); + << op->getResult(0).getType(); } return success(); } -::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser, - ::mlir::OperationState &result) { +static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { mlir::FlatSymbolRefAttr calleeAttr; llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> ops; llvm::SMLoc opsLoc; @@ -2068,12 +2069,13 @@ ::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser, return ::mlir::success(); } -void CallOp::print(::mlir::OpAsmPrinter &state) { +void printCallCommon(Operation *op, mlir::FlatSymbolRefAttr flatSym, + ::mlir::OpAsmPrinter &state) { state << ' '; - auto ops = getOperands(); + auto ops = op->getOperands(); - if (getCallee()) { // Direct calls - state.printAttributeWithoutType(getCalleeAttr()); + if (flatSym) { // Direct calls + state.printAttributeWithoutType(flatSym); } else { // Indirect calls state << ops.front(); ops = ops.drop_front(); @@ -2084,11 +2086,65 @@ void CallOp::print(::mlir::OpAsmPrinter &state) { llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs; elidedAttrs.push_back("callee"); elidedAttrs.push_back("ast"); - state.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); + state.printOptionalAttrDict(op->getAttrs(), elidedAttrs); state << ' ' << ":"; state << ' '; - state.printFunctionalType(getOperands().getTypes(), - getOperation()->getResultTypes()); + state.printFunctionalType(op->getOperands().getTypes(), op->getResultTypes()); +} + +LogicalResult +cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + return verifyCallCommInSymbolUses(*this, symbolTable); +} + +::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { + return parseCallCommon(parser, result); +} + +void CallOp::print(::mlir::OpAsmPrinter &state) { + printCallCommon(*this, getCalleeAttr(), state); +} + +//===----------------------------------------------------------------------===// +// TryCallOp +//===----------------------------------------------------------------------===// + +mlir::Operation::operand_iterator cir::TryCallOp::arg_operand_begin() { + auto arg_begin = operand_begin(); + if (!getCallee()) + arg_begin++; + return arg_begin; +} +mlir::Operation::operand_iterator cir::TryCallOp::arg_operand_end() { + return operand_end(); +} + +/// Return the operand at index 'i', accounts for indirect call. +Value cir::TryCallOp::getArgOperand(unsigned i) { + if (!getCallee()) + i++; + return getOperand(i); +} +/// Return the number of operands, , accounts for indirect call. +unsigned cir::TryCallOp::getNumArgOperands() { + if (!getCallee()) + return this->getOperation()->getNumOperands() - 1; + return this->getOperation()->getNumOperands(); +} + +LogicalResult +cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + return verifyCallCommInSymbolUses(*this, symbolTable); +} + +::mlir::ParseResult TryCallOp::parse(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { + return parseCallCommon(parser, result); +} + +void TryCallOp::print(::mlir::OpAsmPrinter &state) { + printCallCommon(*this, getCalleeAttr(), state); } //===----------------------------------------------------------------------===//