Skip to content

Commit

Permalink
[NFC][SYCL] Refactor kernel wrapper generation.
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Lazarev <vladimir.lazarev@intel.com>
  • Loading branch information
vladimirlaz committed Jan 22, 2019
1 parent f1d20ee commit 120b4b5
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 63 deletions.
2 changes: 1 addition & 1 deletion clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -10852,7 +10852,7 @@ class Sema {
void AddSyclKernel(Decl * d) { SyclKernel.push_back(d); }
SmallVector<Decl*, 4> &SyclKernels() { return SyclKernel; }

void ConstructSYCLKernel(CXXMemberCallExpr* e);
void ConstructSYCLKernel(FunctionDecl* KernelHelper);
};

/// RAII object that enters a new expression evaluation context.
Expand Down
9 changes: 0 additions & 9 deletions clang/lib/Sema/SemaOverload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13012,15 +13012,6 @@ Sema::BuildCallToMemberFunction(Scope *S, Expr *MemExprE,
CXXMemberCallExpr::Create(Context, MemExprE, Args, ResultType, VK,
RParenLoc, Proto->getNumParams());

if (getLangOpts().SYCL) {
auto Func = TheCall->getMethodDecl();
auto Name = Func->getQualifiedNameAsString();
if (Name == "cl::sycl::handler::parallel_for" ||
Name == "cl::sycl::handler::single_task") {
ConstructSYCLKernel(TheCall);
}
}

// Check for a valid return type.
if (CheckCallReturnType(Method->getReturnType(), MemExpr->getMemberLoc(),
TheCall, Method))
Expand Down
106 changes: 58 additions & 48 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,44 @@
#include "clang/AST/AST.h"
#include "clang/Sema/Sema.h"
#include "llvm/ADT/SmallVector.h"
#include "TreeTransform.h"

using namespace clang;

LambdaExpr *getBodyAsLambda(CXXMemberCallExpr *e) {
auto LastArg = e->getArg(e->getNumArgs() - 1);
return dyn_cast<LambdaExpr>(LastArg);
typedef llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> DeclMap;

class KernelBodyTransform : public TreeTransform<KernelBodyTransform> {
public:
KernelBodyTransform(llvm::DenseMap<DeclaratorDecl *, DeclaratorDecl *> &Map,
Sema &S)
: TreeTransform<KernelBodyTransform>(S), DMap(Map), SemaRef(S) {}
bool AlwaysRebuild() { return true; }

ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) {
auto Ref = dyn_cast<DeclaratorDecl>(DRE->getDecl());
if (Ref) {
auto NewDecl = DMap[Ref];
if (NewDecl) {
return DeclRefExpr::Create(
SemaRef.getASTContext(), DRE->getQualifierLoc(),
DRE->getTemplateKeywordLoc(), NewDecl, false, DRE->getNameInfo(),
NewDecl->getType(), DRE->getValueKind());
}
}
return DRE;
}

private:
DeclMap DMap;
Sema &SemaRef;
};

CXXRecordDecl* getBodyAsLambda(FunctionDecl *FD) {
auto FirstArg = (*FD->param_begin());
if (FirstArg)
if (FirstArg->getType()->getAsCXXRecordDecl()->isLambda())
return FirstArg->getType()->getAsCXXRecordDecl();
return nullptr;
}

FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
Expand Down Expand Up @@ -54,17 +86,16 @@ FunctionDecl *CreateSYCLKernelFunction(ASTContext &Context, StringRef Name,
return Result;
}

CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
CompoundStmt *CreateSYCLKernelBody(Sema &S, FunctionDecl *KernelHelper,
DeclContext *DC) {

llvm::SmallVector<Stmt *, 16> BodyStmts;

// TODO: case when kernel is functor
// TODO: possible refactoring when functor case will be completed
LambdaExpr *LE = getBodyAsLambda(e);
if (LE) {
CXXRecordDecl *LC = getBodyAsLambda(KernelHelper);
if (LC) {
// Create Lambda object
CXXRecordDecl *LC = LE->getLambdaClass();
auto LambdaVD = VarDecl::Create(
S.Context, DC, SourceLocation(), SourceLocation(), LC->getIdentifier(),
QualType(LC->getTypeForDecl(), 0), LC->getLambdaTypeInfo(), SC_None);
Expand Down Expand Up @@ -137,43 +168,23 @@ CompoundStmt *CreateSYCLKernelBody(Sema &S, CXXMemberCallExpr *e,
TargetFuncParam++;
}

// Create Lambda operator () call
FunctionDecl *LO = LE->getCallOperator();
ArrayRef<ParmVarDecl *> Args = LO->parameters();
llvm::SmallVector<Expr *, 16> ParamStmts(1);
ParamStmts[0] = dyn_cast<Expr>(LambdaDRE);

// Collect arguments for () operator
for (auto Arg : Args) {
QualType ArgType = Arg->getOriginalType();
// Declare variable for parameter and pass it to call
auto param_VD =
VarDecl::Create(S.Context, DC, SourceLocation(), SourceLocation(),
Arg->getIdentifier(), ArgType,
S.Context.getTrivialTypeSourceInfo(ArgType), SC_None);
Stmt *param_DS = new (S.Context)
DeclStmt(DeclGroupRef(param_VD), SourceLocation(), SourceLocation());
BodyStmts.push_back(param_DS);
auto DRE = DeclRefExpr::Create(S.Context, NestedNameSpecifierLoc(),
SourceLocation(), param_VD, false,
DeclarationNameInfo(), ArgType, VK_LValue);
Expr *Res = ImplicitCastExpr::Create(
S.Context, ArgType, CK_LValueToRValue, DRE, nullptr, VK_RValue);
ParamStmts.push_back(Res);
}
// In function from headers lambda is function parameter, we need
// to replace all refs to this lambda with our vardecl.
// I used TreeTransform here, but I'm not sure that it is good solution
// Also I used map and I'm not sure about it too.
Stmt* FunctionBody = KernelHelper->getBody();
DeclMap DMap;
ParmVarDecl* LambdaParam = *(KernelHelper->param_begin());
// DeclRefExpr with valid source location but with decl which is not marked
// as used is invalid.
LambdaVD->setIsUsed();
DMap[LambdaParam] = LambdaVD;
// Without PushFunctionScope I had segfault. Maybe we also need to do pop.
S.PushFunctionScope();
KernelBodyTransform KBT(DMap, S);
Stmt* NewBody = KBT.TransformStmt(FunctionBody).get();
BodyStmts.push_back(NewBody);

// Create ref for call operator
DeclRefExpr *DRE = new (S.Context)
DeclRefExpr(S.Context, LO, false, LO->getType(), VK_LValue,
SourceLocation());
QualType ResultTy = LO->getReturnType();
ExprValueKind VK = Expr::getValueKindForType(ResultTy);
ResultTy = ResultTy.getNonLValueExprType(S.Context);

CXXOperatorCallExpr *TheCall = CXXOperatorCallExpr::Create(
S.Context, OO_Call, DRE, ParamStmts, ResultTy, VK, SourceLocation(),
FPOptions(), clang::CallExpr::ADLCallKind::NotADL );
BodyStmts.push_back(TheCall);
}
return CompoundStmt::Create(S.Context, BodyStmts, SourceLocation(),
SourceLocation());
Expand Down Expand Up @@ -222,9 +233,9 @@ void BuildArgTys(ASTContext &Context,
}
}

void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
void Sema::ConstructSYCLKernel(FunctionDecl *KernelHelper) {
// TODO: Case when kernel is functor
LambdaExpr *LE = getBodyAsLambda(e);
CXXRecordDecl *LE = getBodyAsLambda(KernelHelper);
if (LE) {

llvm::SmallVector<DeclaratorDecl *, 16> ArgDecls;
Expand All @@ -238,9 +249,8 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
BuildArgTys(getASTContext(), ArgDecls, NewArgDecls, ArgTys);

// Get Name for our kernel.
FunctionDecl *FuncDecl = e->getMethodDecl();
const TemplateArgumentList *TemplateArgs =
FuncDecl->getTemplateSpecializationArgs();
KernelHelper->getTemplateSpecializationArgs();
QualType KernelNameType = TemplateArgs->get(0).getAsType();
std::string Name = KernelNameType.getBaseTypeIdentifier()->getName().str();

Expand All @@ -256,7 +266,7 @@ void Sema::ConstructSYCLKernel(CXXMemberCallExpr *e) {
FunctionDecl *SYCLKernel =
CreateSYCLKernelFunction(getASTContext(), Name, ArgTys, NewArgDecls);

CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, e, SYCLKernel);
CompoundStmt *SYCLKernelBody = CreateSYCLKernelBody(*this, KernelHelper, SYCLKernel);
SYCLKernel->setBody(SYCLKernelBody);

AddSyclKernel(SYCLKernel);
Expand Down
18 changes: 16 additions & 2 deletions clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5231,14 +5231,28 @@ void Sema::PerformPendingInstantiations(bool LocalOnly) {
Function, [this, Inst, DefinitionRequired](FunctionDecl *CurFD) {
InstantiateFunctionDefinition(/*FIXME:*/ Inst.second, CurFD, true,
DefinitionRequired, true);
if (CurFD->isDefined())
if (CurFD->isDefined()) {
// Because all SYCL kernel functions are template functions - they
// have deferred instantination. We need bodies of these functions
// so we are checking for SYCL kernel attribute after instantination.
if (getLangOpts().SYCL && CurFD->hasAttr<SYCLKernelAttr>()) {
ConstructSYCLKernel(CurFD);
}
CurFD->setInstantiationIsPending(false);
}
});
} else {
InstantiateFunctionDefinition(/*FIXME:*/ Inst.second, Function, true,
DefinitionRequired, true);
if (Function->isDefined())
if (Function->isDefined()) {
// Because all SYCL kernel functions are template functions - they
// have deferred instantination. We need bodies of these functions
// so we are checking for SYCL kernel attribute after instantination.
if (getLangOpts().SYCL && Function->hasAttr<SYCLKernelAttr>()) {
ConstructSYCLKernel(Function);
}
Function->setInstantiationIsPending(false);
}
}
continue;
}
Expand Down
5 changes: 2 additions & 3 deletions clang/test/CodeGenSYCL/kernel-with-id.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ int main() {

deviceQueue.submit([&](cl::sycl::handler &cgh) {
auto accessorA = bufferA.template get_access<cl::sycl::access::mode::read_write>(cgh);
// CHECK: %wiID = alloca %"struct.cl::sycl::id", align 8
// CHECK: call spir_func void @_ZN2cl4sycl8accessorIiLi1ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0EE13__set_pointerEPU3AS1i(%"class.cl::sycl::accessor"* %1, i32 addrspace(1)* %2)
// CHECK: call spir_func void @"_ZZZ4mainENK3$_0clERN2cl4sycl7handlerEENKUlNS1_2idILm1EEEE_clES5_"(%class.anon* %0, %"struct.cl::sycl::id"* byval align 8 %wiID)
// CHECK: %call = call spir_func i64 @_Z13get_global_idj(i32 0)
// CHECK: %call = call spir_func i64 @_Z13get_global_idj(i32 %{{.*}})
// CHECK: call spir_func void @"_ZZZ4mainENK3$_0clERN2cl4sycl7handlerEENKUlNS1_2idILm1EEEE_clES5_"(%class.anon* %0, %"struct.cl::sycl::id"* byval align 8 %{{.*}})
cgh.parallel_for<class kernel_function>(numOfItems,
[=](cl::sycl::id<1> wiID) {
accessorA[wiID] = accessorA[wiID] * accessorA[wiID];
Expand Down

0 comments on commit 120b4b5

Please sign in to comment.