Skip to content

Commit

Permalink
enable XeVM test using gpu-runner
Browse files Browse the repository at this point in the history
  • Loading branch information
akroviakov committed Jan 20, 2025
1 parent 5540c6f commit fede4a8
Show file tree
Hide file tree
Showing 20 changed files with 360 additions and 29 deletions.
15 changes: 15 additions & 0 deletions lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "gc/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -398,6 +399,20 @@ struct ConvertXeVMToLLVMPass
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalDialect<XeVMDialect>();
RewritePatternSet patterns(&getContext());
auto mod = llvm::dyn_cast<ModuleOp>(getOperation());
auto trySetSGAttr = [&](Operation *op) {
if (!op->hasAttr("intel_reqd_sub_group_size")) {
op->setAttr("intel_reqd_sub_group_size",
IntegerAttr::get(IntegerType::get(&getContext(), 32), 16));
return mlir::WalkResult::interrupt();
}
};
mod.walk([&](mlir::gpu::GPUModuleOp gpuModule) {
gpuModule.walk(
[&](mlir::gpu::GPUFuncOp funcOp) { trySetSGAttr(funcOp); });
gpuModule.walk([&](mlir::func::FuncOp funcOp) { trySetSGAttr(funcOp); });
});

populateXeVMToLLVMConversionPatterns(patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
71 changes: 57 additions & 14 deletions lib/gc/Transforms/GPU/OCL/GpuToGpuOcl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,13 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {

int i = 0;
for (auto arg : kernelArgs) {
if (auto type = gpuLaunch.getKernelOperand(i++).getType();
if (auto type = gpuLaunch.getKernelOperand(i).getType();
isa<MemRefType>(type)) {
MemRefDescriptor desc(arg);
args.emplace_back(desc.alignedPtr(rewriter, loc));
} else if (auto type = gpuLaunch.getKernelOperand(i).getType();
isa<LLVM::LLVMPointerType>(type)) {
args.emplace_back(arg);
} else {
// Store the arg on the stack and pass the pointer
auto ptr = rewriter.create<LLVM::AllocaOp>(
Expand All @@ -258,6 +261,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
rewriter.create<LLVM::StoreOp>(loc, arg, ptr);
args.emplace_back(ptr);
}
i++;
}

const auto gpuOclLaunch =
Expand Down Expand Up @@ -352,32 +356,67 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
.getResult();
}

// Create a new kernel and save the pointer to the global variable
// ...name_Ptr.
bool createKernel(
gpu::LaunchFuncOp &gpuLaunch, OpAdaptor &adaptor,
ConversionPatternRewriter &rewriter, const Location &loc, ModuleOp &mod,
StringRef funcName,
const std::function<SmallString<128> &(const char *chars)> &str) const {
auto kernelModName = gpuLaunch.getKernelModuleName();
StringAttr getBinaryAttrIMEX(ConversionPatternRewriter &rewriter,
gpu::LaunchFuncOp &gpuLaunch,
StringAttr kernelModName) const {
StringAttr binaryAttr;
auto kernelMod = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
gpuLaunch, kernelModName);
if (!kernelMod) {
gpuLaunch.emitOpError() << "Module " << kernelModName << " not found!";
return false;
return {};
}
const auto binaryAttr = kernelMod->getAttrOfType<StringAttr>("gpu.binary");
binaryAttr = kernelMod->getAttrOfType<StringAttr>("gpu.binary");
if (!binaryAttr) {
kernelMod.emitOpError() << "missing 'gpu.binary' attribute";
return false;
return {};
}
rewriter.eraseOp(kernelMod);
return binaryAttr;
}

StringAttr getBinaryAttrUpstream(ConversionPatternRewriter &rewriter,
gpu::LaunchFuncOp &gpuLaunch,
StringAttr kernelModName) const {
StringAttr binaryAttr;
auto gpuBin = SymbolTable::lookupNearestSymbolFrom<gpu::BinaryOp>(
gpuLaunch, kernelModName);
if (!gpuBin) {
gpuLaunch.emitOpError() << "Binary " << kernelModName << " not found!";
return {};
}
if (gpuBin.getObjects().size() != 1) {
gpuLaunch.emitOpError() << "Many targets present in " << kernelModName
<< ", please use xevm only.";
return {};
}
binaryAttr = cast<gpu::ObjectAttr>(gpuBin.getObjects()[0]).getObject();
if (!binaryAttr) {
gpuBin.emitOpError() << "missing binary object.";
return {};
}
return binaryAttr;
}

// Create a new kernel and save the pointer to the global variable
// ...name_Ptr.
bool createKernel(
gpu::LaunchFuncOp &gpuLaunch, OpAdaptor &adaptor,
ConversionPatternRewriter &rewriter, const Location &loc, ModuleOp &mod,
StringRef funcName,
const std::function<SmallString<128> &(const char *chars)> &str) const {
auto kernelModName = gpuLaunch.getKernelModuleName();
#ifdef GC_USE_IMEX
auto binaryAttr = getBinaryAttrIMEX(rewriter, gpuLaunch, kernelModName);
#else
auto binaryAttr = getBinaryAttrUpstream(rewriter, gpuLaunch, kernelModName);
#endif

rewriter.setInsertionPointToStart(mod.getBody());
// The kernel pointer is stored here
rewriter.create<LLVM::GlobalOp>(loc, helper.ptrType, /*isConstant=*/false,
LLVM::Linkage::Internal, str("Ptr"),
rewriter.getZeroAttr(helper.ptrType));
rewriter.eraseOp(kernelMod);

auto function = rewriter.create<LLVM::LLVMFuncOp>(
loc, funcName,
Expand Down Expand Up @@ -415,7 +454,7 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
for (auto arg : gpuLaunch.getKernelOperands()) {
auto type = arg.getType();
size_t size;
if (isa<MemRefType>(type)) {
if (isa<MemRefType>(type) || isa<LLVM::LLVMPointerType>(type)) {
size = 0; // A special case for pointers
} else if (type.isIndex()) {
size = helper.idxType.getIntOrFloatBitWidth() / 8;
Expand Down Expand Up @@ -452,6 +491,8 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
assert(getConstantIntValue(cast.getOperand(0)));
value = helper.idxConstant(
rewriter, loc, getConstantIntValue(cast.getOperand(0)).value());
} else {
value = rewriter.clone(*value.getDefiningOp())->getResult(0);
}
rewriter.create<LLVM::StoreOp>(loc, value, elementPtr);
}
Expand Down Expand Up @@ -527,6 +568,8 @@ struct GpuToGpuOcl final : gc::impl::GpuToGpuOclBase<GpuToGpuOcl> {
return;
}

if (!helper.kernelNames.size())
return;
// Add gpuOclDestructor() function that destroys all the kernels
auto mod = llvm::dyn_cast<ModuleOp>(getOperation());
assert(mod);
Expand Down
15 changes: 8 additions & 7 deletions lib/gc/Transforms/GPU/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ void registerIMEXPipeline() {
#ifdef GC_USE_GPU
void populateGPUPipeline(OpPassManager &pm,
const GPUPipelineOptions &pipelineOpts) {

pm.addNestedPass<func::FuncOp>(createAddContextArg());

pm.addPass(createConvertSCFToCFPass());
Expand All @@ -148,20 +147,22 @@ void populateGPUPipeline(OpPassManager &pm,
pm.addPass(createArithToLLVMConversionPass());
pm.addPass(createConvertFuncToLLVMPass());
pm.addPass(createConvertMathToLLVMPass());
pm.addPass(createCSEPass());
pm.addPass(createReconcileUnrealizedCastsPass());

// Convert allocs, etc.
pm.addPass(createGpuToGpuOcl({pipelineOpts.callFinish}));
pm.addPass(createGpuKernelOutliningPass());
pm.addPass(createConvertXeVMToLLVMPass());
pm.addPass(createGpuXeVMAttachTarget());
pm.addPass(createConvertGpuOpsToLLVMSPVOps());
pm.addPass(createGpuToLLVMConversionPass());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertGpuOpsToLLVMSPVOps());
pm.addNestedPass<gpu::GPUModuleOp>(createConvertIndexToLLVMPass());
pm.addNestedPass<gpu::GPUModuleOp>(createArithToLLVMConversionPass());
pm.addPass(createReconcileUnrealizedCastsPass());
pm.addPass(createCSEPass());
// Convert allocs, etc.
pm.addPass(createGpuToGpuOcl({pipelineOpts.callFinish}));
pm.addPass(createGpuModuleToBinaryPass());
// Convert launch given a binary.
pm.addPass(createGpuToGpuOcl({pipelineOpts.callFinish}));
pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addPass(createReconcileUnrealizedCastsPass());
}

void registerGPUPipeline() {
Expand Down
98 changes: 98 additions & 0 deletions test/mlir/test/gc/Transforms/GPU/IMEX/gpu-to-gpuocl.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// RUN: gc-opt %s --gpu-to-gpuocl | FileCheck %s

module @test attributes {gpu.container_module} {
llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64) attributes {llvm.emit_c_interface} {
%0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%6 = llvm.insertvalue %arg5, %5[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%8 = builtin.unrealized_conversion_cast %7 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<64x64xf32>
%gpu_mem = gpu.alloc host_shared () : memref<64x64xf32>
gpu.memcpy %gpu_mem, %8 : memref<64x64xf32>, memref<64x64xf32>
%9 = llvm.mlir.constant(32 : index) : i64
%10 = builtin.unrealized_conversion_cast %9 : i64 to index
%11 = llvm.mlir.constant(2 : index) : i64
%12 = builtin.unrealized_conversion_cast %11 : i64 to index
%13 = llvm.mlir.constant(1 : index) : i64
%14 = builtin.unrealized_conversion_cast %13 : i64 to index
gpu.launch_func @entry_kernel::@entry_kernel blocks in (%12, %12, %14) threads in (%14, %14, %14) args(%10 : index, %gpu_mem : memref<64x64xf32>)
gpu.memcpy %8, %gpu_mem : memref<64x64xf32>, memref<64x64xf32>
gpu.dealloc %gpu_mem : memref<64x64xf32>
llvm.return
}

gpu.module @entry_kernel attributes {gpu.binary = "Some SPIRV here \00"} {
gpu.func @entry_kernel(%arg0: index, %arg1: memref<64x64xf32>) kernel attributes {} {
gpu.return
}
}
}

// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_SPIRV
// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_Name
// CHECK: llvm.mlir.global internal @gcGpuOclKernel_entry_kernel_Ptr

// CHECK: llvm.func @createGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr
// CHECK: [[NEW_PTR:%.+]] = llvm.call @gcGpuOclKernelCreate([[CTX]]
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[CMPXCHG:%.+]] = llvm.cmpxchg [[PTR_ADDR]], [[ZERO]], [[NEW_PTR]]
// CHECK: [[FLAG:%.+]] = llvm.extractvalue [[CMPXCHG]][1]
// CHECK: llvm.cond_br [[FLAG]], [[BB1:\^.+]], [[BB2:\^.+]]
// CHECK: [[BB1]]:
// CHECK: llvm.return [[NEW_PTR]]
// CHECK: [[BB2]]:
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]]
// CHECK: llvm.store [[NEW_PTR]], [[ADDR]]
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
// CHECK: [[OLD_PTR:%.+]] = llvm.extractvalue [[CMPXCHG]][0]
// CHECK: llvm.return [[OLD_PTR]]

// CHECK: llvm.func internal @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline}
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
// CHECK: [[ICMP:%.+]] = llvm.icmp "eq" [[PTR]], [[ZERO]]
// CHECK: llvm.cond_br [[ICMP]], [[BB1:\^.+]], [[BB2:\^.+]]
// CHECK: [[BB1]]:
// CHECK: [[NEW_PTR:%.+]] = llvm.call @createGcGpuOclKernel_entry_kernel([[CTX]])
// CHECK: llvm.return [[NEW_PTR]]
// CHECK: [[BB2]]:
// CHECK: llvm.return [[PTR]]

// CHECK: llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, [[CTX:%.+]]: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64)
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: llvm.call @gcGpuOclMallocShared([[CTX]], [[SIZE]])
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: [[SRC:%.+]] = llvm.extractvalue
// CHECK: [[DST:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1]
// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]])
// CHECK: [[KERNEL:%.+]] = llvm.call @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: llvm.call @gcGpuOclKernelLaunch([[CTX]], [[KERNEL]],
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: [[SRC:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1]
// CHECK: [[DST:%.+]] = llvm.extractvalue
// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]])
// CHECK: [[GPU_PTR:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][0]
// CHECK: llvm.call @gcGpuOclDealloc([[CTX]], [[GPU_PTR]])

// CHECK: llvm.func @gcGpuOclKernelCreate
// CHECK: llvm.func @gcGpuOclKernelDestroy
// CHECK: llvm.func @gcGpuOclKernelLaunch


// CHECK: llvm.func @gcGpuOclModuleDestructor()
// CHECK: llvm.fence acquire
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]]
// CHECK: llvm.store [[PTR]], [[ADDR]]
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
102 changes: 102 additions & 0 deletions test/mlir/test/gc/Transforms/GPU/OCL/gpu-to-gpuocl-inlined.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: gc-opt %s --gc-gpu-pipeline | FileCheck %s

module @test attributes {gpu.container_module} {
llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64) attributes {llvm.emit_c_interface} {
%0 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%3 = llvm.insertvalue %arg2, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%4 = llvm.insertvalue %arg3, %3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%5 = llvm.insertvalue %arg4, %4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%6 = llvm.insertvalue %arg5, %5[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%7 = llvm.insertvalue %arg6, %6[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%8 = builtin.unrealized_conversion_cast %7 : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<64x64xf32>
%gpu_mem = gpu.alloc host_shared () : memref<64x64xf32>
gpu.memcpy %gpu_mem, %8 : memref<64x64xf32>, memref<64x64xf32>
%9 = llvm.mlir.constant(32 : index) : i64
%10 = builtin.unrealized_conversion_cast %9 : i64 to index
%11 = llvm.mlir.constant(2 : index) : i64
%12 = builtin.unrealized_conversion_cast %11 : i64 to index
%13 = llvm.mlir.constant(1 : index) : i64
%14 = builtin.unrealized_conversion_cast %13 : i64 to index

%floaat = llvm.mlir.constant(1.1 : f32) : f32
%a_ptr_as_idx = memref.extract_aligned_pointer_as_index %gpu_mem : memref<64x64xf32> -> index
%a_ptr_as_i64 = arith.index_cast %a_ptr_as_idx : index to i64
%a_ptr = llvm.inttoptr %a_ptr_as_i64 : i64 to !llvm.ptr
%a_ptr_casted = llvm.addrspacecast %a_ptr : !llvm.ptr to !llvm.ptr<1>

gpu.launch blocks(%arg10, %arg11, %arg12) in (%arg16 = %12, %arg17 = %12, %arg18 = %12) threads(%arg13, %arg14, %arg15) in (%arg19 = %14, %arg20 = %14, %arg21 = %14) {
llvm.store %floaat, %a_ptr_casted : f32, !llvm.ptr<1>
gpu.terminator
}
gpu.memcpy %8, %gpu_mem : memref<64x64xf32>, memref<64x64xf32>
gpu.dealloc %gpu_mem : memref<64x64xf32>
llvm.return
}
}

// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_SPIRV
// CHECK: llvm.mlir.global internal constant @gcGpuOclKernel_entry_kernel_Name
// CHECK: llvm.mlir.global internal @gcGpuOclKernel_entry_kernel_Ptr

// CHECK: llvm.func @createGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr
// CHECK: [[NEW_PTR:%.+]] = llvm.call @gcGpuOclKernelCreate([[CTX]]
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[CMPXCHG:%.+]] = llvm.cmpxchg [[PTR_ADDR]], [[ZERO]], [[NEW_PTR]]
// CHECK: [[FLAG:%.+]] = llvm.extractvalue [[CMPXCHG]][1]
// CHECK: llvm.cond_br [[FLAG]], [[BB1:\^.+]], [[BB2:\^.+]]
// CHECK: [[BB1]]:
// CHECK: llvm.return [[NEW_PTR]]
// CHECK: [[BB2]]:
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]]
// CHECK: llvm.store [[NEW_PTR]], [[ADDR]]
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
// CHECK: [[OLD_PTR:%.+]] = llvm.extractvalue [[CMPXCHG]][0]
// CHECK: llvm.return [[OLD_PTR]]

// CHECK: llvm.func internal @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]: !llvm.ptr) -> !llvm.ptr attributes {always_inline}
// CHECK: [[ZERO:%.+]] = llvm.mlir.zero
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
// CHECK: [[ICMP:%.+]] = llvm.icmp "eq" [[PTR]], [[ZERO]]
// CHECK: llvm.cond_br [[ICMP]], [[BB1:\^.+]], [[BB2:\^.+]]
// CHECK: [[BB1]]:
// CHECK: [[NEW_PTR:%.+]] = llvm.call @createGcGpuOclKernel_entry_kernel([[CTX]])
// CHECK: llvm.return [[NEW_PTR]]
// CHECK: [[BB2]]:
// CHECK: llvm.return [[PTR]]

// CHECK: llvm.func @entry(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, [[CTX:%.+]]: !llvm.ptr, %arg8: !llvm.ptr, %arg9: i64)
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: llvm.call @gcGpuOclMallocShared([[CTX]], [[SIZE]])
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: [[SRC:%.+]] = llvm.extractvalue
// CHECK: [[DST:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1]
// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]])
// CHECK: [[KERNEL:%.+]] = llvm.call @getGcGpuOclKernel_entry_kernel([[CTX:%.+]]) : (!llvm.ptr) -> !llvm.ptr
// CHECK: llvm.call @gcGpuOclKernelLaunch([[CTX]], [[KERNEL]],
// CHECK: [[SIZE:%.+]] = llvm.mlir.constant(16384 : i64) : i64
// CHECK: [[SRC:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][1]
// CHECK: [[DST:%.+]] = llvm.extractvalue
// CHECK: llvm.call @gcGpuOclMemcpy([[CTX]], [[SRC]], [[DST]], [[SIZE]])
// CHECK: [[GPU_PTR:%.+]] = llvm.extractvalue [[GPU_MEMREF:%.+]][0]
// CHECK: llvm.call @gcGpuOclDealloc([[CTX]], [[GPU_PTR]])

// CHECK: llvm.func @gcGpuOclKernelCreate
// CHECK: llvm.func @gcGpuOclKernelDestroy
// CHECK: llvm.func @gcGpuOclKernelLaunch


// CHECK: llvm.func @gcGpuOclModuleDestructor()
// CHECK: llvm.fence acquire
// CHECK: [[PTR_ADDR:%.+]] = llvm.mlir.addressof @gcGpuOclKernel_entry_kernel_Ptr
// CHECK: [[PTR:%.+]] = llvm.load [[PTR_ADDR]]
// CHECK: [[ONE:%.+]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: [[ARRAY:%.+]] = llvm.alloca [[ONE]]
// CHECK: [[ADDR:%.+]] = llvm.getelementptr [[ARRAY]]
// CHECK: llvm.store [[PTR]], [[ADDR]]
// CHECK: llvm.call @gcGpuOclKernelDestroy([[ONE]], [[ARRAY]])
Loading

0 comments on commit fede4a8

Please sign in to comment.