Skip to content

Commit

Permalink
[flang][cuda] Convert cuf.alloc for box to fir.alloca in device conte…
Browse files Browse the repository at this point in the history
…xt (llvm#102662)

In device context managed memory is not available so it makes no sense
to allocate the descriptor using it. Fall back to fir.alloca as it is
handled well in device code.
cuf.free is just dropped.
  • Loading branch information
clementval authored and bwendling committed Aug 15, 2024
1 parent 1eb859c commit bf2f59c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
30 changes: 30 additions & 0 deletions flang/lib/Optimizer/Transforms/CufOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ struct CufDeallocateOpConversion
}
};

static bool inDeviceContext(mlir::Operation *op) {
if (op->getParentOfType<cuf::KernelOp>())
return true;
if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
if (auto cudaProcAttr =
funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
cuf::getProcAttrName())) {
return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
}
}
return false;
}

struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;

Expand All @@ -157,6 +171,16 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
if (!boxTy)
return failure();

if (inDeviceContext(op.getOperation())) {
// In device context just replace the cuf.alloc operation with a fir.alloc
// the cuf.free will be removed.
rewriter.replaceOpWithNewOp<fir::AllocaOp>(
op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
op.getShape());
return mlir::success();
}

auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
Expand Down Expand Up @@ -200,6 +224,11 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy()))
return failure();

if (inDeviceContext(op.getOperation())) {
rewriter.eraseOp(op);
return mlir::success();
}

auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
Expand Down Expand Up @@ -248,6 +277,7 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
[](::cuf::AllocateOp op) { return isBoxGlobal(op); });
target.addDynamicallyLegalOp<cuf::DeallocateOp>(
[](::cuf::DeallocateOp op) { return isBoxGlobal(op); });
target.addLegalDialect<fir::FIROpsDialect>();
patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
CufFreeOpConversion>(ctx);
Expand Down
11 changes: 11 additions & 0 deletions flang/test/Fir/CUDA/cuda-allocate.fir
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ func.func @_QPsub3() {
// CHECK: cuf.allocate
// CHECK: cuf.deallocate

func.func @_QPsub4() attributes {cuf.proc_attr = #cuf.cuda_proc<device>} {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xf32>>> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Ea"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
%4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
cuf.free %4#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<device>}
return
}

// CHECK-LABEL: func.func @_QPsub4()
// CHECK: fir.alloca
// CHECK-NOT: cuf.free

}


0 comments on commit bf2f59c

Please sign in to comment.