Skip to content

Commit

Permalink
[flang][NFC] turn fir.call is_bind_c into enum for procedure flags (#…
Browse files Browse the repository at this point in the history
…105691)

First patch to fix a BIND(C) ABI issue
(#102113). I need to keep
track of BIND(C) in more locations (fir.dispatch and func.func
operations), and I need to fix a few passes that are dropping the
attribute on the floor. Since I expect more procedure attributes that
cannot be reflected in mlir::FunctionType will be needed for ABI,
optimizations, or debug info, this NFC patch adds a new enum attribute
to keep track of procedure attributes in the IR.

This patch is not updating lowering to lower more attributes, this will
be done in a separate patch to keep the test changes low here.

Adding the attribute on fir.dispatch and func.func will also be done in
separate patches.
  • Loading branch information
jeanPerier authored Aug 23, 2024
1 parent 2f144ac commit 2051a7b
Show file tree
Hide file tree
Showing 22 changed files with 100 additions and 48 deletions.
27 changes: 27 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRAttr.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,33 @@ def fir_FortranVariableFlagsAttr : fir_Attr<"FortranVariableFlags"> {
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
}


/// Fortran procedure attributes (F2023 15.6.2.1). BIND attribute (18.3.7)
/// is also tracked in the same enum. Recursive (resp. Impure) attribute
/// is implied by the absence of opposite NonRecursive (resp. Pure) attribute.
def FIRfuncNoAttributes : I32BitEnumAttrCaseNone<"none">;
def FIRfuncElemental : I32BitEnumAttrCaseBit<"elemental", 0>;
def FIRfuncPure : I32BitEnumAttrCaseBit<"pure", 1>;
def FIRfuncNonRecursive : I32BitEnumAttrCaseBit<"non_recursive", 2>;
def FIRfuncSimple : I32BitEnumAttrCaseBit<"simple", 3>;
def FIRfuncBind_c : I32BitEnumAttrCaseBit<"bind_c", 4>;

def fir_FortranProcedureFlagsEnum : I32BitEnumAttr<
"FortranProcedureFlagsEnum",
"Fortran procedure attributes",
[FIRfuncNoAttributes, FIRfuncElemental, FIRfuncPure, FIRfuncNonRecursive,
FIRfuncSimple, FIRfuncBind_c]> {
let separator = ", ";
let cppNamespace = "::fir";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}

def fir_FortranProcedureFlagsAttr :
EnumAttr<FIROpsDialect, fir_FortranProcedureFlagsEnum, "proc_attrs"> {
let assemblyFormat = "`<` $value `>`";
}

def fir_BoxFieldAttr : I32EnumAttr<
"BoxFieldAttr", "",
[
Expand Down
4 changes: 2 additions & 2 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2432,9 +2432,9 @@ def fir_CallOp : fir_Op<"call",
let arguments = (ins
OptionalAttr<SymbolRefAttr>:$callee,
Variadic<AnyType>:$args,
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath,
UnitAttr:$is_bind_c
"::mlir::arith::FastMathFlags::none">:$fastmath
);
let results = (outs Variadic<AnyType>);

Expand Down
18 changes: 11 additions & 7 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,17 +372,17 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
auto stackSaveSymbol = bldr->getSymbolRefAttr(stackSaveFn.getName());
mlir::Value sp;
fir::CallOp call = bldr->create<fir::CallOp>(
loc, stackSaveFn.getFunctionType().getResults(), stackSaveSymbol,
loc, stackSaveSymbol, stackSaveFn.getFunctionType().getResults(),
mlir::ValueRange{});
if (call.getNumResults() != 0)
sp = call.getResult(0);
stmtCtx.attachCleanup([bldr, loc, sp]() {
auto stackRestoreFn = fir::factory::getLlvmStackRestore(*bldr);
auto stackRestoreSymbol =
bldr->getSymbolRefAttr(stackRestoreFn.getName());
bldr->create<fir::CallOp>(loc,
bldr->create<fir::CallOp>(loc, stackRestoreSymbol,
stackRestoreFn.getFunctionType().getResults(),
stackRestoreSymbol, mlir::ValueRange{sp});
mlir::ValueRange{sp});
});
}
mlir::Value temp =
Expand Down Expand Up @@ -640,11 +640,15 @@ std::pair<fir::ExtendedValue, bool> Fortran::lower::genCallOpAndResult(
if (callNumResults != 0)
callResult = dispatch.getResult(0);
} else {
// Standard procedure call with fir.call.
auto call = builder.create<fir::CallOp>(loc, funcType.getResults(),
funcSymbolAttr, operands);
// TODO: gather other procedure attributes.
fir::FortranProcedureFlagsEnumAttr procAttrs;
if (caller.characterize().IsBindC())
call.setIsBindC(true);
procAttrs = fir::FortranProcedureFlagsEnumAttr::get(
builder.getContext(), fir::FortranProcedureFlagsEnum::bind_c);

// Standard procedure call with fir.call.
auto call = builder.create<fir::CallOp>(
loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs);

callNumResults = call.getNumResults();
if (callNumResults != 0)
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6120,7 +6120,7 @@ class ArrayExprLowering {
mlir::SymbolRefAttr funcSymAttr =
builder.getSymbolRefAttr(memcpyFunc.getName());
mlir::FunctionType funcTy = memcpyFunc.getFunctionType();
builder.create<fir::CallOp>(loc, funcTy.getResults(), funcSymAttr, args);
builder.create<fir::CallOp>(loc, funcSymAttr, funcTy.getResults(), args);
}

// Construct code to check for a buffer overrun and realloc the buffer when
Expand All @@ -6146,7 +6146,7 @@ class ArrayExprLowering {
builder.getSymbolRefAttr(reallocFunc.getName());
mlir::FunctionType funcTy = reallocFunc.getFunctionType();
auto newMem = builder.create<fir::CallOp>(
loc, funcTy.getResults(), funcSymAttr,
loc, funcSymAttr, funcTy.getResults(),
llvm::ArrayRef<mlir::Value>{
builder.createConvert(loc, funcTy.getInputs()[0], mem),
builder.createConvert(loc, funcTy.getInputs()[1], byteSz)});
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,8 @@ mlir::Value genLibCall(fir::FirOpBuilder &builder, mlir::Location loc,

llvm::SmallVector<mlir::Value, 3> operands{funcPointer};
operands.append(args.begin(), args.end());
libCall = builder.create<fir::CallOp>(loc, libFuncType.getResults(),
nullptr, operands);
libCall = builder.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
libFuncType.getResults(), operands);
}

LLVM_DEBUG(libCall.dump(); llvm::dbgs() << "\n");
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/Dialect/FIRAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
//===----------------------------------------------------------------------===//

void FIROpsDialect::registerAttributes() {
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
addAttributes<ClosedIntervalAttr, ExactTypeAttr,
FortranProcedureFlagsEnumAttr, FortranVariableFlagsAttr,
LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
SubclassAttr, UpperBoundAttr, LocationKindAttr,
LocationKindArrayAttr>();
Expand Down
23 changes: 20 additions & 3 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,14 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
p << getOperand(0);
p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')';

// Print `proc_attrs<...>`, if present.
fir::FortranProcedureFlagsEnumAttr procAttrs = getProcedureAttrsAttr();
if (procAttrs &&
procAttrs.getValue() != fir::FortranProcedureFlagsEnum::none) {
p << ' ' << fir::FortranProcedureFlagsEnumAttr::getMnemonic();
p.printStrippedAttrOrType(procAttrs);
}

// Print 'fastmath<...>' (if it has non-default value) before
// any other attributes.
mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr();
Expand All @@ -1111,9 +1119,9 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
p.printStrippedAttrOrType(fmfAttr);
}

p.printOptionalAttrDict(
(*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName()});
p.printOptionalAttrDict((*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr(),
getFastmathAttrName(), getProcedureAttrsAttrName()});
auto resultTypes{getResultTypes()};
llvm::SmallVector<mlir::Type> argTypes(
llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
Expand All @@ -1138,6 +1146,15 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();

// Parse `proc_attrs<...>`, if present.
fir::FortranProcedureFlagsEnumAttr procAttr;
if (mlir::succeeded(parser.parseOptionalKeyword(
fir::FortranProcedureFlagsEnumAttr::getMnemonic())))
if (parser.parseCustomAttributeWithFallback(
procAttr, mlir::Type{}, getProcedureAttrsAttrName(result.name),
attrs))
return mlir::failure();

// Parse 'fastmath<...>', if present.
mlir::arith::FastMathFlagsAttr fmfAttr;
llvm::StringRef fmfAttrName = getFastmathAttrName(result.name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ class CallOpRewriter : public mlir::OpRewritePattern<fir::CallOp> {
newResultTypes.append(callOp.getResultTypes().begin(),
callOp.getResultTypes().end());
fir::CallOp newOp = builder.create<fir::CallOp>(
loc, newResultTypes,
loc,
callOp.getCallee().has_value() ? callOp.getCallee().value()
: mlir::SymbolRefAttr{},
newOperands);
newResultTypes, newOperands);
// Copy all the attributes from the old to new op.
newOp->setAttrs(callOp->getAttrs());
rewriter.replaceOp(callOp, newOp);
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,10 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
// Make the call.
llvm::SmallVector<mlir::Value> args{funcPtr};
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
rewriter.replaceOpWithNewOp<fir::CallOp>(dispatch, resTypes, nullptr, args);
// FIXME: add procedure_attrs to fir.dispatch and propagate to fir.call.
rewriter.replaceOpWithNewOp<fir::CallOp>(
dispatch, resTypes, nullptr, args,
/*procedure_attrs=*/fir::FortranProcedureFlagsEnumAttr{});
return mlir::success();
}

Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/Transforms/StackArrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,9 +741,9 @@ void AllocMemConversion::insertStackSaveRestore(
builder.setInsertionPoint(oldAlloc);
mlir::Value sp =
builder
.create<fir::CallOp>(oldAlloc.getLoc(),
.create<fir::CallOp>(oldAlloc.getLoc(), stackSaveSym,
stackSaveFn.getFunctionType().getResults(),
stackSaveSym, mlir::ValueRange{})
mlir::ValueRange{})
.getResult(0);

mlir::func::FuncOp stackRestoreFn =
Expand All @@ -753,9 +753,9 @@ void AllocMemConversion::insertStackSaveRestore(

auto createStackRestoreCall = [&](mlir::Operation *user) {
builder.setInsertionPoint(user);
builder.create<fir::CallOp>(user->getLoc(),
builder.create<fir::CallOp>(user->getLoc(), stackRestoreSym,
stackRestoreFn.getFunctionType().getResults(),
stackRestoreSym, mlir::ValueRange{sp});
mlir::ValueRange{sp});
};

for (mlir::Operation *user : oldAlloc->getUsers()) {
Expand Down
4 changes: 2 additions & 2 deletions flang/test/HLFIR/c_ptr_byvalue.f90
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
! CHECK: %[[VAL_113:.*]] = fir.load %[[VAL_112]] : !fir.ref<i64>
! CHECK: %[[VAL_114:.*]] = fir.convert %[[VAL_113]] : (i64) -> !fir.ref<i64>
! CHECK: hlfir.end_associate %[[VAL_110]]#1, %[[VAL_110]]#2 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, i1
! CHECK: fir.call @get_expected_f(%[[VAL_114]]) fastmath<contract> {is_bind_c} : (!fir.ref<i64>) -> ()
! CHECK: fir.call @get_expected_f(%[[VAL_114]]) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i64>) -> ()
subroutine test1
use iso_c_binding
interface
Expand All @@ -28,7 +28,7 @@ end subroutine get_expected_f
! CHECK: %[[VAL_99:.*]] = fir.coordinate_of %[[VAL_97]]#0, %[[VAL_98]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
! CHECK: %[[VAL_100:.*]] = fir.load %[[VAL_99]] : !fir.ref<i64>
! CHECK: %[[VAL_101:.*]] = fir.convert %[[VAL_100]] : (i64) -> !fir.ref<i64>
! CHECK: fir.call @get_expected_f(%[[VAL_101]]) fastmath<contract> {is_bind_c} : (!fir.ref<i64>) -> ()
! CHECK: fir.call @get_expected_f(%[[VAL_101]]) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i64>) -> ()
subroutine test2(cptr)
use iso_c_binding
interface
Expand Down
8 changes: 4 additions & 4 deletions flang/test/Lower/CUDA/cuda-device-proc.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ end

! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: fir.call @__syncthreads()
! CHECK: fir.call @__syncwarp(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> ()
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
! CHECK: fir.call @__threadfence()
! CHECK: fir.call @__threadfence_block()
! CHECK: fir.call @__threadfence_system()
! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) fastmath<contract> {is_bind_c} : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32

! CHECK: func.func private @__syncthreads() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads"}
! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp"}
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/assumed-rank-calls.f90
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ subroutine bindc_func(x) bind(c)
! CHECK: %[[VAL_1:.*]] = fir.dummy_scope : !fir.dscope
! CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %[[VAL_1]] {uniq_name = "_QFtest_to_bindcEx"} : (!fir.box<!fir.array<*:f32>>, !fir.dscope) -> (!fir.box<!fir.array<*:f32>>, !fir.box<!fir.array<*:f32>>)
! CHECK: %[[VAL_3:.*]] = fir.rebox_assumed_rank %[[VAL_2]]#0 lbs zeroes : (!fir.box<!fir.array<*:f32>>) -> !fir.box<!fir.array<*:f32>>
! CHECK: fir.call @bindc_func(%[[VAL_3]]) fastmath<contract> {is_bind_c} : (!fir.box<!fir.array<*:f32>>) -> ()
! CHECK: fir.call @bindc_func(%[[VAL_3]]) proc_attrs<bind_c> fastmath<contract> : (!fir.box<!fir.array<*:f32>>) -> ()
! CHECK: return
! CHECK: }

Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/HLFIR/assumed-rank-iface.f90
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ subroutine int_scalar_to_assumed_rank_bindc(x)
! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QFint_scalar_to_assumed_rank_bindcEx"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[VAL_2:.*]] = fir.embox %[[VAL_1]]#0 : (!fir.ref<i32>) -> !fir.box<i32>
! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.box<i32>) -> !fir.box<!fir.array<*:i32>>
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_3]]) fastmath<contract> {is_bind_c} : (!fir.box<!fir.array<*:i32>>) -> ()
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_3]]) proc_attrs<bind_c> fastmath<contract> : (!fir.box<!fir.array<*:i32>>) -> ()

