Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][NFC] turn fir.call is_bind_c into enum for procedure flags #105691

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,33 @@ def fir_FortranVariableFlagsAttr : fir_Attr<"FortranVariableFlags"> {
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
}


/// Fortran procedure attributes (F2023 15.6.2.1). BIND attribute (18.3.7)
/// is also tracked in the same enum. Recursive (resp. Impure) attribute
/// is implied by the absence of opposite NonRecursive (resp. Pure) attribute.
def FIRfuncNoAttributes : I32BitEnumAttrCaseNone<"none">;
def FIRfuncElemental : I32BitEnumAttrCaseBit<"elemental", 0>;
def FIRfuncPure : I32BitEnumAttrCaseBit<"pure", 1>;
def FIRfuncNonRecursive : I32BitEnumAttrCaseBit<"non_recursive", 2>;
def FIRfuncSimple : I32BitEnumAttrCaseBit<"simple", 3>;
def FIRfuncBind_c : I32BitEnumAttrCaseBit<"bind_c", 4>;

def fir_FortranProcedureFlagsEnum : I32BitEnumAttr<
"FortranProcedureFlagsEnum",
"Fortran procedure attributes",
[FIRfuncNoAttributes, FIRfuncElemental, FIRfuncPure, FIRfuncNonRecursive,
FIRfuncSimple, FIRfuncBind_c]> {
let separator = ", ";
let cppNamespace = "::fir";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}

def fir_FortranProcedureFlagsAttr :
EnumAttr<FIROpsDialect, fir_FortranProcedureFlagsEnum, "proc_attrs"> {
let assemblyFormat = "`<` $value `>`";
}

