Skip to content

Commit

Permalink
[spirv] support vk::ext_execution_mode_id(..) (#4190)
Browse files Browse the repository at this point in the history
As a part of HLSL version of GL_EXT_spirv_intrinsics, this commit adds
`vk::ext_execution_mode_id(..)` intrinsic function. In addition, it allows users
to enable capabilites and extensions via `vk::ext_execution_mode[_id](..)`
using `[[vk::ext_capability(..)]]` and `[[vk::ext_extension(..)]]`.

Related to #3919
  • Loading branch information
jaebaek authored Jan 21, 2022
1 parent 0d338ee commit 29874c9
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 61 deletions.
3 changes: 3 additions & 0 deletions include/dxc/HlslIntrinsicOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ enum class IntrinsicOp { IOP_AcceptHitAndEndSearch,
#endif // ENABLE_SPIRV_CODEGEN
#ifdef ENABLE_SPIRV_CODEGEN
IOP_Vkext_execution_mode,
#endif // ENABLE_SPIRV_CODEGEN
#ifdef ENABLE_SPIRV_CODEGEN
IOP_Vkext_execution_mode_id,
#endif // ENABLE_SPIRV_CODEGEN
MOP_Append,
MOP_RestartStrip,
Expand Down
1 change: 1 addition & 0 deletions lib/HLSL/HLOperationLower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5663,6 +5663,7 @@ IntrinsicLower gLowerTable[] = {
{ IntrinsicOp::IOP_VkReadClock, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes },
{ IntrinsicOp::IOP_VkRawBufferLoad, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes },
{ IntrinsicOp::IOP_Vkext_execution_mode, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes },
{ IntrinsicOp::IOP_Vkext_execution_mode_id, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes },
#endif // ENABLE_SPIRV_CODEGEN
{IntrinsicOp::MOP_Append, StreamOutputLower, DXIL::OpCode::EmitStream},
{IntrinsicOp::MOP_RestartStrip, StreamOutputLower, DXIL::OpCode::CutStream},
Expand Down
9 changes: 5 additions & 4 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,8 @@ class SpirvBuilder {
inline SpirvInstruction *addExecutionMode(SpirvFunction *entryPoint,
spv::ExecutionMode em,
llvm::ArrayRef<uint32_t> params,
SourceLocation);
SourceLocation,
bool useIdParams = false);

/// \brief Adds an OpModuleProcessed instruction to the module under
/// construction.
Expand Down Expand Up @@ -888,9 +889,9 @@ SpirvBuilder::setDebugSource(uint32_t major, uint32_t minor,
SpirvInstruction *
SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em,
llvm::ArrayRef<uint32_t> params,
SourceLocation loc) {
auto mode =
new (context) SpirvExecutionMode(loc, entryPoint, em, params, false);
SourceLocation loc, bool useIdParams) {
auto mode = new (context)
SpirvExecutionMode(loc, entryPoint, em, params, useIdParams);
mod->addExecutionMode(mode);

return mode;
Expand Down
12 changes: 10 additions & 2 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,16 @@ bool EmitVisitor::visit(SpirvExecutionMode *inst) {
initInstruction(inst);
curInst.push_back(getOrAssignResultId<SpirvFunction>(inst->getEntryPoint()));
curInst.push_back(static_cast<uint32_t>(inst->getExecutionMode()));
curInst.insert(curInst.end(), inst->getParams().begin(),
inst->getParams().end());
if (inst->getopcode() == spv::Op::OpExecutionMode) {
curInst.insert(curInst.end(), inst->getParams().begin(),
inst->getParams().end());
} else {
for (uint32_t param : inst->getParams()) {
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, param), context.getUIntType(32),
/*isSpecConst */ false));
}
}
finalizeInstruction(&preambleBinary);
return true;
}
Expand Down
114 changes: 63 additions & 51 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,11 +924,21 @@ void SpirvEmitter::doStmt(const Stmt *stmt,
doForStmt(forStmt, attrs);
} else if (dyn_cast<NullStmt>(stmt)) {
// For the null statement ";". We don't need to do anything.
} else if (const auto *expr = dyn_cast<Expr>(stmt)) {
// All cases for expressions used as statements
doExpr(expr);
} else if (const auto *attrStmt = dyn_cast<AttributedStmt>(stmt)) {
doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs());
} else if (const auto *expr = dyn_cast<Expr>(stmt)) {
// All cases for expressions used as statements
SpirvInstruction *result = doExpr(expr);

if (result && result->getKind() == SpirvInstruction::IK_ExecutionMode &&
!attrs.empty()) {
// Handle [[vk::ext_capability(..)]] and [[vk::ext_extension(..)]]
// attributes for vk::ext_execution_mode[_id](..).
createSpirvIntrInstExt(
attrs, QualType(),
/*spvArgs*/ llvm::SmallVector<SpirvInstruction *, 1>{},
/*isInstr*/ false, expr->getExprLoc());
}
} else {
emitError("statement class '%0' unimplemented", stmt->getLocStart())
<< stmt->getStmtClassName() << stmt->getSourceRange();
Expand Down Expand Up @@ -7806,7 +7816,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
retVal = processRawBufferLoad(callExpr);
break;
case hlsl::IntrinsicOp::IOP_Vkext_execution_mode:
retVal = processIntrinsicExecutionMode(callExpr);
retVal = processIntrinsicExecutionMode(callExpr, false);
break;
case hlsl::IntrinsicOp::IOP_Vkext_execution_mode_id:
retVal = processIntrinsicExecutionMode(callExpr, true);
break;
case hlsl::IntrinsicOp::IOP_saturate:
retVal = processIntrinsicSaturate(callExpr);
Expand Down Expand Up @@ -12750,32 +12763,46 @@ SpirvEmitter::processRayQueryIntrinsics(const CXXMemberCallExpr *expr,
return retVal;
}

SpirvInstruction *
SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) {
auto funcDecl = expr->getDirectCallee();
auto &attrs = funcDecl->getAttrs();
QualType retType = funcDecl->getReturnType();

SpirvInstruction *SpirvEmitter::createSpirvIntrInstExt(
llvm::ArrayRef<const Attr *> attrs, QualType retType,
const llvm::SmallVectorImpl<SpirvInstruction *> &spvArgs, bool isInstr,
SourceLocation loc) {
llvm::SmallVector<uint32_t, 2> capbilities;
llvm::SmallVector<llvm::StringRef, 2> extensions;
llvm::StringRef instSet = "";
uint32_t op = 0;
// For [[vk::ext_type_def]], we use dummy OpNop with no semantic meaning,
// with possible extension and capabilities.
uint32_t op = static_cast<unsigned>(spv::Op::OpNop);
for (auto &attr : attrs) {
if (auto capAttr = dyn_cast<VKCapabilityExtAttr>(attr)) {
capbilities.push_back(capAttr->getCapability());
} else if (auto extAttr = dyn_cast<VKExtensionExtAttr>(attr)) {
extensions.push_back(extAttr->getName());
} else if (auto instAttr = dyn_cast<VKInstructionExtAttr>(attr)) {
}
if (!isInstr)
continue;
if (auto instAttr = dyn_cast<VKInstructionExtAttr>(attr)) {
op = instAttr->getOpcode();
instSet = instAttr->getInstruction_set();
}
}

llvm::SmallVector<SpirvInstruction *, 8> spvArgs;
SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
op, retType, spvArgs, extensions, instSet, capbilities, loc);

// TODO: Revisit this r-value setting when handling vk::ext_result_id<T> ?
retVal->setRValue();

return retVal;
}

