Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][FIR] remove fir.complex type and its fir.real element type #111025

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions flang/include/flang/Optimizer/CodeGen/TypeConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,10 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
// fir.char<k,n> --> llvm.array<n x "ix">
mlir::Type convertCharType(fir::CharacterType charTy) const;

// Use the target specifics to figure out how to map complex to LLVM IR. The
// use of complex values in function signatures is handled before conversion
// to LLVM IR dialect here.
//
// fir.complex<T> | std.complex<T> --> llvm<"{t,t}">
template <typename C>
mlir::Type convertComplexType(C cmplx) const {
auto eleTy = cmplx.getElementType();
return convertType(specifics->complexMemoryType(eleTy));
}

template <typename A> mlir::Type convertPointerLike(A &ty) const {
return mlir::LLVM::LLVMPointerType::get(ty.getContext());
}

// convert a front-end kind value to either a std or LLVM IR dialect type
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
mlir::Type convertRealType(fir::KindTy kind) const;

// fir.array<c ... :any> --> llvm<"[...[c x any]]">
mlir::Type convertSequenceType(SequenceType seq) const;

Expand Down
4 changes: 2 additions & 2 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2735,8 +2735,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
}

def FortranTypeAttr : Attr<And<[CPred<"mlir::isa<mlir::TypeAttr>($_self)">,
Or<[CPred<"mlir::isa<fir::CharacterType, fir::ComplexType, "
"fir::IntegerType, fir::LogicalType, fir::RealType, "
Or<[CPred<"mlir::isa<fir::CharacterType, fir::IntegerType,"
"fir::LogicalType, mlir::FloatType, mlir::ComplexType,"
"fir::RecordType>(mlir::cast<mlir::TypeAttr>($_self).getValue())"
>]>]>, "Fortran surface type"> {
let storageType = [{ ::mlir::TypeAttr }];
Expand Down
6 changes: 2 additions & 4 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
/// `t` is not a memory reference or box type, then returns a null `Type`.
mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t);

/// Is `t` a FIR Real or MLIR Float type?
inline bool isa_real(mlir::Type t) {
return mlir::isa<fir::RealType, mlir::FloatType>(t);
}
/// Is `t` a real type?
inline bool isa_real(mlir::Type t) { return mlir::isa<mlir::FloatType>(t); }

/// Is `t` an integral type?
inline bool isa_integer(mlir::Type t) {
Expand Down
43 changes: 1 addition & 42 deletions flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,26 +158,6 @@ def fir_ClassType : FIR_Type<"Class", "class", [], "BaseBoxType"> {
let assemblyFormat = "`<` $eleTy `>`";
}

def fir_ComplexType : FIR_Type<"Complex", "complex"> {
let summary = "Complex type";

let description = [{
Model of a Fortran COMPLEX intrinsic type, including the KIND type
parameter. COMPLEX is a floating point type with a real and imaginary
member.
}];

let parameters = (ins "KindTy":$fKind);
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
using KindTy = unsigned;

mlir::Type getElementType() const;
mlir::Type getEleType(const fir::KindMapping &kindMap) const;
}];
}

def fir_FieldType : FIR_Type<"Field", "field"> {
let summary = "A field (in a RecordType) argument's type";

Expand Down Expand Up @@ -313,26 +293,6 @@ def fir_PointerType : FIR_Type<"Pointer", "ptr"> {
}];
}

def fir_RealType : FIR_Type<"Real", "real"> {
let summary = "FIR real type";

let description = [{
Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the
KIND type parameter.
}];

let parameters = (ins "KindTy":$fKind);
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
using KindTy = unsigned;
// Get MLIR float type with same semantics.
mlir::Type getFloatType(const fir::KindMapping &kindMap) const;
}];

let genVerifyDecl = 1;
}