def fir_BoxFieldAttr : I32EnumAttr<
"BoxFieldAttr", "",
[
Expand Down
4 changes: 2 additions & 2 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2432,9 +2432,9 @@ def fir_CallOp : fir_Op<"call",
let arguments = (ins
OptionalAttr<SymbolRefAttr>:$callee,
Variadic<AnyType>:$args,
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath,
UnitAttr:$is_bind_c
"::mlir::arith::FastMathFlags::none">:$fastmath
);
let results = (outs Variadic<AnyType>);

Expand Down
18 changes: 11 additions & 7 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,17 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
auto stackSaveSymbol = bldr->getSymbolRefAttr(stackSaveFn.getName());
mlir::Value sp;
fir::CallOp call = bldr->create<fir::CallOp>(
loc, stackSaveFn.getFunctionType().getResults(), stackSaveSymbol,
loc, stackSaveSymbol, stackSaveFn.getFunctionType().getResults(),
mlir::ValueRange{});
if (call.getNumResults() != 0)
sp = call.getResult(0);
stmtCtx.attachCleanup([bldr, loc, sp]() {
auto stackRestoreFn = fir::factory::getLlvmStackRestore(*bldr);
auto stackRestoreSymbol =
bldr->getSymbolRefAttr(stackRestoreFn.getName());
bldr->create<fir::CallOp>(loc,
bldr->create<fir::CallOp>(loc, stackRestoreSymbol,
stackRestoreFn.getFunctionType().getResults(),
stackRestoreSymbol, mlir::ValueRange{sp});
mlir::ValueRange{sp});
});
}
mlir::Value temp =
Expand Down Expand Up @@ -640,11 +640,15 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
if (callNumResults != 0)
callResult = dispatch.getResult(0);
} else {
// Standard procedure call with fir.call.
auto call = builder.create<fir::CallOp>(loc, funcType.getResults(),
funcSymbolAttr, operands);
// TODO: gather other procedure attributes.
fir::FortranProcedureFlagsEnumAttr procAttrs;
if (caller.characterize().IsBindC())
call.setIsBindC(true);
procAttrs = fir::FortranProcedureFlagsEnumAttr::get(
builder.getContext(), fir::FortranProcedureFlagsEnum::bind_c);

// Standard procedure call with fir.call.
auto call = builder.create<fir::CallOp>(
loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs);

callNumResults = call.getNumResults();
if (callNumResults != 0)
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6120,7 +6120,7 @@ class ArrayExprLowering {
mlir::SymbolRefAttr funcSymAttr =
builder.getSymbolRefAttr(memcpyFunc.getName());
mlir::FunctionType funcTy = memcpyFunc.getFunctionType();
builder.create<fir::CallOp>(loc, funcTy.getResults(), funcSymAttr, args);
builder.create<fir::CallOp>(loc, funcSymAttr, funcTy.getResults(), args);
}

// Construct code to check for a buffer overrun and realloc the buffer when
Expand All @@ -6146,7 +6146,7 @@ class ArrayExprLowering {
builder.getSymbolRefAttr(reallocFunc.getName());
mlir::FunctionType funcTy = reallocFunc.getFunctionType();
auto newMem = builder.create<fir::CallOp>(
loc, funcTy.getResults(), funcSymAttr,
loc, funcSymAttr, funcTy.getResults(),
llvm::ArrayRef<mlir::Value>{
builder.createConvert(loc, funcTy.getInputs()[0], mem),
builder.createConvert(loc, funcTy.getInputs()[1], byteSz)});
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,8 @@ mlir::Value genLibCall(fir::FirOpBuilder &builder, mlir::Location loc,

llvm::SmallVector<mlir::Value, 3> operands{funcPointer};
operands.append(args.begin(), args.end());
libCall = builder.create<fir::CallOp>(loc, libFuncType.getResults(),
nullptr, operands);
libCall = builder.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
libFuncType.getResults(), operands);
}

LLVM_DEBUG(libCall.dump(); llvm::dbgs() << "\n");
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Dialect/FIRAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
//===----------------------------------------------------------------------===//

void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
addAttributes<ClosedIntervalAttr, ExactTypeAttr,
FortranProcedureFlagsEnumAttr, FortranVariableFlagsAttr,
LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
SubclassAttr, UpperBoundAttr, LocationKindAttr,
LocationKindArrayAttr>();
Expand Down
23 changes: 20 additions & 3 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,14 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
p << getOperand(0);
p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')';

// Print `proc_attrs<...>`, if present.
fir::FortranProcedureFlagsEnumAttr procAttrs = getProcedureAttrsAttr();
if (procAttrs &&
procAttrs.getValue() != fir::FortranProcedureFlagsEnum::none) {
p << ' ' << fir::FortranProcedureFlagsEnumAttr::getMnemonic();
p.printStrippedAttrOrType(procAttrs);
}

// Print 'fastmath<...>' (if it has non-default value) before
// any other attributes.
mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr();
Expand All @@ -1111,9 +1119,9 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
p.printStrippedAttrOrType(fmfAttr);
}