subroutine int_r1_to_assumed_rank(x)
use ifaces, only : int_assumed_rank
Expand Down Expand Up @@ -94,7 +94,7 @@ subroutine int_assumed_shape_to_assumed_rank_bindc(x)
! CHECK: %[[VAL_3:.*]] = fir.shift %[[VAL_2]], %[[VAL_2]] : (index, index) -> !fir.shift<2>
! CHECK: %[[VAL_4:.*]] = fir.rebox %[[VAL_1]]#0(%[[VAL_3]]) : (!fir.box<!fir.array<?x?xi32>>, !fir.shift<2>) -> !fir.box<!fir.array<?x?xi32>>
! CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_4]] : (!fir.box<!fir.array<?x?xi32>>) -> !fir.box<!fir.array<*:i32>>
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_5]]) fastmath<contract> {is_bind_c} : (!fir.box<!fir.array<*:i32>>) -> ()
! CHECK: fir.call @int_assumed_rank_bindc(%[[VAL_5]]) proc_attrs<bind_c> fastmath<contract> : (!fir.box<!fir.array<*:i32>>) -> ()

subroutine int_allocatable_to_assumed_rank(x)
use ifaces, only : int_assumed_rank
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/bindc-value-derived.f90
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ subroutine call_it(x)
! CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>> {fir.bindc_name = "x"}) {
! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] dummy_scope %{{[0-9]+}} {uniq_name = "_QMbindc_byvalFcall_itEx"} : (!fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>, !fir.dscope) -> (!fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>, !fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>)
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_1]]#1 : !fir.ref<!fir.type<_QMbindc_byvalTt{i:i32}>>
! CHECK: fir.call @test(%[[VAL_2]]) fastmath<contract> {is_bind_c} : (!fir.type<_QMbindc_byvalTt{i:i32}>) -> ()
! CHECK: fir.call @test(%[[VAL_2]]) proc_attrs<bind_c> fastmath<contract> : (!fir.type<_QMbindc_byvalTt{i:i32}>) -> ()
! CHECK: return
! CHECK: }
end module
2 changes: 1 addition & 1 deletion flang/test/Lower/HLFIR/block_bindc_pocs.f90
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end subroutine test_proc
end interface
end module m
!CHECK-DAG: %[[S0:.*]] = fir.call @llvm.stacksave.p0() fastmath<contract> : () -> !fir.ref<i8>
!CHECK-DAG: fir.call @test_proc() fastmath<contract> {is_bind_c} : () -> ()
!CHECK-DAG: fir.call @test_proc() proc_attrs<bind_c> fastmath<contract> : () -> ()
!CHECK-DAG: fir.call @llvm.stackrestore.p0(%[[S0]]) fastmath<contract> : (!fir.ref<i8>) -> ()
!CHECK-DAG: func.func private @test_proc() attributes {fir.bindc_name = "test_proc"}
subroutine test
Expand Down
Loading

0 comments on commit 2051a7b

Please sign in to comment.