def fir_RecordType : FIR_Type<"Record", "type"> {
let summary = "FIR derived type";

Expand Down Expand Up @@ -597,8 +557,7 @@ def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
AnySignedInteger.predicate, fir_IntegerType.predicate]>, "any integer">;
def AnyLogicalLike : TypeConstraint<Or<[BoolLike.predicate,
fir_LogicalType.predicate]>, "any logical">;
def AnyRealLike : TypeConstraint<Or<[FloatLike.predicate,
fir_RealType.predicate]>, "any real">;
def AnyRealLike : TypeConstraint<FloatLike.predicate, "any real">;
def AnyIntegerType : Type<AnyIntegerLike.predicate, "any integer">;

def AnyFirComplexLike : TypeConstraint<CPred<"::fir::isa_complex($_self)">,
Expand Down
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/HLFIR/HLFIROpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def AnyFortranVariable : Type<IsFortranVariablePred, "any HLFIR variable type">;

def AnyFortranValue : TypeConstraint<Or<[AnyLogicalLike.predicate,
AnyIntegerLike.predicate, AnyRealLike.predicate,
fir_ComplexType.predicate, AnyComplex.predicate,
AnyFirComplexLike.predicate,
hlfir_ExprType.predicate]>, "any Fortran value type">;


Expand Down
4 changes: 0 additions & 4 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5710,10 +5710,6 @@ class ArrayExprLowering {
fir::applyPathToType(seqTy.getEleTy(), components.suffixComponents);
if (!eleTy)
fir::emitFatalError(loc, "slicing path is ill-formed");
if (auto realTy = mlir::dyn_cast<fir::RealType>(eleTy))
eleTy = Fortran::lower::convertReal(realTy.getContext(),
realTy.getFKind());

// create the type of the projected array.
arrTy = fir::SequenceType::get(seqTy.getShape(), eleTy);
LLVM_DEBUG(llvm::dbgs()
Expand Down
2 changes: 0 additions & 2 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ mlir::Value
fir::FirOpBuilder::createRealConstant(mlir::Location loc, mlir::Type fltTy,
llvm::APFloat::integerPart val) {
auto apf = [&]() -> llvm::APFloat {
if (auto ty = mlir::dyn_cast<fir::RealType>(fltTy))
return llvm::APFloat(kindMap.getFloatSemantics(ty.getFKind()), val);
if (fltTy.isF16())
return llvm::APFloat(llvm::APFloat::IEEEhalf(), val);
if (fltTy.isBF16())
Expand Down
4 changes: 1 addition & 3 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,7 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
} // namespace

static mlir::Type getComplexEleTy(mlir::Type complex) {
if (auto cc = mlir::dyn_cast<mlir::ComplexType>(complex))
return cc.getElementType();
return mlir::cast<fir::ComplexType>(complex).getElementType();
return mlir::cast<mlir::ComplexType>(complex).getElementType();
}

namespace {
Expand Down
15 changes: 10 additions & 5 deletions flang/lib/Optimizer/CodeGen/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ llvm::StringRef Attributes::getIntExtensionAttrName() const {
static const llvm::fltSemantics &floatToSemantics(const KindMapping &kindMap,
mlir::Type type) {
assert(isa_real(type));
if (auto ty = mlir::dyn_cast<fir::RealType>(type))
return kindMap.getFloatSemantics(ty.getFKind());
return mlir::cast<mlir::FloatType>(type).getFloatSemantics();
}

Expand Down Expand Up @@ -356,7 +354,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
else
current = ArgClass::Integer;
})
.template Case<mlir::FloatType, fir::RealType>([&](mlir::Type floatTy) {
.template Case<mlir::FloatType>([&](mlir::Type floatTy) {
const auto *sem = &floatToSemantics(kindMap, floatTy);
if (sem == &llvm::APFloat::x87DoubleExtended()) {
Lo = ArgClass::X87;
Expand Down Expand Up @@ -540,9 +538,16 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
if (typeList.size() != 1)
return {};
mlir::Type fieldType = typeList[0].second;
if (mlir::isa<mlir::FloatType, mlir::IntegerType, fir::RealType,
fir::CharacterType, fir::LogicalType>(fieldType))
if (mlir::isa<mlir::FloatType, mlir::IntegerType, fir::LogicalType>(
fieldType))
return fieldType;
if (mlir::isa<fir::CharacterType>(fieldType)) {
// Only CHARACTER(1) are expected in BIND(C) contexts, which is the only
// contexts where derived type may be passed in registers.
assert(mlir::cast<fir::CharacterType>(fieldType).getLen() == 1 &&
"fir.type value arg character components must have length 1");
return fieldType;
}
// Complex field that needs to be split, or array.
return {};
}
Expand Down
48 changes: 12 additions & 36 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
if (fnTy.getResults().size() == 1) {
mlir::Type ty = fnTy.getResult(0);
llvm::TypeSwitch<mlir::Type>(ty)
.template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
newInTyAndAttrs, newOpers,
savedStackPtr);
})
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
wrap = rewriteCallComplexResultType(loc, cmplx, newResTys,
newInTyAndAttrs, newOpers,
Expand Down Expand Up @@ -414,10 +409,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
})
.template Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
newOpers, savedStackPtr);
})
.template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
rewriteCallComplexInputType(loc, cmplx, oper, newInTyAndAttrs,
newOpers, savedStackPtr);
Expand Down Expand Up @@ -538,10 +529,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}

