Skip to content

Commit

Permalink
[flang][cuda] Copying device globals in the gpu module (#113955)
Browse files Browse the repository at this point in the history
  • Loading branch information
Renaud-K authored Oct 28, 2024
1 parent e873b41 commit 0eb5c9d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
27 changes: 27 additions & 0 deletions flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/allocatable.h"
#include "mlir/IR/SymbolTable.h"
Expand Down Expand Up @@ -58,6 +59,32 @@ class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> {
prepareImplicitDeviceGlobals(funcOp, symTable);
return mlir::WalkResult::advance();
});

// Copying the device global variable into the gpu module
mlir::SymbolTable parentSymTable(mod);
auto gpuMod =
parentSymTable.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
if (gpuMod) {
mlir::SymbolTable gpuSymTable(gpuMod);
for (auto globalOp : mod.getOps<fir::GlobalOp>()) {
auto attr = globalOp.getDataAttrAttr();
if (!attr)
continue;
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Constant:
case cuf::DataAttribute::Managed: {
auto globalName{globalOp.getSymbol().getValue()};
if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) {
break;
}
gpuSymTable.insert(globalOp->clone());
} break;
default:
break;
}
}
}
}
};
} // namespace
13 changes: 13 additions & 0 deletions flang/test/Fir/CUDA/cuda-device-global.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

// RUN: fir-opt --split-input-file --cuf-device-global %s | FileCheck %s


module attributes {fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module} {
fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>

gpu.module @cuda_device_mod [#nvvm.target] {
}
}

// CHECK: gpu.module @cuda_device_mod [#nvvm.target]
// CHECK-NEXT: fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>

0 comments on commit 0eb5c9d

Please sign in to comment.