Skip to content

Commit

Permalink
add an entry point wrapper around functions
Browse files Browse the repository at this point in the history
SPIR-V spec states:
"It is invalid for any function to be targeted by both an OpEntryPoint instruction
 and an OpFunctionCall instruction."

In order to satisfy SPIR-V that entrypoints and functions
must be different, this introduces an entrypoint wrapper around
functions at the LLVM IR level, then fixes up a few things like
naming at the SPIRV translation.

It's necessary to fixup the spirv metadata to point at the new kernel entrypoint
  • Loading branch information
airlied committed Sep 10, 2021
1 parent 4ccb1b2 commit 1ed4a5d
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
1 change: 1 addition & 0 deletions lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ const static char TranslateOCLMemScope[] = "__translate_ocl_memory_scope";
const static char TranslateSPIRVMemOrder[] = "__translate_spirv_memory_order";
const static char TranslateSPIRVMemScope[] = "__translate_spirv_memory_scope";
const static char TranslateSPIRVMemFence[] = "__translate_spirv_memory_fence";
const static char EntrypointPrefix[] = "__spirv_entry_";
} // namespace kSPIRVName

namespace kSPIRVPostfix {
Expand Down
65 changes: 65 additions & 0 deletions lib/SPIRV/SPIRVRegularizeLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

#include "OCLUtil.h"
#include "SPIRVInternal.h"
#include "SPIRVMDWalker.h"
#include "libSPIRV/SPIRVDebug.h"

#include "llvm/ADT/StringExtras.h" // llvm::isDigit
Expand Down Expand Up @@ -71,6 +72,11 @@ class SPIRVRegularizeLLVMBase {
// Lower functions
bool regularize();

// SPIR-V disallows functions being entrypoints and called
// LLVM doesn't. This adds a wrapper around the entry point
// that later SPIR-V writer renames.
void addKernelEntryPoint(Module *M);

/// Erase cast inst of function and replace with the function.
/// Assuming F is a SPIR-V builtin function with op code \param OC.
void lowerFuncPtr(Function *F, Op OC);
Expand Down Expand Up @@ -362,6 +368,7 @@ bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
bool SPIRVRegularizeLLVMBase::regularize() {
eraseUselessFunctions(M);
lowerFuncPtr(M);
addKernelEntryPoint(M);

for (auto I = M->begin(), E = M->end(); I != E;) {
Function *F = &(*I++);
Expand Down Expand Up @@ -529,6 +536,64 @@ void SPIRVRegularizeLLVMBase::lowerFuncPtr(Module *M) {
lowerFuncPtr(I.first, I.second);
}

void SPIRVRegularizeLLVMBase::addKernelEntryPoint(Module *M) {
std::vector<Function *> Work;

// Get a list of all functions that have SPIR kernel calling conv
for (auto &F : *M) {
if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
Work.push_back(&F);
}
for (auto &F: Work) {

// for declarations just make them into SPIR functions.
F->setCallingConv(CallingConv::SPIR_FUNC);
if (F->isDeclaration())
continue;

// Otherwise add a wrapper around the function to act as an entry point.
FunctionType *FType = F->getFunctionType();
std::string wrap_name = kSPIRVName::EntrypointPrefix + static_cast<std::string>(F->getName());
Function *WrapFn = getOrCreateFunction(M, F->getReturnType(),
FType->params(), wrap_name);

auto *CallBB = BasicBlock::Create(M->getContext(), "", WrapFn);
IRBuilder<> Builder(CallBB);

SmallVector<Value *, 1> Args;
for (auto AI = WrapFn->arg_begin(), End = WrapFn->arg_end(); AI != End; ++AI) {
Value *A = AI;
Args.emplace_back(A);
}
auto CI = CallInst::Create(F, ArrayRef<Value *>(Args), F->getReturnType()->isVoidTy() ? "" : "call", CallBB);
CI->setCallingConv(F->getCallingConv());
CI->setAttributes(F->getAttributes());

// copy over all the metadata (should it be removed from F?)
SmallVector<std::pair<unsigned, MDNode*>> MDs;
F->getAllMetadata(MDs);
WrapFn->setAttributes(F->getAttributes());
for (auto MD = MDs.begin(), End = MDs.end(); MD != End; ++MD) {
WrapFn->addMetadata(MD->first, *MD->second);
}
WrapFn->setCallingConv(CallingConv::SPIR_KERNEL);
WrapFn->setLinkage(llvm::GlobalValue::InternalLinkage);

Builder.CreateRet(F->getReturnType()->isVoidTy() ? nullptr : CI);

/* have to find the spir-v metadata for execution mode and transfer it to the wrapper */
if (auto NMD = SPIRVMDWalker(*M).getNamedMD(kSPIRVMD::ExecutionMode)) {
while (!NMD.atEnd()) {
Function *MDF = nullptr;
auto N = NMD.nextOp(); /* execution mode MDNode */
N.get(MDF);
if (MDF == F)
N.M->replaceOperandWith(0, ValueAsMetadata::get(WrapFn));
}
}
}
}

} // namespace SPIRV

INITIALIZE_PASS(SPIRVRegularizeLLVMLegacy, "spvregular",
Expand Down
20 changes: 14 additions & 6 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,11 +596,16 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {
SPIRVFunction *BF =
static_cast<SPIRVFunction *>(mapValue(F, BM->addFunction(BFT)));
BF->setFunctionControlMask(transFunctionControlMask(F));
if (F->hasName())
BM->setName(BF, F->getName().str());
if (isKernel(F))
BM->addEntryPoint(ExecutionModelKernel, BF->getId());
else if (F->getLinkage() != GlobalValue::InternalLinkage)
if (F->hasName()) {
if (isKernel(F)) {
/* strip the prefix as the runtime will be looking for this name */
std::string prefix = kSPIRVName::EntrypointPrefix;
std::string name = F->getName().str();
BM->setName(BF, name.substr(prefix.size()));
} else
BM->setName(BF, F->getName().str());
}
if (F->getLinkage() != GlobalValue::InternalLinkage && !isKernel(F))
BF->setLinkageType(transLinkageType(F));

// Translate OpenCL/SYCL buffer_location metadata if it's attached to the
Expand Down Expand Up @@ -3630,6 +3635,7 @@ void LLVMToSPIRVBase::transFunction(Function *I) {
bool IsKernelEntryPoint = isKernel(I);

if (IsKernelEntryPoint) {
BM->addEntryPoint(ExecutionModelKernel, BF->getId());
collectInputOutputVariables(BF, I);
}
}
Expand Down Expand Up @@ -3957,8 +3963,10 @@ bool LLVMToSPIRVBase::transMetadata() {
// Work around to translate kernel_arg_type and kernel_arg_type_qual metadata
static void transKernelArgTypeMD(SPIRVModule *BM, Function *F, MDNode *MD,
std::string MDName) {
std::string prefix = "__spirv_entry_";
std::string name = F->getName().str().substr(prefix.size());
std::string KernelArgTypesMDStr =
std::string(MDName) + "." + F->getName().str() + ".";
std::string(MDName) + "." + name + ".";
for (const auto &TyOp : MD->operands())
KernelArgTypesMDStr += cast<MDString>(TyOp)->getString().str() + ",";
BM->getString(KernelArgTypesMDStr);
Expand Down

0 comments on commit 1ed4a5d

Please sign in to comment.