// Result type fixup for fir::ComplexType and mlir::ComplexType
template <typename A, typename B>
// Result type fixup for ComplexType.
template <typename Ty>
void lowerComplexSignatureRes(
mlir::Location loc, A cmplx, B &newResTys,
mlir::Location loc, mlir::ComplexType cmplx, Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
if (noComplexConversion) {
newResTys.push_back(cmplx);
Expand All @@ -557,10 +548,9 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}

// Argument type fixup for fir::ComplexType and mlir::ComplexType
template <typename A>
// Argument type fixup for ComplexType.
void lowerComplexSignatureArg(
mlir::Location loc, A cmplx,
mlir::Location loc, mlir::ComplexType cmplx,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs) {
if (noComplexConversion) {
newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
Expand Down Expand Up @@ -602,9 +592,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
auto loc = addrOp.getLoc();
for (mlir::Type ty : addrTy.getResults()) {
llvm::TypeSwitch<mlir::Type>(ty)
.Case<fir::ComplexType>([&](fir::ComplexType ty) {
lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
lowerComplexSignatureRes(loc, ty, newResTys, newInTyAndAttrs);
})
Expand All @@ -628,9 +615,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
})
.Case<fir::ComplexType>([&](fir::ComplexType ty) {
lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
lowerComplexSignatureArg(loc, ty, newInTyAndAttrs);
})
Expand Down Expand Up @@ -766,12 +750,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
// Convert return value(s)
for (auto ty : funcTy.getResults())
llvm::TypeSwitch<mlir::Type>(ty)
.Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
if (noComplexConversion)
newResTys.push_back(cmplx);
else
doComplexReturn(func, cmplx, newResTys, newInTyAndAttrs, fixups);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
if (noComplexConversion)
newResTys.push_back(cmplx);
Expand Down Expand Up @@ -835,9 +813,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
}
}
})
.Case<fir::ComplexType>([&](fir::ComplexType cmplx) {
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
.Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) {
doComplexArg(func, cmplx, newInTyAndAttrs, fixups);
})
Expand Down Expand Up @@ -1090,10 +1065,11 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// Convert a complex return value. This can involve converting the return
/// value to a "hidden" first argument or packing the complex into a wide
/// GPR.
template <typename A, typename B, typename C>
void doComplexReturn(mlir::func::FuncOp func, A cmplx, B &newResTys,
template <typename Ty, typename FIXUPS>
void doComplexReturn(mlir::func::FuncOp func, mlir::ComplexType cmplx,
Ty &newResTys,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
C &fixups) {
FIXUPS &fixups) {
if (noComplexConversion) {
newResTys.push_back(cmplx);
return;
Expand Down Expand Up @@ -1194,10 +1170,10 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
/// Convert a complex argument value. This can involve storing the value to
/// a temporary memory location or factoring the value into two distinct
/// arguments.
template <typename A, typename B>
void doComplexArg(mlir::func::FuncOp func, A cmplx,
template <typename FIXUPS>
void doComplexArg(mlir::func::FuncOp func, mlir::ComplexType cmplx,
fir::CodeGenSpecifics::Marshalling &newInTyAndAttrs,
B &fixups) {
FIXUPS &fixups) {
if (noComplexConversion) {
newInTyAndAttrs.push_back(fir::CodeGenSpecifics::getTypeAndAttr(cmplx));
return;
Expand Down
11 changes: 0 additions & 11 deletions flang/lib/Optimizer/CodeGen/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
[&](fir::ClassType classTy) { return convertBoxType(classTy); });
addConversion(
[&](fir::CharacterType charTy) { return convertCharType(charTy); });
addConversion(
[&](fir::ComplexType cmplx) { return convertComplexType(cmplx); });
addConversion([&](fir::FieldType field) {
// Convert to i32 because of LLVM GEP indexing restriction.
return mlir::IntegerType::get(field.getContext(), 32);
Expand Down Expand Up @@ -86,8 +84,6 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
[&](fir::RecordType derived, llvm::SmallVectorImpl<mlir::Type> &results) {
return convertRecordType(derived, results);
});
addConversion(
[&](fir::RealType real) { return convertRealType(real.getFKind()); });
addConversion(
[&](fir::ReferenceType ref) { return convertPointerLike(ref); });
addConversion([&](fir::SequenceType sequence) {
Expand Down Expand Up @@ -277,13 +273,6 @@ mlir::Type LLVMTypeConverter::convertCharType(fir::CharacterType charTy) const {
return mlir::LLVM::LLVMArrayType::get(iTy, charTy.getLen());
}

// convert a front-end kind value to either a std or LLVM IR dialect type
// fir.real<n> --> llvm.anyfloat where anyfloat is a kind mapping
mlir::Type LLVMTypeConverter::convertRealType(fir::KindTy kind) const {
return fir::fromRealTypeID(&getContext(), kindMapping.getRealTypeID(kind),
kind);
}

// fir.array<c ... :any> --> llvm<"[...[c x any]]">
mlir::Type LLVMTypeConverter::convertSequenceType(SequenceType seq) const {
auto baseTy = convertType(seq.getEleTy());
Expand Down
12 changes: 1 addition & 11 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,7 @@ bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) {
}

bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) {
return mlir::isa<mlir::FloatType, fir::RealType>(ty);
return mlir::isa<mlir::FloatType>(ty);
}

bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) {
Expand Down Expand Up @@ -1533,8 +1533,6 @@ llvm::LogicalResult fir::CoordinateOp::verify() {
} else if (auto t = mlir::dyn_cast<fir::RecordType>(eleTy)) {
// FIXME: This is the same as the tuple case.
return mlir::success();
} else if (auto t = mlir::dyn_cast<fir::ComplexType>(eleTy)) {
eleTy = t.getElementType();
} else if (auto t = mlir::dyn_cast<mlir::ComplexType>(eleTy)) {
eleTy = t.getElementType();
} else if (auto t = mlir::dyn_cast<fir::CharacterType>(eleTy)) {
Expand Down Expand Up @@ -4389,14 +4387,6 @@ mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) {
return ty.getType(fir::toInt(off));
return mlir::Type{};
})
.Case<fir::ComplexType>([&](fir::ComplexType ty) {
auto x = *i;
if (auto *op = (*i++).getDefiningOp())
if (fir::isa_integer(x.getType()))
return ty.getEleType(fir::getKindMapping(
op->getParentOfType<mlir::ModuleOp>()));
return mlir::Type{};
})
.Case<mlir::ComplexType>([&](mlir::ComplexType ty) {
if (fir::isa_integer((*i++).getType()))
return ty.getElementType();
Expand Down
Loading
Loading