SpirvInstruction *
SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) {
const auto *funcDecl = expr->getDirectCallee();
llvm::SmallVector<SpirvInstruction *, 8> spvArgs;
const auto args = expr->getArgs();
for (uint32_t i = 0; i < expr->getNumArgs(); ++i) {
auto param = funcDecl->getParamDecl(i);
const auto *param = funcDecl->getParamDecl(i);
const Expr *arg = args[i]->IgnoreParenLValueCasts();
SpirvInstruction *argInst = doExpr(arg);
if (param->hasAttr<VKReferenceExtAttr>()) {
Expand All @@ -12796,14 +12823,9 @@ SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) {
}
}

const auto loc = expr->getExprLoc();

SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
op, retType, spvArgs, extensions, instSet, capbilities, loc);

// TODO: Revisit this r-value setting when handling vk::ext_result_id<T> ?
retVal->setRValue();
return retVal;
return createSpirvIntrInstExt(funcDecl->getAttrs(), funcDecl->getReturnType(),
spvArgs,
/*isInstr*/ true, expr->getExprLoc());
}

SpirvInstruction *SpirvEmitter::processRawBufferLoad(const CallExpr *callExpr) {
Expand Down Expand Up @@ -12833,18 +12855,22 @@ SpirvInstruction *SpirvEmitter::processRawBufferLoad(const CallExpr *callExpr) {
}

SpirvInstruction *
SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) {
SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr,
bool useIdParams) {
llvm::SmallVector<uint32_t, 2> execModesParams;
uint32_t exeMode = 0;
const auto args = expr->getArgs();
for (uint32_t i = 0; i < expr->getNumArgs(); ++i) {
SpirvConstantInteger *argInst =
dyn_cast<SpirvConstantInteger>(doExpr(args[i]));
if (argInst == nullptr) {
emitError("argument should be constant interger", expr->getExprLoc());
const auto *intLiteral =
dyn_cast<IntegerLiteral>(args[i]->IgnoreImplicit());
if (intLiteral == nullptr) {
emitError("argument should be constant integer", expr->getExprLoc());
return nullptr;
}
unsigned argInteger = argInst->getValue().getZExtValue();

uint32_t argInteger =
static_cast<uint32_t>(intLiteral->getValue().getZExtValue());

if (i > 0)
execModesParams.push_back(argInteger);
else
Expand All @@ -12853,26 +12879,14 @@ SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) {
assert(entryFunction != nullptr);
assert(exeMode != 0);

return spvBuilder.addExecutionMode(entryFunction,
static_cast<spv::ExecutionMode>(exeMode),
execModesParams, expr->getExprLoc());
return spvBuilder.addExecutionMode(
entryFunction, static_cast<spv::ExecutionMode>(exeMode), execModesParams,
expr->getExprLoc(), useIdParams);
}

SpirvInstruction *
SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) {
auto funcDecl = expr->getDirectCallee();
auto typeDefAttr = funcDecl->getAttr<VKTypeDefExtAttr>();
llvm::SmallVector<uint32_t, 2> capbilities;
llvm::SmallVector<llvm::StringRef, 2> extensions;

for (auto &attr : funcDecl->getAttrs()) {
if (auto capAttr = dyn_cast<VKCapabilityExtAttr>(attr)) {
capbilities.push_back(capAttr->getCapability());
} else if (auto extAttr = dyn_cast<VKExtensionExtAttr>(attr)) {
extensions.push_back(extAttr->getName());
}
}

SmallVector<SpvIntrinsicTypeOperand, 3> operands;
const auto args = expr->getArgs();
for (uint32_t i = 0; i < expr->getNumArgs(); ++i) {
Expand All @@ -12897,17 +12911,15 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) {
operands.emplace_back(loadIfGLValue(arg));
}
}

auto typeDefAttr = funcDecl->getAttr<VKTypeDefExtAttr>();
spvContext.getSpirvIntrinsicType(typeDefAttr->getId(),
typeDefAttr->getOpcode(), operands);

// Emit dummy OpNop with no semantic meaning, with possible extension and
// capabilities
SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
static_cast<unsigned>(spv::Op::OpNop), QualType(), {}, extensions, {},
capbilities, expr->getExprLoc());
retVal->setRValue();

