Skip to content

Commit

Permalink
[CIR] Add support for casting pointer-to-data-member values (#1188)
Browse files Browse the repository at this point in the history
This PR adds support for base-to-derived and derived-to-base casts on
pointer-to-data-member values.

Related to #973.
  • Loading branch information
Lancern authored Dec 2, 2024
1 parent 67bbd1e commit eacaabb
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 20 deletions.
52 changes: 52 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3279,6 +3279,58 @@ def DerivedClassAddrOp : CIR_Op<"derived_class_addr"> {
let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// BaseDataMemberOp & DerivedDataMemberOp
//===----------------------------------------------------------------------===//

def BaseDataMemberOp : CIR_Op<"base_data_member", [Pure]> {
let summary =
"Cast a derived class data member pointer to a base class data member "
"pointer";
let description = [{
The `cir.base_data_member` operation casts a data member pointer of type
`T Derived::*` to a data member pointer of type `T Base::*`, where `Base`
is an accessible non-ambiguous non-virtual base class of `Derived`.

The `offset` parameter gives the offset in bytes of the `Base` base class
subobject within a `Derived` object.
}];

let arguments = (ins CIR_DataMemberType:$src, IndexAttr:$offset);
let results = (outs CIR_DataMemberType:$result);

let assemblyFormat = [{
`(` $src `:` qualified(type($src)) `)`
`[` $offset `]` `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

def DerivedDataMemberOp : CIR_Op<"derived_data_member", [Pure]> {
let summary =
"Cast a base class data member pointer to a derived class data member "
"pointer";
let description = [{
The `cir.derived_data_member` operation casts a data member pointer of type
`T Base::*` to a data member pointer of type `T Derived::*`, where `Base`
is an accessible non-ambiguous non-virtual base class of `Derived`.

The `offset` parameter gives the offset in bytes of the `Base` base class
subobject within a `Derived` object.
}];

let arguments = (ins CIR_DataMemberType:$src, IndexAttr:$offset);
let results = (outs CIR_DataMemberType:$result);

let assemblyFormat = [{
`(` $src `:` qualified(type($src)) `)`
`[` $offset `]` `->` qualified(type($result)) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 24 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1744,9 +1744,30 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *CE) {
case CK_ReinterpretMemberPointer:
llvm_unreachable("NYI");
case CK_BaseToDerivedMemberPointer:
llvm_unreachable("NYI");
case CK_DerivedToBaseMemberPointer:
llvm_unreachable("NYI");
case CK_DerivedToBaseMemberPointer: {
mlir::Value src = Visit(E);

QualType derivedTy =
Kind == CK_DerivedToBaseMemberPointer ? E->getType() : CE->getType();
const CXXRecordDecl *derivedClass = derivedTy->castAs<MemberPointerType>()
->getClass()
->getAsCXXRecordDecl();
CharUnits offset = CGF.CGM.computeNonVirtualBaseClassOffset(
derivedClass, CE->path_begin(), CE->path_end());

if (E->getType()->isMemberFunctionPointerType())
llvm_unreachable("NYI");

mlir::Location loc = CGF.getLoc(E->getExprLoc());
mlir::Type resultTy = CGF.getCIRType(DestTy);
mlir::IntegerAttr offsetAttr = Builder.getIndexAttr(offset.getQuantity());

if (Kind == CK_BaseToDerivedMemberPointer)
return Builder.create<cir::DerivedDataMemberOp>(loc, resultTy, src,
offsetAttr);
return Builder.create<cir::BaseDataMemberOp>(loc, resultTy, src,
offsetAttr);
}
case CK_ARCProduceObject:
llvm_unreachable("NYI");
case CK_ARCConsumeObject:
Expand Down
26 changes: 26 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,32 @@ LogicalResult cir::DynamicCastOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// BaseDataMemberOp & DerivedDataMemberOp
//===----------------------------------------------------------------------===//

static LogicalResult verifyDataMemberCast(Operation *op, mlir::Value src,
mlir::Type resultTy) {
// Let the operand type be T1 C1::*, let the result type be T2 C2::*.
// Verify that T1 and T2 are the same type.
auto inputMemberTy =
mlir::cast<cir::DataMemberType>(src.getType()).getMemberTy();
auto resultMemberTy = mlir::cast<cir::DataMemberType>(resultTy).getMemberTy();
if (inputMemberTy != resultMemberTy)
return op->emitOpError()
<< "member types of the operand and the result do not match";

return mlir::success();
}

LogicalResult cir::BaseDataMemberOp::verify() {
return verifyDataMemberCast(getOperation(), getSrc(), getType());
}

LogicalResult cir::DerivedDataMemberOp::verify() {
return verifyDataMemberCast(getOperation(), getSrc(), getType());
}

//===----------------------------------------------------------------------===//
// ComplexCreateOp
//===----------------------------------------------------------------------===//
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ class CIRCXXABI {
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const = 0;

/// Lower the given cir.base_data_member op to a sequence of more "primitive"
/// CIR operations that act on the ABI types.
virtual mlir::Value lowerBaseDataMember(cir::BaseDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;

/// Lower the given cir.derived_data_member op to a sequence of more
/// "primitive" CIR operations that act on the ABI types.
virtual mlir::Value
lowerDerivedDataMember(cir::DerivedDataMemberOp op, mlir::Value loweredSrc,
mlir::OpBuilder &builder) const = 0;
};

/// Creates an Itanium-family ABI.
Expand Down
46 changes: 46 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ class ItaniumCXXABI : public CIRCXXABI {
lowerGetRuntimeMember(cir::GetRuntimeMemberOp op, mlir::Type loweredResultTy,
mlir::Value loweredAddr, mlir::Value loweredMember,
mlir::OpBuilder &builder) const override;

mlir::Value lowerBaseDataMember(cir::BaseDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;

mlir::Value lowerDerivedDataMember(cir::DerivedDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const override;
};

} // namespace
Expand Down Expand Up @@ -129,6 +137,44 @@ mlir::Operation *ItaniumCXXABI::lowerGetRuntimeMember(
memberBytesPtr);
}

static mlir::Value lowerDataMemberCast(mlir::Operation *op,
mlir::Value loweredSrc,
std::int64_t offset,
bool isDerivedToBase,
mlir::OpBuilder &builder) {
if (offset == 0)
return loweredSrc;

auto nullValue = builder.create<cir::ConstantOp>(
op->getLoc(), mlir::IntegerAttr::get(loweredSrc.getType(), -1));
auto isNull = builder.create<cir::CmpOp>(op->getLoc(), cir::CmpOpKind::eq,
loweredSrc, nullValue);

auto offsetValue = builder.create<cir::ConstantOp>(
op->getLoc(), mlir::IntegerAttr::get(loweredSrc.getType(), offset));
auto binOpKind = isDerivedToBase ? cir::BinOpKind::Sub : cir::BinOpKind::Add;
auto adjustedPtr = builder.create<cir::BinOp>(
op->getLoc(), loweredSrc.getType(), binOpKind, loweredSrc, offsetValue);

return builder.create<cir::SelectOp>(op->getLoc(), loweredSrc.getType(),
isNull, nullValue, adjustedPtr);
}

mlir::Value ItaniumCXXABI::lowerBaseDataMember(cir::BaseDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
return lowerDataMemberCast(op, loweredSrc, op.getOffset().getSExtValue(),
/*isDerivedToBase=*/true, builder);
}

mlir::Value
ItaniumCXXABI::lowerDerivedDataMember(cir::DerivedDataMemberOp op,
mlir::Value loweredSrc,
mlir::OpBuilder &builder) const {
return lowerDataMemberCast(op, loweredSrc, op.getOffset().getSExtValue(),
/*isDerivedToBase=*/false, builder);
}

CIRCXXABI *CreateItaniumCXXABI(LowerModule &LM) {
switch (LM.getCXXABIKind()) {
// Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't
Expand Down
74 changes: 57 additions & 17 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,24 @@ mlir::LogicalResult CIRToLLVMDerivedClassAddrOpLowering::matchAndRewrite(
return mlir::success();
}

mlir::LogicalResult CIRToLLVMBaseDataMemberOpLowering::matchAndRewrite(
cir::BaseDataMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Value loweredResult =
lowerMod->getCXXABI().lowerBaseDataMember(op, adaptor.getSrc(), rewriter);
rewriter.replaceOp(op, loweredResult);
return mlir::success();
}

mlir::LogicalResult CIRToLLVMDerivedDataMemberOpLowering::matchAndRewrite(
cir::DerivedDataMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Value loweredResult = lowerMod->getCXXABI().lowerDerivedDataMember(
op, adaptor.getSrc(), rewriter);
rewriter.replaceOp(op, loweredResult);
return mlir::success();
}

static mlir::Value
getValueForVTableSymbol(mlir::Operation *op,
mlir::ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -1518,7 +1536,13 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Attribute attr = op.getValue();

if (mlir::isa<cir::BoolType>(op.getType())) {
if (mlir::isa<mlir::IntegerType>(op.getType())) {
// Verified cir.const operations cannot actually be of these types, but the
// lowering pass may generate temporary cir.const operations with these
// types. This is OK since MLIR allows unverified operations to be alive
// during a pass as long as they don't live past the end of the pass.
attr = op.getValue();
} else if (mlir::isa<cir::BoolType>(op.getType())) {
int value = (op.getValue() ==
cir::BoolAttr::get(getContext(),
cir::BoolType::get(getContext()), true));
Expand Down Expand Up @@ -2412,11 +2436,12 @@ CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const {
mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
cir::BinOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
assert((op.getLhs().getType() == op.getRhs().getType()) &&
assert((adaptor.getLhs().getType() == adaptor.getRhs().getType()) &&
"inconsistent operands' types not supported yet");

mlir::Type type = op.getRhs().getType();
assert((mlir::isa<cir::IntType, cir::CIRFPTypeInterface, cir::VectorType>(
type)) &&
assert((mlir::isa<cir::IntType, cir::CIRFPTypeInterface, cir::VectorType,
mlir::IntegerType>(type)) &&
"operand type not supported yet");

auto llvmTy = getTypeConverter()->convertType(op.getType());
Expand All @@ -2427,38 +2452,44 @@ mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(

switch (op.getKind()) {
case cir::BinOpKind::Add:
if (mlir::isa<cir::IntType>(type))
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmTy, lhs, rhs);
break;
case cir::BinOpKind::Sub:
if (mlir::isa<cir::IntType>(type))
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, llvmTy, lhs, rhs);
break;
case cir::BinOpKind::Mul:
if (mlir::isa<cir::IntType>(type))
if (mlir::isa<cir::IntType, mlir::IntegerType>(type))
rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
getIntOverflowFlag(op));
else
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, llvmTy, lhs, rhs);
break;
case cir::BinOpKind::Div:
if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
if (ty.isUnsigned())
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
auto isUnsigned = mlir::isa<cir::IntType>(type)
? mlir::cast<cir::IntType>(type).isUnsigned()
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, llvmTy, lhs, rhs);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, llvmTy, lhs, rhs);
} else
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, llvmTy, lhs, rhs);
break;
case cir::BinOpKind::Rem:
if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
if (ty.isUnsigned())
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
auto isUnsigned = mlir::isa<cir::IntType>(type)
? mlir::cast<cir::IntType>(type).isUnsigned()
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
if (isUnsigned)
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, llvmTy, lhs, rhs);
else
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, llvmTy, lhs, rhs);
Expand Down Expand Up @@ -2642,9 +2673,12 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
mlir::Value llResult;

// Lower to LLVM comparison op.
if (auto intTy = mlir::dyn_cast<cir::IntType>(type)) {
auto kind =
convertCmpKindToICmpPredicate(cmpOp.getKind(), intTy.isSigned());
// if (auto intTy = mlir::dyn_cast<cir::IntType>(type)) {
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
auto isSigned = mlir::isa<cir::IntType>(type)
? mlir::cast<cir::IntType>(type).isSigned()
: mlir::cast<mlir::IntegerType>(type).isSigned();
auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
Expand Down Expand Up @@ -3847,9 +3881,15 @@ void populateCIRToLLVMConversionPatterns(
patterns.add<CIRToLLVMAllocaOpLowering>(converter, dataLayout,
stringGlobalsMap, argStringGlobalsMap,
argsVarMap, patterns.getContext());
patterns.add<CIRToLLVMConstantOpLowering, CIRToLLVMGlobalOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering>(
converter, patterns.getContext(), lowerModule);
patterns.add<
// clang-format off
CIRToLLVMBaseDataMemberOpLowering,
CIRToLLVMConstantOpLowering,
CIRToLLVMDerivedDataMemberOpLowering,
CIRToLLVMGetRuntimeMemberOpLowering,
CIRToLLVMGlobalOpLowering
// clang-format on
>(converter, patterns.getContext(), lowerModule);
patterns.add<
// clang-format off
CIRToLLVMAbsOpLowering,
Expand Down
30 changes: 30 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,36 @@ class CIRToLLVMDerivedClassAddrOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBaseDataMemberOpLowering
: public mlir::OpConversionPattern<cir::BaseDataMemberOp> {
cir::LowerModule *lowerMod;

public:
CIRToLLVMBaseDataMemberOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}

mlir::LogicalResult
matchAndRewrite(cir::BaseDataMemberOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMDerivedDataMemberOpLowering
: public mlir::OpConversionPattern<cir::DerivedDataMemberOp> {
cir::LowerModule *lowerMod;

public:
CIRToLLVMDerivedDataMemberOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}

mlir::LogicalResult
matchAndRewrite(cir::DerivedDataMemberOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMVTTAddrPointOpLowering
: public mlir::OpConversionPattern<cir::VTTAddrPointOp> {
public:
Expand Down
Loading

0 comments on commit eacaabb

Please sign in to comment.