diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 25fbfc37691182..961db45d500309 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -1246,26 +1246,23 @@ class OpenMPIRBuilder { getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack, StringRef ParentName = ""); - // using ReductionGenTy = - // function_ref; - - // using AtomicReductionGenTy = - // function_ref; - /// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used /// to /// store lambdas with capture. /// Functions used to generate reductions. Such functions take two Values /// representing LHS and RHS of the reduction, respectively, and a reference /// to the value that is updated to refer to the reduction result. - using ReductionGenTy = std::function; + using ReductionGenTy = + function_ref; + /// Functions used to generate atomic reductions. Such functions take two /// Values representing pointers to LHS and RHS of the reduction, as well as /// the element type of these pointers. They are expected to atomically /// update the LHS to the reduced value. - using AtomicReductionGenTy = std::function; + using AtomicReductionGenTy = + function_ref; + + /// Information about an OpenMP reduction. struct ReductionInfo { @@ -1275,10 +1272,6 @@ class OpenMPIRBuilder { : ElementType(ElementType), Variable(Variable), PrivateVariable(PrivateVariable), ReductionGen(ReductionGen), AtomicReductionGen(AtomicReductionGen) {} - ReductionInfo(Value *PrivateVariable) - : ElementType(nullptr), Variable(nullptr), - PrivateVariable(PrivateVariable), ReductionGen(), - AtomicReductionGen() {} /// Reduction element type, must match pointee type of variable. Type *ElementType; @@ -1301,56 +1294,6 @@ class OpenMPIRBuilder { AtomicReductionGenTy AtomicReductionGen; }; - /// A class that manages the reduction info to facilitate lowering of - /// reductions at multiple levels of parallelism. For example handling teams - /// and parallel reductions on GPUs - - class ReductionInfoManager { - private: - SmallVector ReductionInfos; - std::optional PrivateVarAllocaIP; - - public: - ReductionInfoManager() {}; - void clear() { - ReductionInfos.clear(); - PrivateVarAllocaIP.reset(); - } - - Value *allocatePrivateReductionVar( - IRBuilderBase &builder, - llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, - Type *VarType) { - llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext()); - llvm::Value *var = builder.CreateAlloca(VarType); - var->setName("private_redvar"); - llvm::Value *castVar = - builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy); - ReductionInfos.push_back(ReductionInfo(castVar)); - return castVar; - } - - ReductionInfo getReductionInfo(unsigned Index) { - return ReductionInfos[Index]; - } - ReductionInfo setReductionInfo(unsigned Index, ReductionInfo &RI) { - return ReductionInfos[Index] = RI; - } - Value *getPrivateReductionVariable(unsigned Index) { - return ReductionInfos[Index].PrivateVariable; - } - SmallVector &getReductionInfos() { - return ReductionInfos; - } - - bool hasPrivateVarAllocaIP() { return PrivateVarAllocaIP.has_value(); } - InsertPointTy getPrivateVarAllocaIP() { - assert(PrivateVarAllocaIP.has_value() && "AllocaIP not set"); - return *PrivateVarAllocaIP; - } - void setPrivateVarAllocaIP(InsertPointTy IP) { PrivateVarAllocaIP = IP; } - }; - /// \param Loc The location where the reduction was /// encountered. Must be within the associate /// directive and after the last local access to the @@ -1573,9 +1516,6 @@ class OpenMPIRBuilder { /// Info manager to keep track of target regions. OffloadEntriesInfoManager OffloadInfoManager; - /// Info manager to keep track of reduction information; - ReductionInfoManager RIManager; - /// The target triple of the underlying module. const Triple T; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 1dae7d4536ffb0..2db5a03c73399a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -440,18 +440,27 @@ static LogicalResult inlineConvertOmpRegions( } namespace { +/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to +/// store lambdas with capture. +using OwningReductionGen = std::function; +using OwningAtomicReductionGen = + std::function; } // namespace /// Create an OpenMPIRBuilder-compatible reduction generator for the given /// reduction declaration. The generator uses `builder` but ignores its /// insertion point. -static llvm::OpenMPIRBuilder::ReductionGenTy +static OwningReductionGen makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { // The lambda is mutable because we need access to non-const methods of decl // (which aren't actually mutating it), and we must capture decl by-value to // avoid the dangling reference after the parent function returns. - llvm::OpenMPIRBuilder::ReductionGenTy gen = + OwningReductionGen gen = [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Value *lhs, llvm::Value *rhs, llvm::Value *&result) mutable { @@ -475,17 +484,17 @@ makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder, /// given reduction declaration. The generator uses `builder` but ignores its /// insertion point. Returns null if there is no atomic region available in the /// reduction declaration. -static llvm::OpenMPIRBuilder::AtomicReductionGenTy +static OwningAtomicReductionGen makeAtomicReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { if (decl.getAtomicReductionRegion().empty()) - return llvm::OpenMPIRBuilder::AtomicReductionGenTy(); + return OwningAtomicReductionGen(); // The lambda is mutable because we need access to non-const methods of decl // (which aren't actually mutating it), and we must capture decl by-value to // avoid the dangling reference after the parent function returns. - llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = + OwningAtomicReductionGen atomicGen = [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *, llvm::Value *lhs, llvm::Value *rhs) mutable { Region &atomicRegion = decl.getAtomicReductionRegion(); @@ -774,48 +783,62 @@ convertOmpTaskgroupOp(omp::TaskGroupOp tgOp, llvm::IRBuilderBase &builder, template static void allocReductionVars(T loop, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVector &reductionDecls, + SmallVector &privateReductionVariables, DenseMap &reductionVariableMap) { llvm::IRBuilderBase::InsertPointGuard guard(builder); - if (!ompBuilder.RIManager.hasPrivateVarAllocaIP()) - ompBuilder.RIManager.setPrivateVarAllocaIP(allocaIP); - builder.restoreIP(ompBuilder.RIManager.getPrivateVarAllocaIP()); + builder.restoreIP(allocaIP); + auto args = + loop.getRegion().getArguments().take_back(loop.getNumReductionVars()); - unsigned numReductions = loop.getNumReductionVars(); - auto args = loop.getRegion().getArguments().take_back(numReductions); - for (unsigned i = 0; i < numReductions; ++i) { - llvm::Value *var = ompBuilder.RIManager.allocatePrivateReductionVar( - builder, allocaIP, + for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) { + llvm::Value *var = builder.CreateAlloca( moduleTranslation.convertType(reductionDecls[i].getType())); - moduleTranslation.mapValue(args[i], var); - reductionVariableMap.try_emplace(loop.getReductionVars()[i], var); + + var->setName("private_redvar"); + llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext()); + llvm::Value *castVar = + builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy); + + moduleTranslation.mapValue(args[i], castVar); + privateReductionVariables.push_back(castVar); + reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar); } } /// Collect reduction info template -static void -collectReductionInfo(T &loop, llvm::IRBuilderBase &builder, - llvm::OpenMPIRBuilder &ompBuilder, - LLVM::ModuleTranslation &moduleTranslation, - SmallVector &reductionDecls) { +static void collectReductionInfo( + T loop, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + SmallVector &reductionDecls, + SmallVector &owningReductionGens, + SmallVector &owningAtomicReductionGens, + const SmallVector &privateReductionVariables, + SmallVector &reductionInfos) { unsigned numReductions = loop.getNumReductionVars(); for (unsigned i = 0; i < numReductions; ++i) { + owningReductionGens.push_back( + makeReductionGen(reductionDecls[i], builder, moduleTranslation)); + owningAtomicReductionGens.push_back( + makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation)); + } + + // Collect the reduction information. + reductionInfos.reserve(numReductions); + + for (unsigned i = 0; i < numReductions; ++i) { + llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr; + if (owningAtomicReductionGens[i]) + atomicGen = owningAtomicReductionGens[i]; + llvm::Value *variable = moduleTranslation.lookupValue(loop.getReductionVars()[i]); - llvm::OpenMPIRBuilder::ReductionInfo RI = - ompBuilder.RIManager.getReductionInfo(i); - RI.Variable = variable; - RI.ElementType = - moduleTranslation.convertType(reductionDecls[i].getType()); - RI.ReductionGen = - makeReductionGen(reductionDecls[i], builder, moduleTranslation); - RI.AtomicReductionGen = - makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation); - ompBuilder.RIManager.setReductionInfo(i, RI); + reductionInfos.push_back( + {moduleTranslation.convertType(reductionDecls[i].getType()), variable, + privateReductionVariables[i], owningReductionGens[i], atomicGen}); } } @@ -864,9 +887,13 @@ static void getSinkableAllocas(LLVM::ModuleTranslation &moduleTranslation, } /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. -static LogicalResult -convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +static LogicalResult convertOmpWsLoop( + Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::OpenMPIRBuilder::InsertPointTy redAllocaIP, + SmallVector &owningReductionGens, + SmallVector &owningAtomicReductionGens, + SmallVector &reductionInfos) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); auto loop = cast(opInst); // TODO: this should be in the op verifier instead. @@ -888,12 +915,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, } SmallVector reductionDecls; collectReductionDecls(loop, reductionDecls); - llvm::OpenMPIRBuilder::InsertPointTy allocaIP = - findAllocaInsertPoint(builder, moduleTranslation); + SmallVector privateReductionVariables; DenseMap reductionVariableMap; - allocReductionVars(loop, builder, *ompBuilder, moduleTranslation, allocaIP, - reductionDecls, reductionVariableMap); + allocReductionVars(loop, builder, moduleTranslation, redAllocaIP, + reductionDecls, privateReductionVariables, + reductionVariableMap); // Store the mapping between reduction variables and their private copies on // ModuleTranslation stack. It can be then recovered when translating @@ -912,8 +939,7 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, return failure(); assert(phis.size() == 1 && "expected one value to be yielded from the " "reduction neutral element declaration region"); - builder.CreateStore(phis[0], - ompBuilder->RIManager.getPrivateReductionVariable(i)); + builder.CreateStore(phis[0], privateReductionVariables[i]); } // Set up the source location value for OpenMP runtime. @@ -992,7 +1018,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, llvm::CanonicalLoopInfo *loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}); - allocaIP = findAllocaInsertPoint(builder, moduleTranslation); + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); // TODO: Handle doacross loops when the ordered clause has a parameter. bool isOrdered = loop.getOrderedVal().has_value(); @@ -1029,8 +1056,10 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, // Create the reduction generators. We need to own them here because // ReductionInfo only accepts references to the generators. - collectReductionInfo(loop, builder, *ompBuilder, moduleTranslation, - reductionDecls); + collectReductionInfo(loop, builder, moduleTranslation, reductionDecls, + owningReductionGens, owningAtomicReductionGens, + privateReductionVariables, reductionInfos); + // The call to createReductions below expects the block to have a // terminator. Create an unreachable instruction to serve as terminator // and remove it later. @@ -1038,8 +1067,7 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, builder.SetInsertPoint(tempTerminator); llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint = - ompBuilder->createReductions(builder.saveIP(), allocaIP, - ompBuilder->RIManager.getReductionInfos(), + ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos, loop.getNowait(), /*IsTeamsReduction*/ false, /*HasDistribute*/ distributeCodeGen); if (!contInsertPoint.getBlock()) @@ -1049,12 +1077,24 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, tempTerminator->eraseFromParent(); builder.restoreIP(nextInsertionPoint); - if (!ompBuilder->Config.isGPU()) - ompBuilder->RIManager.clear(); - return success(); } +static LogicalResult +convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder::InsertPointTy redAllocaIP = + findAllocaInsertPoint(builder, moduleTranslation); + SmallVector owningReductionGens; + SmallVector owningAtomicReductionGens; + SmallVector reductionInfos; + + return convertOmpWsLoop(opInst, builder, moduleTranslation, redAllocaIP, + owningReductionGens, owningAtomicReductionGens, + reductionInfos); +} + + /// Converts the OpenMP parallel operation to LLVM IR. static LogicalResult convertOmpParallel(Operation &opInst1, llvm::IRBuilderBase &builder, @@ -1072,9 +1112,11 @@ convertOmpParallel(Operation &opInst1, llvm::IRBuilderBase &builder, collectReductionDecls(opInst, reductionDecls); // Allocate reduction vars + SmallVector privateReductionVariables; DenseMap reductionVariableMap; - allocReductionVars(opInst, builder, *ompBuilder, moduleTranslation, - allocaIP, reductionDecls, reductionVariableMap); + allocReductionVars(opInst, builder, moduleTranslation, allocaIP, + reductionDecls, privateReductionVariables, + reductionVariableMap); // Store the mapping between reduction variables and their private copies on // ModuleTranslation stack. It can be then recovered when translating @@ -1094,8 +1136,7 @@ convertOmpParallel(Operation &opInst1, llvm::IRBuilderBase &builder, "expected one value to be yielded from the " "reduction neutral element declaration region"); builder.restoreIP(allocaIP); - builder.CreateStore(phis[0], - ompBuilder->RIManager.getPrivateReductionVariable(i)); + builder.CreateStore(phis[0], privateReductionVariables[i]); } // Save the alloca insertion point on ModuleTranslation stack for use in @@ -1112,8 +1153,12 @@ convertOmpParallel(Operation &opInst1, llvm::IRBuilderBase &builder, // Process the reductions if required. if (opInst.getNumReductionVars() > 0) { // Collect reduction info - collectReductionInfo(opInst, builder, *ompBuilder, moduleTranslation, - reductionDecls); + SmallVector owningReductionGens; + SmallVector owningAtomicReductionGens; + SmallVector reductionInfos; + collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls, + owningReductionGens, owningAtomicReductionGens, + privateReductionVariables, reductionInfos); // Move to region cont block builder.SetInsertPoint(regionBlock->getTerminator()); @@ -1122,9 +1167,8 @@ convertOmpParallel(Operation &opInst1, llvm::IRBuilderBase &builder, llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable(); builder.SetInsertPoint(tempTerminator); llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint = - ompBuilder->createReductions( - builder.saveIP(), allocaIP, - ompBuilder->RIManager.getReductionInfos(), false, false, false); + ompBuilder->createReductions(builder.saveIP(), allocaIP, + reductionInfos, false, false, false); if (!contInsertPoint.getBlock()) { bodyGenStatus = opInst->emitOpError() << "failed to convert reductions"; return; @@ -1167,9 +1211,6 @@ convertOmpParallel(Operation &opInst1, llvm::IRBuilderBase &builder, ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind, isCancellable)); - if (!ompBuilder->Config.isGPU()) - ompBuilder->RIManager.clear(); - return bodyGenStatus; } @@ -2307,9 +2348,11 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, return bodyGenStatus; } -static LogicalResult -convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) { +static LogicalResult convertOmpDistribute( + Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::OpenMPIRBuilder::InsertPointTy *redAllocaIP, + SmallVector &reductionInfos) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; @@ -2325,11 +2368,11 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, // DistributeOp has only one region associated with it. builder.restoreIP(codeGenIP); - ompBuilder->RIManager.setPrivateVarAllocaIP(allocaIP); + *redAllocaIP = allocaIP; + mlir::Region& reg = opInst.getRegion(0); auto regionBlock = - convertOmpOpRegions(opInst.getRegion(0), "omp.distribute.region", + convertOmpOpRegions(reg, "omp.distribute.region", builder, moduleTranslation, bodyGenStatus); - builder.SetInsertPoint(regionBlock->getTerminator()); // FIXME(JAN): We need to know if we are inside a distribute and @@ -2341,9 +2384,8 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, auto IP = builder.saveIP(); if (ompBuilder->Config.isGPU()) { llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint = - ompBuilder->createReductions( - IP, allocaIP, ompBuilder->RIManager.getReductionInfos(), false, - true, true); + ompBuilder->createReductions(IP, allocaIP, reductionInfos, false, + true, true); builder.restoreIP(contInsertPoint); } }; @@ -2352,10 +2394,19 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB)); - return success(); } +static LogicalResult +convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + // No reductions are present so we just create dummy variables. + llvm::OpenMPIRBuilder::InsertPointTy dummyRedAllocaIP; + SmallVector dummyReductionInfos; + return convertOmpDistribute(opInst, builder, moduleTranslation, + &dummyRedAllocaIP, dummyReductionInfos); +} + /// Lowers the FlagsAttr which is applied to the module on the device /// pass when offloading, this attribute contains OpenMP RTL globals that can /// be passed as flags to the frontend, otherwise they are set to default @@ -2943,8 +2994,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, if (isTargetDevice) handleDeclareTargetMapVar(mapData, moduleTranslation, builder); - // Clear any reduction information - ompBuilder->RIManager.clear(); return bodyGenStatus; } @@ -3039,126 +3088,53 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, return success(); } -namespace { +/////////////////////////////////////////////////////////////////////////////// +// CombinedConstructs lowering forward declarations -/// Implementation of the dialect interface that converts operations belonging -/// to the OpenMP dialect to LLVM IR. -class OpenMPDialectLLVMIRTranslationInterface - : public LLVMTranslationDialectInterface { -public: - using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; +class OpenMPDialectLLVMIRTranslationInterface; - /// Translates the given operation to LLVM IR using the provided IR builder - /// and saving the state in `moduleTranslation`. - LogicalResult +using ConvertFunctionTy = std::function( + Operation *, llvm::IRBuilderBase &, LLVM::ModuleTranslation &)>; + +class ConversionDispatchList { +private: + llvm::SmallVector functions; + +public: + std::pair convertOperation(Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const final; + LLVM::ModuleTranslation &moduleTranslation) { + for (auto riter = functions.rbegin(); riter != functions.rend(); ++riter) { + bool match = false; + LogicalResult result = failure(); + std::tie(match, result) = (*riter)(op, builder, moduleTranslation); + if (match) + return { true, result }; + } + return {false, failure()}; + } - /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime - /// calls, or operation amendments - LogicalResult - amendOperation(Operation *op, ArrayRef instructions, - NamedAttribute attribute, - LLVM::ModuleTranslation &moduleTranslation) const final; + void pushConversionFunction(ConvertFunctionTy function) { + functions.push_back(function); + } + void popConversionFunction() { + functions.pop_back(); + } }; -} // namespace -LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( - Operation *op, ArrayRef instructions, - NamedAttribute attribute, - LLVM::ModuleTranslation &moduleTranslation) const { - return llvm::StringSwitch>( - attribute.getName()) - .Case("omp.is_target_device", - [&](Attribute attr) { - if (auto deviceAttr = attr.dyn_cast()) { - llvm::OpenMPIRBuilderConfig &config = - moduleTranslation.getOpenMPBuilder()->Config; - config.setIsTargetDevice(deviceAttr.getValue()); - return success(); - } - return failure(); - }) - .Case("omp.is_gpu", - [&](Attribute attr) { - if (auto gpuAttr = attr.dyn_cast()) { - llvm::OpenMPIRBuilderConfig &config = - moduleTranslation.getOpenMPBuilder()->Config; - config.setIsGPU(gpuAttr.getValue()); - return success(); - } - return failure(); - }) - .Case("omp.host_ir_filepath", - [&](Attribute attr) { - if (auto filepathAttr = attr.dyn_cast()) { - llvm::OpenMPIRBuilder *ompBuilder = - moduleTranslation.getOpenMPBuilder(); - ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue()); - return success(); - } - return failure(); - }) - .Case("omp.flags", - [&](Attribute attr) { - if (auto rtlAttr = attr.dyn_cast()) - return convertFlagsAttr(op, rtlAttr, moduleTranslation); - return failure(); - }) - .Case("omp.version", - [&](Attribute attr) { - if (auto versionAttr = attr.dyn_cast()) { - llvm::OpenMPIRBuilder *ompBuilder = - moduleTranslation.getOpenMPBuilder(); - ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp", - versionAttr.getVersion()); - return success(); - } - return failure(); - }) - .Case("omp.declare_target", - [&](Attribute attr) { - if (auto declareTargetAttr = - attr.dyn_cast()) - return convertDeclareTargetAttr(op, declareTargetAttr, - moduleTranslation); - return failure(); - }) - .Case( - "omp.requires", - [&](Attribute attr) { - if (auto requiresAttr = attr.dyn_cast()) { - using Requires = omp::ClauseRequires; - Requires flags = requiresAttr.getValue(); - llvm::OpenMPIRBuilderConfig &config = - moduleTranslation.getOpenMPBuilder()->Config; - config.setHasRequiresReverseOffload( - bitEnumContainsAll(flags, Requires::reverse_offload)); - config.setHasRequiresUnifiedAddress( - bitEnumContainsAll(flags, Requires::unified_address)); - config.setHasRequiresUnifiedSharedMemory( - bitEnumContainsAll(flags, Requires::unified_shared_memory)); - config.setHasRequiresDynamicAllocators( - bitEnumContainsAll(flags, Requires::dynamic_allocators)); - return success(); - } - return failure(); - }) - .Default([](Attribute) { - // Fall through for omp attributes that do not require lowering. - return success(); - })(attribute.getValue()); - - return failure(); -} +static LogicalResult convertOmpDistributeParallelWsLoop( + Operation *op, + omp::DistributeOp distribute, omp::ParallelOp parallel, + omp::WsLoopOp wsloop, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + ConversionDispatchList &dispatchList); -/// Given an OpenMP MLIR operation, create the corresponding LLVM IR -/// (including OpenMP runtime calls). -LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( +/////////////////////////////////////////////////////////////////////////////// +// Dispatch functions +static LogicalResult convertCommonOperation( Operation *op, llvm::IRBuilderBase &builder, - LLVM::ModuleTranslation &moduleTranslation) const { - + LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); return llvm::TypeSwitch(op) @@ -3251,7 +3227,7 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( }) .Case([&](omp::ThreadprivateOp) { return convertOmpThreadprivate(*op, builder, moduleTranslation); - }) + }) .Case( [&](auto op) { return convertOmpTargetData(op, builder, moduleTranslation); @@ -3274,6 +3250,309 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( }); } +// Returns true if the given block has a single instruction. +static bool singleInstrBlock(Block &block) { + bool result = (block.getOperations().size() == 2); + if (!result) { + llvm::errs() << "Num ops: " << block.getOperations().size() << "\n"; + } + return result; +} + +// Returns the operation if it only contains one instruction otherwise +// return nullptr. +template +Operation *getContainedInstr(OpType op) { + Region ®ion = op.getRegion(); + if (!region.hasOneBlock()) { + llvm::errs() << "Region has multiple blocks\n"; + return nullptr; + } + Block &block = region.front(); + if (!singleInstrBlock(block)) { + return nullptr; + } + return &(block.getOperations().front()); +} + +// Returns the operation if it only contains one instruction otherwise +// return nullptr. +template +Block &getContainedBlock(OpType op) { + Region ®ion = op.getRegion(); + return region.front(); +} + + +template +bool matchOpNest(Operation *op, OpTypes... matchOp) { + return true; +} + +template +bool matchOpNestScan(Block &op, OpTypes... matchOp) { + return true; +} + +template +bool matchOpNest(Operation *op, FirstOpType &firstOp, RestOpTypes... restOps) { + if (auto firstOp = mlir::dyn_cast(op)) { + if (sizeof...(RestOpTypes) == 0) + return true; + Block &innerBlock = getContainedBlock(firstOp); + return matchOpNestScan(innerBlock, restOps...); + } + return false; +} + +template +bool matchOpScanNest(Block &block, FirstOpType &firstOp, RestOpTypes... restOps) { + for (Operation *op : block) { + if (auto firstOp = mlir::dyn_cast(op)) { + if (sizeof...(RestOpTypes) == 0) + return true; + Block &innerBlock = getContainedBlock(firstOp); + return matchOpNestScan(innerBlock, restOps...); + } + } + return false; +} + +static LogicalResult +convertInternalTargetOp(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + ConversionDispatchList &dispatchList) { + + omp::DistributeOp distribute; + omp::ParallelOp parallel; + omp::WsLoopOp wsloop; + // Match composite constructs + if (matchOpNest(op, distribute, parallel, wsloop)) { + return convertOmpDistributeParallelWsLoop(op, distribute, parallel, wsloop, + builder, moduleTranslation, + dispatchList); + } + + return convertCommonOperation(op, builder, moduleTranslation); +} + +static LogicalResult +convertTopLevelTargetOp(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + if (isa(op)) + return convertOmpTarget(*op, builder, moduleTranslation); + bool interrupted = + op->walk([&](omp::TargetOp targetOp) { + if (failed(convertOmpTarget(*targetOp, builder, moduleTranslation))) + return WalkResult::interrupt(); + return WalkResult::skip(); + }).wasInterrupted(); + return failure(interrupted); +} + +/// Implementation of the dialect interface that converts operations belonging +/// to the OpenMP dialect to LLVM IR. +class OpenMPDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +private: + mutable ConversionDispatchList dispatchList; + +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final; + + /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime + /// calls, or operation amendments + LogicalResult + amendOperation(Operation *op, ArrayRef instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final; +}; + +// Implementation converting a nest of operations in a single function. This +// just overrides the parallel and wsloop dispatches but does the normal +// lowering for now. +static LogicalResult convertOmpDistributeParallelWsLoop( + Operation *op, omp::DistributeOp distribute, omp::ParallelOp parallel, + omp::WsLoopOp wsloop, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + ConversionDispatchList &dispatchList) { + + // Reduction related data structures + SmallVector owningReductionGens; + SmallVector owningAtomicReductionGens; + SmallVector reductionInfos; + llvm::OpenMPIRBuilder::InsertPointTy redAllocaIP; + + // Convert wsloop alternative implementation + ConvertFunctionTy convertWsLoop = [&redAllocaIP, &owningReductionGens, + &owningAtomicReductionGens, + &reductionInfos]( + Operation *op, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation + &moduleTranslation) { + if (!isa(op)) { + return std::make_pair(false, failure()); + } + + LogicalResult result = convertOmpWsLoop( + *op, builder, moduleTranslation, redAllocaIP, owningReductionGens, + owningAtomicReductionGens, reductionInfos); + return std::make_pair(true, result); + }; + + // Push the new alternative functions + dispatchList.pushConversionFunction(convertWsLoop); + + // Lower the current distribute operation + LogicalResult result = convertOmpDistribute(*op, builder, moduleTranslation, + &redAllocaIP, reductionInfos); + + // Pop the alternative functions + dispatchList.popConversionFunction(); + + return result; +} + +LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( + Operation *op, ArrayRef instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const { + return llvm::StringSwitch>( + attribute.getName()) + .Case("omp.is_target_device", + [&](Attribute attr) { + if (auto deviceAttr = attr.dyn_cast()) { + llvm::OpenMPIRBuilderConfig &config = + moduleTranslation.getOpenMPBuilder()->Config; + config.setIsTargetDevice(deviceAttr.getValue()); + return success(); + } + return failure(); + }) + .Case("omp.is_gpu", + [&](Attribute attr) { + if (auto gpuAttr = attr.dyn_cast()) { + llvm::OpenMPIRBuilderConfig &config = + moduleTranslation.getOpenMPBuilder()->Config; + config.setIsGPU(gpuAttr.getValue()); + return success(); + } + return failure(); + }) + .Case("omp.host_ir_filepath", + [&](Attribute attr) { + if (auto filepathAttr = attr.dyn_cast()) { + llvm::OpenMPIRBuilder *ompBuilder = + moduleTranslation.getOpenMPBuilder(); + ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue()); + return success(); + } + return failure(); + }) + .Case("omp.flags", + [&](Attribute attr) { + if (auto rtlAttr = attr.dyn_cast()) + return convertFlagsAttr(op, rtlAttr, moduleTranslation); + return failure(); + }) + .Case("omp.version", + [&](Attribute attr) { + if (auto versionAttr = attr.dyn_cast()) { + llvm::OpenMPIRBuilder *ompBuilder = + moduleTranslation.getOpenMPBuilder(); + ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp", + versionAttr.getVersion()); + return success(); + } + return failure(); + }) + .Case("omp.declare_target", + [&](Attribute attr) { + if (auto declareTargetAttr = + attr.dyn_cast()) + return convertDeclareTargetAttr(op, declareTargetAttr, + moduleTranslation); + return failure(); + }) + .Case( + "omp.requires", + [&](Attribute attr) { + if (auto requiresAttr = attr.dyn_cast()) { + using Requires = omp::ClauseRequires; + Requires flags = requiresAttr.getValue(); + llvm::OpenMPIRBuilderConfig &config = + moduleTranslation.getOpenMPBuilder()->Config; + config.setHasRequiresReverseOffload( + bitEnumContainsAll(flags, Requires::reverse_offload)); + config.setHasRequiresUnifiedAddress( + bitEnumContainsAll(flags, Requires::unified_address)); + config.setHasRequiresUnifiedSharedMemory( + bitEnumContainsAll(flags, Requires::unified_shared_memory)); + config.setHasRequiresDynamicAllocators( + bitEnumContainsAll(flags, Requires::dynamic_allocators)); + return success(); + } + return failure(); + }) + .Default([](Attribute) { + // Fall through for omp attributes that do not require lowering. + return success(); + })(attribute.getValue()); + + return failure(); +} + +static bool isInternalTargetDeviceOp(Operation *op) { + // Assumes no reverse offloading + if (op->getParentOfType()) + return true; + + if (auto parentFn = op->getParentOfType()) + if (auto declareTargetIface = + llvm::dyn_cast( + parentFn.getOperation())) + if (declareTargetIface.isDeclareTarget() && + declareTargetIface.getDeclareTargetDeviceType() != + mlir::omp::DeclareTargetDeviceType::host) + return true; + + return false; +} + +/// Given an OpenMP MLIR operation, create the corresponding LLVM IR +/// (including OpenMP runtime calls). +LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( + Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const { + + // Check to see if there is a lowering that overrides the default lowering + // if not use the default dispatch. + bool match = false; + LogicalResult result = success(); + std::tie(match, result) = + dispatchList.convertOperation(op, builder, moduleTranslation); + if (match) + return result; + + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + if (ompBuilder->Config.isTargetDevice()) { + if (isInternalTargetDeviceOp(op)) { + return convertInternalTargetOp(op, builder, moduleTranslation, dispatchList); + } else { + return convertTopLevelTargetOp(op, builder, moduleTranslation); + } + } + + return convertCommonOperation(op, builder, moduleTranslation); +} + void mlir::registerOpenMPDialectTranslation(DialectRegistry ®istry) { registry.insert(); registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) { diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir index afbf5f22246309..0cccf890782ec6 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir @@ -4,7 +4,7 @@ // for nested omp do loop inside omp target region module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } { - llvm.func @target_parallel_wsloop(%arg0: !llvm.ptr) attributes { + llvm.func @target_parallel_wsloop(%arg0: !llvm.ptr) attributes { omp.declare_target = #omp.declaretarget, target_cpu = "gfx90a", target_features = #llvm.target_features<["+gfx9-insts", "+wavefrontsize64"]> } { diff --git a/mlir/test/Target/LLVMIR/omptarget-teams-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-teams-llvm.mlir index 96cced7a1d584b..c5f89eb2c3274c 100644 --- a/mlir/test/Target/LLVMIR/omptarget-teams-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-teams-llvm.mlir @@ -5,7 +5,7 @@ module attributes {omp.is_target_device = true} { llvm.func @foo(i32) - llvm.func @omp_target_teams_shared_simple(%arg0 : i32) { + llvm.func @omp_target_teams_shared_simple(%arg0 : i32) attributes {omp.declare_target = #omp.declaretarget} { omp.teams { llvm.call @foo(%arg0) : (i32) -> () omp.terminator diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir index 435aca32450c2f..ebd39ed8601f98 100644 --- a/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir @@ -4,7 +4,7 @@ // for nested omp do loop with collapse clause inside omp target region module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } { - llvm.func @target_collapsed_wsloop(%arg0: !llvm.ptr) { + llvm.func @target_collapsed_wsloop(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget} { %loop_ub = llvm.mlir.constant(99 : i32) : i32 %loop_lb = llvm.mlir.constant(0 : i32) : i32 %loop_step = llvm.mlir.constant(1 : index) : i32 diff --git a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir index 4cfb7d4f695143..9246a1bdd85370 100644 --- a/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-wsloop.mlir @@ -4,7 +4,7 @@ // for nested omp do loop inside omp target region module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } { - llvm.func @target_wsloop(%arg0: !llvm.ptr ){ + llvm.func @target_wsloop(%arg0: !llvm.ptr ) attributes {omp.declare_target = #omp.declaretarget} { %loop_ub = llvm.mlir.constant(9 : i32) : i32 %loop_lb = llvm.mlir.constant(0 : i32) : i32 %loop_step = llvm.mlir.constant(1 : i32) : i32 @@ -16,7 +16,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo llvm.return } - llvm.func @target_empty_wsloop(){ + llvm.func @target_empty_wsloop() attributes {omp.declare_target = #omp.declaretarget} { %loop_ub = llvm.mlir.constant(9 : i32) : i32 %loop_lb = llvm.mlir.constant(0 : i32) : i32 %loop_step = llvm.mlir.constant(1 : i32) : i32