Skip to content

Commit

Permalink
[CIR] TryCallOp: add blocks, arguments, proper interface impl and tes…
Browse files Browse the repository at this point in the history
…tcase

- Add cir.try_call parsing.
- Add block destinations and hookup exception info type.
- Properly implement interface methods.

Printer is still missing, but coming next.
  • Loading branch information
bcardosolopes committed Jan 17, 2024
1 parent f4be195 commit b015a09
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 42 deletions.
24 changes: 18 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1923,13 +1923,14 @@ def FuncOp : CIR_Op<"func", [
}

//===----------------------------------------------------------------------===//
// CallOp and TryCallOp
// CallOp
//===----------------------------------------------------------------------===//

class CIR_CallOp<string mnemonic> :
class CIR_CallOp<string mnemonic, list<Trait> extra_traits = []> :
Op<CIR_Dialect, mnemonic,
[DeclareOpInterfaceMethods<CIRCallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
!listconcat(extra_traits,
[DeclareOpInterfaceMethods<CIRCallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
let extraClassDeclaration = [{
/// Get the argument operands to the called function.
OperandRange getArgOperands() {
Expand Down Expand Up @@ -2013,7 +2014,13 @@ def CallOp : CIR_CallOp<"call"> {
}]>];
}

def TryCallOp : CIR_CallOp<"try_call"> {
//===----------------------------------------------------------------------===//
// TryCallOp
//===----------------------------------------------------------------------===//

def TryCallOp : CIR_CallOp<"try_call",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
Terminator]> {
let summary = "try call operation";
let description = [{
Works very similar to `cir.call` but passes down an exception object
Expand All @@ -2036,8 +2043,13 @@ def TryCallOp : CIR_CallOp<"try_call"> {
}];

let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<CIR_AnyType>:$operands,
ExceptionInfoPtr:$exceptionInfo,
Variadic<CIR_AnyType>:$destContOps,
Variadic<CIR_AnyType>:$destAbortOps,
Variadic<CIR_AnyType>:$callOps,
OptionalAttr<ASTCallExprInterface>:$ast);
let successors = (successor AnySuccessor:$destContinue,
AnySuccessor:$destAbort);
let results = (outs Variadic<CIR_AnyType>);

let builders = [
Expand Down
61 changes: 38 additions & 23 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,26 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
}];
}

//===----------------------------------------------------------------------===//
// Exception info type
//
// By introducing an exception info type, exception related operations can be
// more descriptive.
//
// This basically wraps a uint8_t* and a uint32_t
//
//===----------------------------------------------------------------------===//

def CIR_ExceptionInfo : CIR_Type<"ExceptionInfo", "eh.info"> {
let summary = "CIR exception info";
let description = [{
Represents the content necessary for a `cir.call` to pass back an exception
object pointer + some extra selector information. This type is required for
some exception related operations, like `cir.catch`, `cir.eh.selector_slot`
and `cir.eh.slot`.
}];
}

//===----------------------------------------------------------------------===//
// Void type
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -254,37 +274,32 @@ def VoidPtr : Type<
"mlir::cir::VoidType::get($_builder.getContext()))"> {
}

// Pointer to exception info
def ExceptionInfoPtr : Type<
And<[
CPred<"$_self.isa<::mlir::cir::PointerType>()">,
CPred<"$_self.cast<::mlir::cir::PointerType>()"
".getPointee().isa<::mlir::cir::ExceptionInfoType>()">,
]>, "void*">,
BuildableType<
"mlir::cir::PointerType::get($_builder.getContext(),"
"mlir::cir::ExceptionInfo::get($_builder.getContext()))"> {
}

//===----------------------------------------------------------------------===//
// Global type constraints
// StructType (defined in cpp files)
//===----------------------------------------------------------------------===//

def CIR_StructType : Type<CPred<"$_self.isa<::mlir::cir::StructType>()">,
"CIR struct type">;

def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_BoolType, CIR_ArrayType, CIR_VectorType,
CIR_FuncType, CIR_VoidType, CIR_StructType, AnyFloat,
]>;


//===----------------------------------------------------------------------===//
// Exception info type
//
// By introducing an exception info type, exception related operations can be
// more descriptive.
//
// This basically wraps a uint8_t* and a uint32_t
//
// Global type constraints
//===----------------------------------------------------------------------===//

def CIR_ExceptionInfo : CIR_Type<"ExceptionInfo", "eh.info"> {
let summary = "CIR exception info";
let description = [{
Represents the content necessary for a `cir.call` to pass back an exception
object pointer + some extra selector information. This type is required for
some exception related operations, like `cir.catch`, `cir.eh.selector_slot`
and `cir.eh.slot`.
}];
}
def CIR_AnyType : AnyTypeOf<[
CIR_IntType, CIR_PointerType, CIR_BoolType, CIR_ArrayType, CIR_VectorType,
CIR_FuncType, CIR_VoidType, CIR_StructType, CIR_ExceptionInfo, AnyFloat,
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
163 changes: 150 additions & 13 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1949,19 +1949,24 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
if (!fn)
return op->emitOpError() << "'" << fnAttr.getValue()
<< "' does not reference a valid function";
auto callIf = dyn_cast<mlir::cir::CIRCallOpInterface>(op);
assert(callIf && "expected CIR call interface to be always available");

// Verify that the operand and result types match the callee. Note that
// argument-checking is disabled for functions without a prototype.
auto fnType = fn.getFunctionType();
if (!fn.getNoProto()) {
if (!fnType.isVarArg() && op->getNumOperands() != fnType.getNumInputs())
unsigned numCallOperands = callIf.getNumArgOperands();
unsigned numFnOpOperands = fnType.getNumInputs();

if (!fnType.isVarArg() && numCallOperands != numFnOpOperands)
return op->emitOpError("incorrect number of operands for callee");

if (fnType.isVarArg() && op->getNumOperands() < fnType.getNumInputs())
if (fnType.isVarArg() && numCallOperands < numFnOpOperands)
return op->emitOpError("too few operands for callee");

for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
if (op->getOperand(i).getType() != fnType.getInput(i))
for (unsigned i = 0, e = numFnOpOperands; i != e; ++i)
if (callIf.getArgOperand(i).getType() != fnType.getInput(i))
return op->emitOpError("operand type mismatch: expected operand type ")
<< fnType.getInput(i) << ", but provided "
<< op->getOperand(i).getType() << " for operand number " << i;
Expand All @@ -1986,8 +1991,13 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
return success();
}

static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
static ::mlir::ParseResult parseCallCommon(
::mlir::OpAsmParser &parser, ::mlir::OperationState &result,
llvm::function_ref<::mlir::ParseResult(::mlir::OpAsmParser &,
::mlir::OperationState &, int32_t)>
customOpHandler = [](::mlir::OpAsmParser &parser,
::mlir::OperationState &result,
int32_t numCallArgs) { return mlir::success(); }) {
mlir::FlatSymbolRefAttr calleeAttr;
llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> ops;
llvm::SMLoc opsLoc;
Expand Down Expand Up @@ -2024,13 +2034,18 @@ static ::mlir::ParseResult parseCallCommon(::mlir::OpAsmParser &parser,
operandsTypes = opsFnTy.getInputs();
allResultTypes = opsFnTy.getResults();
result.addTypes(allResultTypes);

if (customOpHandler(parser, result, operandsTypes.size()).failed())
return ::mlir::failure();

if (parser.resolveOperands(ops, operandsTypes, opsLoc, result.operands))
return ::mlir::failure();
return ::mlir::success();
}

void printCallCommon(Operation *op, mlir::FlatSymbolRefAttr flatSym,
::mlir::OpAsmPrinter &state) {
void printCallCommon(
Operation *op, mlir::FlatSymbolRefAttr flatSym, ::mlir::OpAsmPrinter &state,
llvm::function_ref<void()> customOpHandler = []() {}) {
state << ' ';
auto ops = op->getOperands();

Expand Down Expand Up @@ -2074,7 +2089,12 @@ mlir::Operation::operand_iterator cir::TryCallOp::arg_operand_begin() {
auto arg_begin = operand_begin();
if (!getCallee())
arg_begin++;
return arg_begin;
// First operand is the exception pointer, skip it.
//
// FIXME(cir): for this and all the other calculations in the other methods:
// we currently have no basic block arguments on cir.try_call, but if it gets
// to that, this needs further adjustment.
return arg_begin++;
}
mlir::Operation::operand_iterator cir::TryCallOp::arg_operand_end() {
return operand_end();
Expand All @@ -2084,13 +2104,17 @@ mlir::Operation::operand_iterator cir::TryCallOp::arg_operand_end() {
Value cir::TryCallOp::getArgOperand(unsigned i) {
if (!getCallee())
i++;
return getOperand(i);
// First operand is the exception pointer, skip it.
return getOperand(i + 1);
}
/// Return the number of operands, , accounts for indirect call.
unsigned cir::TryCallOp::getNumArgOperands() {
unsigned numOperands = this->getOperation()->getNumOperands();
if (!getCallee())
return this->getOperation()->getNumOperands() - 1;
return this->getOperation()->getNumOperands();
numOperands--;
// First operand is the exception pointer, skip it.
numOperands--;
return numOperands;
}

LogicalResult
Expand All @@ -2100,13 +2124,126 @@ cir::TryCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

::mlir::ParseResult TryCallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(parser, result);
return parseCallCommon(
parser, result,
[](::mlir::OpAsmParser &parser, ::mlir::OperationState &result,
int32_t numCallArgs) -> ::mlir::ParseResult {
::mlir::OpAsmParser::UnresolvedOperand exceptionRawOperands[1];
::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand>
exceptionOperands(exceptionRawOperands);
::llvm::SMLoc exceptionOperandsLoc;
(void)exceptionOperandsLoc;

::mlir::Block *destContinueSuccessor = nullptr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
destOperandsContinue;
::llvm::SMLoc destOperandsContinueLoc;
(void)destOperandsContinueLoc;
::llvm::SmallVector<::mlir::Type, 1> destOperandsContinueTypes;
::mlir::Block *destAbortSuccessor = nullptr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
destOperandsAbort;
::llvm::SMLoc destOperandsAbortLoc;
(void)destOperandsAbortLoc;
::llvm::SmallVector<::mlir::Type, 1> destOperandsAbortTypes;

// So far we have 4: exception ptr, variadic continue, variadic abort
// and variadic call args.
enum {
Segment_Exception_Idx,
Segment_Continue_Idx,
Segment_Abort_Idx,
Segment_CallArgs_Idx,
};
::llvm::SmallVector<int32_t, 4> operandSegmentSizes = {0, 0, 0, 0};

if (parser.parseComma())
return ::mlir::failure();

// Handle continue destination and potential bb operands.
if (parser.parseSuccessor(destContinueSuccessor))
return ::mlir::failure();
if (::mlir::succeeded(parser.parseOptionalLParen())) {

destOperandsContinueLoc = parser.getCurrentLocation();
if (parser.parseOperandList(destOperandsContinue))
return ::mlir::failure();
if (parser.parseColon())
return ::mlir::failure();

if (parser.parseTypeList(destOperandsContinueTypes))
return ::mlir::failure();
if (parser.parseRParen())
return ::mlir::failure();
}
if (parser.parseComma())
return ::mlir::failure();

// Handle abort destination and potential bb operands.
if (parser.parseSuccessor(destAbortSuccessor))
return ::mlir::failure();
if (::mlir::succeeded(parser.parseOptionalLParen())) {
destOperandsAbortLoc = parser.getCurrentLocation();
if (parser.parseOperandList(destOperandsAbort))
return ::mlir::failure();
if (parser.parseColon())
return ::mlir::failure();

if (parser.parseTypeList(destOperandsAbortTypes))
return ::mlir::failure();
if (parser.parseRParen())
return ::mlir::failure();
}

if (parser.parseComma())
return ::mlir::failure();
exceptionOperandsLoc = parser.getCurrentLocation();
if (parser.parseOperand(exceptionRawOperands[0]))
return ::mlir::failure();

auto exceptionPtrTy = cir::PointerType::get(
parser.getBuilder().getContext(),
parser.getBuilder().getType<::mlir::cir::ExceptionInfoType>());
if (parser.resolveOperands(exceptionOperands, exceptionPtrTy,
exceptionOperandsLoc, result.operands))
return ::mlir::failure();

// Add information to the builders.
result.addSuccessors(destContinueSuccessor);
result.addSuccessors(destAbortSuccessor);

if (parser.resolveOperands(destOperandsContinue,
destOperandsContinueTypes,
destOperandsContinueLoc, result.operands))
return ::mlir::failure();
if (parser.resolveOperands(destOperandsAbort, destOperandsAbortTypes,
destOperandsAbortLoc, result.operands))
return ::mlir::failure();

// Required to always be there.
operandSegmentSizes[Segment_Exception_Idx] = 1;
operandSegmentSizes[Segment_Continue_Idx] =
destOperandsContinueTypes.size();
operandSegmentSizes[Segment_Abort_Idx] = destOperandsAbortTypes.size();
operandSegmentSizes[Segment_CallArgs_Idx] = numCallArgs;
result.addAttribute(
"operandSegmentSizes",
parser.getBuilder().getDenseI32ArrayAttr(operandSegmentSizes));

return ::mlir::success();
});
}

void TryCallOp::print(::mlir::OpAsmPrinter &state) {
printCallCommon(*this, getCalleeAttr(), state);
}

mlir::SuccessorOperands TryCallOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(index == 0 ? getDestContOpsMutable()
: getDestAbortOpsMutable());
}

//===----------------------------------------------------------------------===//
// UnaryOp
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 27 additions & 0 deletions clang/test/CIR/IR/exceptions.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: cir-opt %s | FileCheck %s

!s32i = !cir.int<s, 32>

module {
cir.func @div(%x : !s32i, %y : !s32i) -> !s32i {
%3 = cir.const(#cir.int<0> : !s32i) : !s32i
cir.return %3 : !s32i
}

cir.func @foo(%x : !s32i, %y : !s32i) {
cir.scope {
%10 = cir.scope {
%0 = cir.alloca !cir.eh.info, cir.ptr <!cir.eh.info>, ["exception_info"] {alignment = 16 : i64}
%d = cir.try_call @div(%x, %y) : (!s32i, !s32i) -> !s32i, ^continue_A, ^abort, %0
// CHECK: cir.try_call @div(%1, %arg0, %arg1) {operandSegmentSizes = array<i32: 1, 0, 0, 2>} : (!cir.ptr<!cir.eh.info>, !s32i, !s32i) -> !s32i
^continue_A:
cir.br ^abort
^abort:
%1 = cir.load %0 : cir.ptr <!cir.eh.info>, !cir.eh.info
cir.yield %1 : !cir.eh.info
} : !cir.eh.info
cir.yield
}
cir.return
}
}

0 comments on commit b015a09

Please sign in to comment.