p.printOptionalAttrDict(
(*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName()});
p.printOptionalAttrDict((*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr(),
getFastmathAttrName(), getProcedureAttrsAttrName()});
auto resultTypes{getResultTypes()};
llvm::SmallVector<mlir::Type> argTypes(
llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
Expand All @@ -1138,6 +1146,15 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();

// Parse `proc_attrs<...>`, if present.
fir::FortranProcedureFlagsEnumAttr procAttr;
if (mlir::succeeded(parser.parseOptionalKeyword(
fir::FortranProcedureFlagsEnumAttr::getMnemonic())))
if (parser.parseCustomAttributeWithFallback(
procAttr, mlir::Type{}, getProcedureAttrsAttrName(result.name),
attrs))
return mlir::failure();

// Parse 'fastmath<...>', if present.
mlir::arith::FastMathFlagsAttr fmfAttr;
llvm::StringRef fmfAttrName = getFastmathAttrName(result.name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
newResultTypes.append(callOp.getResultTypes().begin(),
callOp.getResultTypes().end());
fir::CallOp newOp = builder.create<fir::CallOp>(
loc, newResultTypes,
loc,
callOp.getCallee().has_value() ? callOp.getCallee().value()
: mlir::SymbolRefAttr{},
newOperands);
newResultTypes, newOperands);
// Copy all the attributes from the old to new op.
newOp->setAttrs(callOp->getAttrs());
rewriter.replaceOp(callOp, newOp);
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
// Make the call.
llvm::SmallVector<mlir::Value> args{funcPtr};
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
rewriter.replaceOpWithNewOp<fir::CallOp>(dispatch, resTypes, nullptr, args);
// FIXME: add procedure_attrs to fir.dispatch and propagate to fir.call.
rewriter.replaceOpWithNewOp<fir::CallOp>(
dispatch, resTypes, nullptr, args,
/*procedure_attrs=*/fir::FortranProcedureFlagsEnumAttr{});
return mlir::success();
}

Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,9 @@ void AllocMemConversion::insertStackSaveRestore(
builder.setInsertionPoint(oldAlloc);
mlir::Value sp =
builder
.create<fir::CallOp>(oldAlloc.getLoc(),
.create<fir::CallOp>(oldAlloc.getLoc(), stackSaveSym,
stackSaveFn.getFunctionType().getResults(),
stackSaveSym, mlir::ValueRange{})
mlir::ValueRange{})
.getResult(0);

mlir::func::FuncOp stackRestoreFn =
Expand All @@ -750,9 +750,9 @@ void AllocMemConversion::insertStackSaveRestore(

auto createStackRestoreCall = [&](mlir::Operation *user) {
builder.setInsertionPoint(user);
builder.create<fir::CallOp>(user->getLoc(),
builder.create<fir::CallOp>(user->getLoc(), stackRestoreSym,
stackRestoreFn.getFunctionType().getResults(),
stackRestoreSym, mlir::ValueRange{sp});
mlir::ValueRange{sp});
};

for (mlir::Operation *user : oldAlloc->getUsers()) {
Expand Down
4 changes: 2 additions & 2 deletions flang/test/HLFIR/c_ptr_byvalue.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
! CHECK: %[[VAL_113:.*]] = fir.load %[[VAL_112]] : !fir.ref<i64>
! CHECK: %[[VAL_114:.*]] = fir.convert %[[VAL_113]] : (i64) -> !fir.ref<i64>
! CHECK: hlfir.end_associate %[[VAL_110]]#1, %[[VAL_110]]#2 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, i1
! CHECK: fir.call @get_expected_f(%[[VAL_114]]) fastmath<contract> {is_bind_c} : (!fir.ref<i64>) -> ()
! CHECK: fir.call @get_expected_f(%[[VAL_114]]) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i64>) -> ()
subroutine test1
use iso_c_binding
interface
Expand All @@ -28,7 +28,7 @@ end subroutine get_expected_f
! CHECK: %[[VAL_99:.*]] = fir.coordinate_of %[[VAL_97]]#0, %[[VAL_98]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
! CHECK: %[[VAL_100:.*]] = fir.load %[[VAL_99]] : !fir.ref<i64>
! CHECK: %[[VAL_101:.*]] = fir.convert %[[VAL_100]] : (i64) -> !fir.ref<i64>
! CHECK: fir.call @get_expected_f(%[[VAL_101]]) fastmath<contract> {is_bind_c} : (!fir.ref<i64>) -> ()
! CHECK: fir.call @get_expected_f(%[[VAL_101]]) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i64>) -> ()
subroutine test2(cptr)
use iso_c_binding
interface
Expand Down
8 changes: 4 additions & 4 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ end

! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: fir.call @__syncthreads()
! CHECK: fir.call @__syncwarp(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> ()
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
! CHECK: fir.call @__threadfence()
! CHECK: fir.call @__threadfence_block()
! CHECK: fir.call @__threadfence_system()
! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32

! CHECK: func.func private @__syncthreads() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads"}
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp"}
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/assumed-rank-calls.f90
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ subroutine bindc_func(x) bind(c)
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] {uniq_name = "_QFtest_to_bindcEx"} : (!fir.box<!fir.array<*:f32>>, !fir.dscope) -> (!fir.box<!fir.array<*:f32>>, !fir.box<!fir.array<*:f32>>)
! CHECK: %[[VAL_3:.*]] = fir.rebox_assumed_rank %[[VAL_2]]#0 lbs zeroes : (!fir.box<!fir.array<*:f32>>) -> !fir.box<!fir.array<*:f32>>
! CHECK: fir.call @bindc_func(%[[VAL_3]]) fastmath<contract> {is_bind_c} : (!fir.box<!fir.array<*:f32>>) -> ()
! CHECK: fir.call @bindc_func(%[[VAL_3]]) proc_attrs<bind_c> fastmath<contract> : (!fir.box<!fir.array<*:f32>>) -> ()
! CHECK: return
! CHECK: }

Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/HLFIR/assumed-rank-iface.f90
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ subroutine int_scalar_to_assumed_rank_bindc(x)
! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFint_scalar_to_assumed_rank_bindcEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_2:.*]] = fir.embox %[[VAL_1]]#0 : (!fir.ref<i32>) -> !fir.box<i32>
! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.box<i32>) -> !fir.box<!fir.array<*:i32>>
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_3]]) fastmath<contract> {is_bind_c} : (!fir.box<!fir.array<*:i32>>) -> ()
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_3]]) proc_attrs<bind_c> fastmath<contract> : (!fir.box<!fir.array<*:i32>>) -> ()