return retVal;
return createSpirvIntrInstExt(
funcDecl->getAttrs(), QualType(),
/*spvArgs*/ llvm::SmallVector<SpirvInstruction *, 1>{},
/*isInstr*/ false, expr->getExprLoc());
}

bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
Expand Down
10 changes: 9 additions & 1 deletion tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,13 @@ class SpirvEmitter : public ASTConsumer {
/// Process ray query intrinsics
SpirvInstruction *processRayQueryIntrinsics(const CXXMemberCallExpr *expr,
hlsl::IntrinsicOp opcode);

/// Create SpirvIntrinsicInstruction for arbitrary SPIR-V instructions
/// specified by [[vk::ext_instruction(..)]] or [[vk::ext_type_def(..)]]
SpirvInstruction *createSpirvIntrInstExt(
llvm::ArrayRef<const Attr *> attrs, QualType retType,
const llvm::SmallVectorImpl<SpirvInstruction *> &spvArgs, bool isInstr,
SourceLocation loc);
/// Process spirv intrinsic instruction
SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr);

Expand All @@ -622,7 +629,8 @@ class SpirvEmitter : public ASTConsumer {
/// Custom intrinsic to support basic buffer_reference use case
SpirvInstruction *processRawBufferLoad(const CallExpr *callExpr);
/// Process vk::ext_execution_mode intrinsic
SpirvInstruction *processIntrinsicExecutionMode(const CallExpr *expr);
SpirvInstruction *processIntrinsicExecutionMode(const CallExpr *expr,
bool useIdParams);

private:
/// Returns the <result-id> for constant value 0 of the given type.
Expand Down
18 changes: 18 additions & 0 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12186,6 +12186,24 @@ Attr *hlsl::ProcessStmtAttributeForHLSL(Sema &S, Stmt *St, const AttributeList &
Attr * result = nullptr;
Handled = true;

// SPIRV Change Starts
if (A.hasScope() && A.getScopeName()->getName().equals("vk")) {
switch (A.getKind()) {
case AttributeList::AT_VKCapabilityExt:
return ::new (S.Context) VKCapabilityExtAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
case AttributeList::AT_VKExtensionExt:
return ::new (S.Context) VKExtensionExtAttr(
A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr),
A.getAttributeSpellingListIndex());
default:
Handled = false;
return nullptr;
}
}
// SPIRV Change Ends

switch (A.getKind())
{
case AttributeList::AT_HLSLUnroll:
Expand Down
7 changes: 5 additions & 2 deletions tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionMode.hlsl
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
// RUN: %dxc -T ps_6_0 -E main -spirv


// CHECK: OpCapability ShaderClockKHR
// CHECK: OpExtension "SPV_KHR_shader_clock"
// CHECK: OpExecutionMode {{%\w+}} StencilRefReplacingEXT
// CHECK: OpExecutionMode {{%\w+}} SubgroupSize 32
// CHECK: OpDecorate {{%\w+}} BuiltIn FragStencilRefEXT

[[vk::ext_decorate(11, 5014)]]
int main() : SV_Target0 {

[[vk::ext_capability(5055)]]
[[vk::ext_extension("SPV_KHR_shader_clock")]]
vk::ext_execution_mode(/*StencilRefReplacingEXT*/5027);

vk::ext_execution_mode(/*SubgroupSize*/35, 32);
return 3;
}
16 changes: 16 additions & 0 deletions tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionModeId.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: %dxc -T ps_6_0 -E main -spirv

// CHECK: OpCapability ShaderClockKHR
// CHECK: OpExtension "SPV_KHR_shader_clock"
// CHECK: OpExecutionModeId {{%\w+}} LocalSizeId %uint_8 %uint_8 %uint_8
// CHECK: OpExecutionModeId {{%\w+}} LocalSizeHintId %uint_4 %uint_4 %uint_4

int main() : SV_Target0 {
vk::ext_execution_mode_id(/*LocalSizeId*/38, 8, 8, 8);

[[vk::ext_capability(5055)]]
[[vk::ext_extension("SPV_KHR_shader_clock")]]
vk::ext_execution_mode_id(/*LocalSizeHintId*/39, 4, 4, 4);

return 3;
}
1 change: 1 addition & 0 deletions tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1349,6 +1349,7 @@ TEST_F(FileTest, IntrinsicsSpirv) {
runFileTest("spv.intrinsicLiteral.hlsl");
runFileTest("spv.intrinsicDecorate.hlsl", Expect::Success, false);
runFileTest("spv.intrinsicExecutionMode.hlsl", Expect::Success, false);
runFileTest("spv.intrinsicExecutionModeId.hlsl", Expect::Success, false);
runFileTest("spv.intrinsicStorageClass.hlsl", Expect::Success, false);
runFileTest("spv.intrinsicTypeInteger.hlsl");
runFileTest("spv.intrinsicTypeRayquery.hlsl", Expect::Success, false);
Expand Down
3 changes: 2 additions & 1 deletion utils/hct/gen_intrin_main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ namespace VkIntrinsics {

u64 [[]] ReadClock(in uint scope);
uint [[ro]] RawBufferLoad(in u64 addr);
void [[]] ext_execution_mode(...);
void [[]] ext_execution_mode(in uint mode, ...);
void [[]] ext_execution_mode_id(in uint mode, ...);

} namespace
// SPIRV Change Ends
Expand Down

0 comments on commit 29874c9

Please sign in to comment.