diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp index f059d36315a345..d391ede82c2707 100644 --- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp @@ -141,6 +141,20 @@ struct CufDeallocateOpConversion } }; +static bool inDeviceContext(mlir::Operation *op) { + if (op->getParentOfType()) + return true; + if (auto funcOp = op->getParentOfType()) { + if (auto cudaProcAttr = + funcOp.getOperation()->getAttrOfType( + cuf::getProcAttrName())) { + return cudaProcAttr.getValue() != cuf::ProcAttribute::Host && + cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice; + } + } + return false; +} + struct CufAllocOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -157,6 +171,16 @@ struct CufAllocOpConversion : public mlir::OpRewritePattern { if (!boxTy) return failure(); + if (inDeviceContext(op.getOperation())) { + // In device context just replace the cuf.alloc operation with a fir.alloc + // the cuf.free will be removed. + rewriter.replaceOpWithNewOp( + op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "", + op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(), + op.getShape()); + return mlir::success(); + } + auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); @@ -200,6 +224,11 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern { if (!mlir::isa(refTy.getEleTy())) return failure(); + if (inDeviceContext(op.getOperation())) { + rewriter.eraseOp(op); + return mlir::success(); + } + auto mod = op->getParentOfType(); fir::FirOpBuilder builder(rewriter, mod); mlir::Location loc = op.getLoc(); @@ -248,6 +277,7 @@ class CufOpConversion : public fir::impl::CufOpConversionBase { [](::cuf::AllocateOp op) { return isBoxGlobal(op); }); target.addDynamicallyLegalOp( [](::cuf::DeallocateOp op) { return isBoxGlobal(op); }); + target.addLegalDialect(); patterns.insert(ctx, &*dl, &typeConverter); patterns.insert(ctx); diff --git a/flang/test/Fir/CUDA/cuda-allocate.fir b/flang/test/Fir/CUDA/cuda-allocate.fir index 569e72f57d6d6c..a9bc7a8518e90e 100644 --- a/flang/test/Fir/CUDA/cuda-allocate.fir +++ b/flang/test/Fir/CUDA/cuda-allocate.fir @@ -57,6 +57,17 @@ func.func @_QPsub3() { // CHECK: cuf.allocate // CHECK: cuf.deallocate +func.func @_QPsub4() attributes {cuf.proc_attr = #cuf.cuda_proc} { + %0 = cuf.alloc !fir.box>> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QFsub1Ea"} -> !fir.ref>>> + %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda, fortran_attrs = #fir.var_attrs, uniq_name = "_QFsub1Ea"} : (!fir.ref>>>) -> (!fir.ref>>>, !fir.ref>>>) + cuf.free %4#1 : !fir.ref>>> {data_attr = #cuf.cuda} + return +} + +// CHECK-LABEL: func.func @_QPsub4() +// CHECK: fir.alloca +// CHECK-NOT: cuf.free + }