diff --git a/include/dxc/HlslIntrinsicOp.h b/include/dxc/HlslIntrinsicOp.h index e342255f2d..a3897b2d0c 100644 --- a/include/dxc/HlslIntrinsicOp.h +++ b/include/dxc/HlslIntrinsicOp.h @@ -229,6 +229,9 @@ enum class IntrinsicOp { IOP_AcceptHitAndEndSearch, #endif // ENABLE_SPIRV_CODEGEN #ifdef ENABLE_SPIRV_CODEGEN IOP_Vkext_execution_mode, +#endif // ENABLE_SPIRV_CODEGEN +#ifdef ENABLE_SPIRV_CODEGEN + IOP_Vkext_execution_mode_id, #endif // ENABLE_SPIRV_CODEGEN MOP_Append, MOP_RestartStrip, diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp index 7c4d2217f1..affb544e6c 100644 --- a/lib/HLSL/HLOperationLower.cpp +++ b/lib/HLSL/HLOperationLower.cpp @@ -5663,6 +5663,7 @@ IntrinsicLower gLowerTable[] = { { IntrinsicOp::IOP_VkReadClock, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes }, { IntrinsicOp::IOP_VkRawBufferLoad, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes }, { IntrinsicOp::IOP_Vkext_execution_mode, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes }, + { IntrinsicOp::IOP_Vkext_execution_mode_id, UnsupportedVulkanIntrinsic, DXIL::OpCode::NumOpCodes }, #endif // ENABLE_SPIRV_CODEGEN {IntrinsicOp::MOP_Append, StreamOutputLower, DXIL::OpCode::EmitStream}, {IntrinsicOp::MOP_RestartStrip, StreamOutputLower, DXIL::OpCode::CutStream}, diff --git a/tools/clang/include/clang/SPIRV/SpirvBuilder.h b/tools/clang/include/clang/SPIRV/SpirvBuilder.h index 772145212b..010e670afe 100644 --- a/tools/clang/include/clang/SPIRV/SpirvBuilder.h +++ b/tools/clang/include/clang/SPIRV/SpirvBuilder.h @@ -585,7 +585,8 @@ class SpirvBuilder { inline SpirvInstruction *addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em, llvm::ArrayRef params, - SourceLocation); + SourceLocation, + bool useIdParams = false); /// \brief Adds an OpModuleProcessed instruction to the module under /// construction. @@ -888,9 +889,9 @@ SpirvBuilder::setDebugSource(uint32_t major, uint32_t minor, SpirvInstruction * SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em, llvm::ArrayRef params, - SourceLocation loc) { - auto mode = - new (context) SpirvExecutionMode(loc, entryPoint, em, params, false); + SourceLocation loc, bool useIdParams) { + auto mode = new (context) + SpirvExecutionMode(loc, entryPoint, em, params, useIdParams); mod->addExecutionMode(mode); return mode; diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 87663e650e..507e150dc0 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -587,8 +587,16 @@ bool EmitVisitor::visit(SpirvExecutionMode *inst) { initInstruction(inst); curInst.push_back(getOrAssignResultId(inst->getEntryPoint())); curInst.push_back(static_cast(inst->getExecutionMode())); - curInst.insert(curInst.end(), inst->getParams().begin(), - inst->getParams().end()); + if (inst->getopcode() == spv::Op::OpExecutionMode) { + curInst.insert(curInst.end(), inst->getParams().begin(), + inst->getParams().end()); + } else { + for (uint32_t param : inst->getParams()) { + curInst.push_back(typeHandler.getOrCreateConstantInt( + llvm::APInt(32, param), context.getUIntType(32), + /*isSpecConst */ false)); + } + } finalizeInstruction(&preambleBinary); return true; } diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 008bd73172..3a0e525ffe 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -924,11 +924,21 @@ void SpirvEmitter::doStmt(const Stmt *stmt, doForStmt(forStmt, attrs); } else if (dyn_cast(stmt)) { // For the null statement ";". We don't need to do anything. - } else if (const auto *expr = dyn_cast(stmt)) { - // All cases for expressions used as statements - doExpr(expr); } else if (const auto *attrStmt = dyn_cast(stmt)) { doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs()); + } else if (const auto *expr = dyn_cast(stmt)) { + // All cases for expressions used as statements + SpirvInstruction *result = doExpr(expr); + + if (result && result->getKind() == SpirvInstruction::IK_ExecutionMode && + !attrs.empty()) { + // Handle [[vk::ext_capability(..)]] and [[vk::ext_extension(..)]] + // attributes for vk::ext_execution_mode[_id](..). + createSpirvIntrInstExt( + attrs, QualType(), + /*spvArgs*/ llvm::SmallVector{}, + /*isInstr*/ false, expr->getExprLoc()); + } } else { emitError("statement class '%0' unimplemented", stmt->getLocStart()) << stmt->getStmtClassName() << stmt->getSourceRange(); @@ -7776,7 +7786,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { retVal = processRawBufferLoad(callExpr); break; case hlsl::IntrinsicOp::IOP_Vkext_execution_mode: - retVal = processIntrinsicExecutionMode(callExpr); + retVal = processIntrinsicExecutionMode(callExpr, false); + break; + case hlsl::IntrinsicOp::IOP_Vkext_execution_mode_id: + retVal = processIntrinsicExecutionMode(callExpr, true); break; case hlsl::IntrinsicOp::IOP_saturate: retVal = processIntrinsicSaturate(callExpr); @@ -12720,32 +12733,46 @@ SpirvEmitter::processRayQueryIntrinsics(const CXXMemberCallExpr *expr, return retVal; } -SpirvInstruction * -SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) { - auto funcDecl = expr->getDirectCallee(); - auto &attrs = funcDecl->getAttrs(); - QualType retType = funcDecl->getReturnType(); - +SpirvInstruction *SpirvEmitter::createSpirvIntrInstExt( + llvm::ArrayRef attrs, QualType retType, + const llvm::SmallVectorImpl &spvArgs, bool isInstr, + SourceLocation loc) { llvm::SmallVector capbilities; llvm::SmallVector extensions; llvm::StringRef instSet = ""; - uint32_t op = 0; + // For [[vk::ext_type_def]], we use dummy OpNop with no semantic meaning, + // with possible extension and capabilities. + uint32_t op = static_cast(spv::Op::OpNop); for (auto &attr : attrs) { if (auto capAttr = dyn_cast(attr)) { capbilities.push_back(capAttr->getCapability()); } else if (auto extAttr = dyn_cast(attr)) { extensions.push_back(extAttr->getName()); - } else if (auto instAttr = dyn_cast(attr)) { + } + if (!isInstr) + continue; + if (auto instAttr = dyn_cast(attr)) { op = instAttr->getOpcode(); instSet = instAttr->getInstruction_set(); } } - llvm::SmallVector spvArgs; + SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt( + op, retType, spvArgs, extensions, instSet, capbilities, loc); + // TODO: Revisit this r-value setting when handling vk::ext_result_id ? + retVal->setRValue(); + + return retVal; +} + +SpirvInstruction * +SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) { + const auto *funcDecl = expr->getDirectCallee(); + llvm::SmallVector spvArgs; const auto args = expr->getArgs(); for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { - auto param = funcDecl->getParamDecl(i); + const auto *param = funcDecl->getParamDecl(i); const Expr *arg = args[i]->IgnoreParenLValueCasts(); SpirvInstruction *argInst = doExpr(arg); if (param->hasAttr()) { @@ -12766,14 +12793,9 @@ SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) { } } - const auto loc = expr->getExprLoc(); - - SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt( - op, retType, spvArgs, extensions, instSet, capbilities, loc); - - // TODO: Revisit this r-value setting when handling vk::ext_result_id ? - retVal->setRValue(); - return retVal; + return createSpirvIntrInstExt(funcDecl->getAttrs(), funcDecl->getReturnType(), + spvArgs, + /*isInstr*/ true, expr->getExprLoc()); } SpirvInstruction *SpirvEmitter::processRawBufferLoad(const CallExpr *callExpr) { @@ -12803,18 +12825,22 @@ SpirvInstruction *SpirvEmitter::processRawBufferLoad(const CallExpr *callExpr) { } SpirvInstruction * -SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) { +SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr, + bool useIdParams) { llvm::SmallVector execModesParams; uint32_t exeMode = 0; const auto args = expr->getArgs(); for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { - SpirvConstantInteger *argInst = - dyn_cast(doExpr(args[i])); - if (argInst == nullptr) { - emitError("argument should be constant interger", expr->getExprLoc()); + const auto *intLiteral = + dyn_cast(args[i]->IgnoreImplicit()); + if (intLiteral == nullptr) { + emitError("argument should be constant integer", expr->getExprLoc()); return nullptr; } - unsigned argInteger = argInst->getValue().getZExtValue(); + + uint32_t argInteger = + static_cast(intLiteral->getValue().getZExtValue()); + if (i > 0) execModesParams.push_back(argInteger); else @@ -12823,26 +12849,14 @@ SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) { assert(entryFunction != nullptr); assert(exeMode != 0); - return spvBuilder.addExecutionMode(entryFunction, - static_cast(exeMode), - execModesParams, expr->getExprLoc()); + return spvBuilder.addExecutionMode( + entryFunction, static_cast(exeMode), execModesParams, + expr->getExprLoc(), useIdParams); } SpirvInstruction * SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { auto funcDecl = expr->getDirectCallee(); - auto typeDefAttr = funcDecl->getAttr(); - llvm::SmallVector capbilities; - llvm::SmallVector extensions; - - for (auto &attr : funcDecl->getAttrs()) { - if (auto capAttr = dyn_cast(attr)) { - capbilities.push_back(capAttr->getCapability()); - } else if (auto extAttr = dyn_cast(attr)) { - extensions.push_back(extAttr->getName()); - } - } - SmallVector operands; const auto args = expr->getArgs(); for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { @@ -12867,17 +12881,15 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { operands.emplace_back(loadIfGLValue(arg)); } } + + auto typeDefAttr = funcDecl->getAttr(); spvContext.getSpirvIntrinsicType(typeDefAttr->getId(), typeDefAttr->getOpcode(), operands); - // Emit dummy OpNop with no semantic meaning, with possible extension and - // capabilities - SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt( - static_cast(spv::Op::OpNop), QualType(), {}, extensions, {}, - capbilities, expr->getExprLoc()); - retVal->setRValue(); - - return retVal; + return createSpirvIntrInstExt( + funcDecl->getAttrs(), QualType(), + /*spvArgs*/ llvm::SmallVector{}, + /*isInstr*/ false, expr->getExprLoc()); } bool SpirvEmitter::spirvToolsValidate(std::vector *mod, diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 4b77e6f0fb..b1873a3905 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -613,6 +613,13 @@ class SpirvEmitter : public ASTConsumer { /// Process ray query intrinsics SpirvInstruction *processRayQueryIntrinsics(const CXXMemberCallExpr *expr, hlsl::IntrinsicOp opcode); + + /// Create SpirvIntrinsicInstruction for arbitrary SPIR-V instructions + /// specified by [[vk::ext_instruction(..)]] or [[vk::ext_type_def(..)]] + SpirvInstruction *createSpirvIntrInstExt( + llvm::ArrayRef attrs, QualType retType, + const llvm::SmallVectorImpl &spvArgs, bool isInstr, + SourceLocation loc); /// Process spirv intrinsic instruction SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr); @@ -622,7 +629,8 @@ class SpirvEmitter : public ASTConsumer { /// Custom intrinsic to support basic buffer_reference use case SpirvInstruction *processRawBufferLoad(const CallExpr *callExpr); /// Process vk::ext_execution_mode intrinsic - SpirvInstruction *processIntrinsicExecutionMode(const CallExpr *expr); + SpirvInstruction *processIntrinsicExecutionMode(const CallExpr *expr, + bool useIdParams); private: /// Returns the for constant value 0 of the given type. diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 0aa9767454..645fe7f3ba 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -12173,6 +12173,24 @@ Attr *hlsl::ProcessStmtAttributeForHLSL(Sema &S, Stmt *St, const AttributeList & Attr * result = nullptr; Handled = true; + // SPIRV Change Starts + if (A.hasScope() && A.getScopeName()->getName().equals("vk")) { + switch (A.getKind()) { + case AttributeList::AT_VKCapabilityExt: + return ::new (S.Context) VKCapabilityExtAttr( + A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + A.getAttributeSpellingListIndex()); + case AttributeList::AT_VKExtensionExt: + return ::new (S.Context) VKExtensionExtAttr( + A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr), + A.getAttributeSpellingListIndex()); + default: + Handled = false; + return nullptr; + } + } + // SPIRV Change Ends + switch (A.getKind()) { case AttributeList::AT_HLSLUnroll: diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionMode.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionMode.hlsl index 9a410b9e65..2b33c05942 100644 --- a/tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionMode.hlsl +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionMode.hlsl @@ -1,14 +1,17 @@ // RUN: %dxc -T ps_6_0 -E main -spirv - +// CHECK: OpCapability ShaderClockKHR +// CHECK: OpExtension "SPV_KHR_shader_clock" // CHECK: OpExecutionMode {{%\w+}} StencilRefReplacingEXT // CHECK: OpExecutionMode {{%\w+}} SubgroupSize 32 // CHECK: OpDecorate {{%\w+}} BuiltIn FragStencilRefEXT [[vk::ext_decorate(11, 5014)]] int main() : SV_Target0 { - + [[vk::ext_capability(5055)]] + [[vk::ext_extension("SPV_KHR_shader_clock")]] vk::ext_execution_mode(/*StencilRefReplacingEXT*/5027); + vk::ext_execution_mode(/*SubgroupSize*/35, 32); return 3; } diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionModeId.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionModeId.hlsl new file mode 100644 index 0000000000..f24d3e4812 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicExecutionModeId.hlsl @@ -0,0 +1,16 @@ +// RUN: %dxc -T ps_6_0 -E main -spirv + +// CHECK: OpCapability ShaderClockKHR +// CHECK: OpExtension "SPV_KHR_shader_clock" +// CHECK: OpExecutionModeId {{%\w+}} LocalSizeId %uint_8 %uint_8 %uint_8 +// CHECK: OpExecutionModeId {{%\w+}} LocalSizeHintId %uint_4 %uint_4 %uint_4 + +int main() : SV_Target0 { + vk::ext_execution_mode_id(/*LocalSizeId*/38, 8, 8, 8); + + [[vk::ext_capability(5055)]] + [[vk::ext_extension("SPV_KHR_shader_clock")]] + vk::ext_execution_mode_id(/*LocalSizeHintId*/39, 4, 4, 4); + + return 3; +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index 8c675fd7a6..0ed513c008 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -1344,6 +1344,7 @@ TEST_F(FileTest, IntrinsicsSpirv) { runFileTest("spv.intrinsicLiteral.hlsl"); runFileTest("spv.intrinsicDecorate.hlsl", Expect::Success, false); runFileTest("spv.intrinsicExecutionMode.hlsl", Expect::Success, false); + runFileTest("spv.intrinsicExecutionModeId.hlsl", Expect::Success, false); runFileTest("spv.intrinsicStorageClass.hlsl", Expect::Success, false); runFileTest("spv.intrinsicTypeInteger.hlsl"); runFileTest("spv.intrinsicTypeRayquery.hlsl", Expect::Success, false); diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index c80a3f8aa9..2b88596c00 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -379,7 +379,8 @@ namespace VkIntrinsics { u64 [[]] ReadClock(in uint scope); uint [[ro]] RawBufferLoad(in u64 addr); -void [[]] ext_execution_mode(...); +void [[]] ext_execution_mode(in uint mode, ...); +void [[]] ext_execution_mode_id(in uint mode, ...); } namespace // SPIRV Change Ends