subroutine int_r1_to_assumed_rank(x)
use ifaces, only : int_assumed_rank
Expand Down Expand Up @@ -94,7 +94,7 @@ subroutine int_assumed_shape_to_assumed_rank_bindc(x)
! CHECK: %[[VAL_3:.*]] = fir.shift %[[VAL_2]], %[[VAL_2]] : (index, index) -> !fir.shift<2>
! CHECK: %[[VAL_4:.*]] = fir.rebox %[[VAL_1]]#0(%[[VAL_3]]) : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>) -> !fir.box<!fir.array<?x?xi32>>
! CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_4]] : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<!fir.array<*:i32>>
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_5]]) fastmath<contract> {is_bind_c} : (!fir.box<!fir.array<*:i32>>) -> ()
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_5]]) proc_attrs<bind_c> fastmath<contract> : (!fir.box<!fir.array<*:i32>>) -> ()

subroutine int_allocatable_to_assumed_rank(x)
use ifaces, only : int_assumed_rank
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/bindc-value-derived.f90
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ subroutine call_it(x)
! CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>> {fir.bindc_name = "x"}) {
! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QMbindc_byvalFcall_itEx"} : (!fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>, !fir.dscope) -> (!fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>, !fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>)
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_1]]#1 : !fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>
! CHECK: fir.call @test(%[[VAL_2]]) fastmath<contract> {is_bind_c} : (!fir.type<_QMbindc_byvalTt{i:i32}>) -> ()
! CHECK: fir.call @test(%[[VAL_2]]) proc_attrs<bind_c> fastmath<contract> : (!fir.type<_QMbindc_byvalTt{i:i32}>) -> ()
! CHECK: return
! CHECK: }
end module
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/block_bindc_pocs.f90
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end subroutine test_proc
end interface
end module m
!CHECK-DAG: %[[S0:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
!CHECK-DAG: fir.call @test_proc() fastmath<contract> {is_bind_c} : () -> ()
!CHECK-DAG: fir.call @test_proc() proc_attrs<bind_c> fastmath<contract> : () -> ()
!CHECK-DAG: fir.call @llvm.stackrestore.p0(%[[S0]]) fastmath<contract> : (!fir.ref<i8>) -> ()
!CHECK-DAG: func.func private @test_proc() attributes {fir.bindc_name = "test_proc"}
subroutine test
Expand Down
Loading
Loading