diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 8a8402a545..aae4a46a31 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -1010,6 +1010,14 @@ def VKBinding : InheritableAttr { let Documentation = [Undocumented]; } +def VKCapabilityExt : InheritableAttr { + let Spellings = [CXX11<"vk", "ext_capability">]; + let Subjects = SubjectList<[Function], ErrorDiag>; + let Args = [IntArgument<"capability">]; + let LangOpts = [SPIRV]; + let Documentation = [Undocumented]; +} + def VKCounterBinding : InheritableAttr { let Spellings = [CXX11<"vk", "counter_binding">]; let Subjects = SubjectList<[CounterStructuredBuffer], ErrorDiag, "ExpectedCounterStructuredBuffer">; @@ -1018,6 +1026,14 @@ def VKCounterBinding : InheritableAttr { let Documentation = [Undocumented]; } +def VKExtensionExt : InheritableAttr { + let Spellings = [CXX11<"vk", "ext_extension">]; + let Subjects = SubjectList<[Function], ErrorDiag>; + let Args = [StringArgument<"name">]; + let LangOpts = [SPIRV]; + let Documentation = [Undocumented]; +} + // Global variables that are of struct type def StructGlobalVar : SubsetSubjecthasGlobalStorage() && S->getType()->isStructureType()}]>; @@ -1083,6 +1099,14 @@ def VKInputAttachmentIndex : InheritableAttr { let Documentation = [Undocumented]; } +def VKInstructionExt : InheritableAttr { + let Spellings = [CXX11<"vk", "ext_instruction">]; + let Subjects = SubjectList<[Function], ErrorDiag>; + let Args = [IntArgument<"opcode">, StringArgument<"instruction_set", 1>]; + let LangOpts = [SPIRV]; + let Documentation = [Undocumented]; +} + // Global variables that are of scalar type def ScalarGlobalVar : SubsetSubjecthasGlobalStorage() && S->getType()->isScalarType()}]>; diff --git a/tools/clang/include/clang/SPIRV/SpirvBuilder.h b/tools/clang/include/clang/SPIRV/SpirvBuilder.h index 357d5a1d45..5b615c2fe7 100644 --- a/tools/clang/include/clang/SPIRV/SpirvBuilder.h +++ b/tools/clang/include/clang/SPIRV/SpirvBuilder.h @@ -507,6 +507,13 @@ class SpirvBuilder { /// OpIgnoreIntersectionKHR/OpTerminateIntersectionKHR void createRaytracingTerminateKHR(spv::Op opcode, SourceLocation loc); + /// \brief Create spirv intrinsic instructions + SpirvInstruction * + createSpirvIntrInstExt(uint32_t opcode, QualType retType, + llvm::ArrayRef operands, + llvm::StringRef ext, llvm::StringRef instSet, + llvm::ArrayRef capts, SourceLocation loc); + /// \brief Returns a clone SPIR-V variable for CTBuffer with FXC memory layout /// and creates copy instructions from the CTBuffer to the clone variable in /// module.init if it contains HLSL matrix 1xN. Otherwise, returns nullptr. diff --git a/tools/clang/include/clang/SPIRV/SpirvInstruction.h b/tools/clang/include/clang/SPIRV/SpirvInstruction.h index f3f9feaadc..2fbcf391d6 100644 --- a/tools/clang/include/clang/SPIRV/SpirvInstruction.h +++ b/tools/clang/include/clang/SPIRV/SpirvInstruction.h @@ -127,6 +127,7 @@ class SpirvInstruction { IK_Store, // OpStore IK_UnaryOp, // Unary operations IK_VectorShuffle, // OpVectorShuffle + IK_SpirvIntrinsicInstruction, // Spirv Instructions for no particular op // For DebugInfo instructions defined in OpenCL.DebugInfo.100 IK_DebugInfoNone, @@ -1999,6 +2000,36 @@ class SpirvDemoteToHelperInvocationEXT : public SpirvInstruction { bool invokeVisitor(Visitor *v) override; }; +class SpirvIntrinsicInstruction : public SpirvInstruction { +public: + SpirvIntrinsicInstruction(QualType resultType, uint32_t opcode, + llvm::ArrayRef operands, + llvm::StringRef ext, SpirvExtInstImport *set, + llvm::ArrayRef capts, SourceLocation loc); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvIntrinsicInstruction) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_SpirvIntrinsicInstruction; + } + + bool invokeVisitor(Visitor *v) override; + + llvm::ArrayRef getOperands() const { return operands; } + llvm::ArrayRef getCapabilities() const { return capabilities; } + llvm::StringRef getExtension() const { return extension; } + SpirvExtInstImport *getInstructionSet() const { return instructionSet; } + uint32_t getInstruction() const { return instruction; } + +private: + uint32_t instruction; + llvm::SmallVector operands; + llvm::SmallVector capabilities; + std::string extension; + SpirvExtInstImport *instructionSet; +}; + /// \breif Base class for all OpenCL.DebugInfo.100 extension instructions. /// Note that all of these instructions should be added to the SPIR-V module as /// an OpExtInst instructions. So, all of these instructions must: diff --git a/tools/clang/include/clang/SPIRV/SpirvVisitor.h b/tools/clang/include/clang/SPIRV/SpirvVisitor.h index 471d6ba908..819c531d8d 100644 --- a/tools/clang/include/clang/SPIRV/SpirvVisitor.h +++ b/tools/clang/include/clang/SPIRV/SpirvVisitor.h @@ -139,6 +139,7 @@ class Visitor { DEFINE_VISIT_METHOD(SpirvRayQueryOpKHR) DEFINE_VISIT_METHOD(SpirvReadClock) DEFINE_VISIT_METHOD(SpirvRayTracingTerminateOpKHR) + DEFINE_VISIT_METHOD(SpirvIntrinsicInstruction) #undef DEFINE_VISIT_METHOD protected: diff --git a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp index a08e4d300d..c72007401e 100644 --- a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp +++ b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp @@ -448,6 +448,16 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) { addCapability(getNonUniformCapability(resultType)); } + if (instr->getKind() == SpirvInstruction::IK_SpirvIntrinsicInstruction) { + SpirvIntrinsicInstruction *pSpvInst = + dyn_cast(instr); + for (auto &cap : pSpvInst->getCapabilities()) { + addCapability(static_cast(cap)); + } + if (pSpvInst->getExtension().size() > 0) + spvBuilder.requireExtension(pSpvInst->getExtension(), loc); + } + // Add opcode-specific capabilities switch (opcode) { case spv::Op::OpDPdxCoarse: diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 9203229409..69761c4d48 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -1659,6 +1659,23 @@ bool EmitVisitor::visit(SpirvRayTracingTerminateOpKHR *inst) { return true; } +bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) { + initInstruction(inst); + curInst.push_back(inst->getResultTypeId()); + curInst.push_back(getOrAssignResultId(inst)); + if (inst->getInstructionSet()) { + curInst.push_back( + getOrAssignResultId(inst->getInstructionSet())); + curInst.push_back(inst->getInstruction()); + } + + for (const auto operand : inst->getOperands()) + curInst.push_back(getOrAssignResultId(operand)); + + finalizeInstruction(&mainBinary); + return true; +} + // EmitTypeHandler ------ void EmitTypeHandler::initTypeInstruction(spv::Op op) { diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index c7438a6f25..dd2f2ded30 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -289,6 +289,7 @@ class EmitVisitor : public Visitor { bool visit(SpirvDebugTypeMember *) override; bool visit(SpirvDebugTypeTemplate *) override; bool visit(SpirvDebugTypeTemplateParameter *) override; + bool visit(SpirvIntrinsicInstruction *) override; using Visitor::visit; diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index 8baaa3ed9b..8c4bdf5d2e 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -962,6 +962,23 @@ SpirvInstruction *SpirvBuilder::createReadClock(SpirvInstruction *scope, return inst; } +SpirvInstruction *SpirvBuilder::createSpirvIntrInstExt( + uint32_t opcode, QualType retType, + llvm::ArrayRef operands, llvm::StringRef ext, + llvm::StringRef instSet, llvm::ArrayRef capts, + SourceLocation loc) { + assert(insertPoint && "null insert point"); + + SpirvExtInstImport *set = + (instSet.size() == 0) ? nullptr : getExtInstSet(instSet); + + auto *inst = new (context) SpirvIntrinsicInstruction( + retType, opcode, operands, ext, set, capts, loc); + + insertPoint->addInstruction(inst); + return inst; +} + void SpirvBuilder::createRaytracingTerminateKHR(spv::Op opcode, SourceLocation loc) { assert(insertPoint && "null insert point"); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 49e2a36a35..16def28f29 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -2288,11 +2288,16 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) { if (const auto *memberCall = dyn_cast(callExpr)) return doCXXMemberCallExpr(memberCall); + auto funcDecl = callExpr->getDirectCallee(); // Intrinsic functions such as 'dot' or 'mul' - if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) { + if (hlsl::IsIntrinsicOp(funcDecl)) { return processIntrinsicCallExpr(callExpr); } + if (funcDecl && funcDecl->hasAttr()) { + return processSpvIntrinsicCallExpr(callExpr); + } + // Normal standalone functions return processCall(callExpr); } @@ -12392,6 +12397,42 @@ SpirvEmitter::processRayQueryIntrinsics(const CXXMemberCallExpr *expr, return retVal; } +SpirvInstruction * +SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) { + auto funcDecl = expr->getDirectCallee(); + auto &attrs = funcDecl->getAttrs(); + QualType retType = funcDecl->getReturnType(); + llvm::SmallVector capbilities; + llvm::StringRef extExtension = ""; + llvm::StringRef instSet = ""; + uint32_t op = 0; + for (auto &attr : attrs) { + if (auto capAttr = dyn_cast(attr)) { + capbilities.push_back(capAttr->getCapability()); + } else if (auto extAttr = dyn_cast(attr)) { + extExtension = extAttr->getName(); + } else if (auto instAttr = dyn_cast(attr)) { + op = instAttr->getOpcode(); + instSet = instAttr->getInstruction_set(); + } + } + + llvm::SmallVector spvArgs; + + const auto args = expr->getArgs(); + for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { + spvArgs.push_back(doExpr(args[i])); + } + + const auto loc = expr->getExprLoc(); + + SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt( + op, retType, spvArgs, extExtension, instSet, capbilities, loc); + + retVal->setRValue(); + return retVal; +} + bool SpirvEmitter::spirvToolsValidate(std::vector *mod, std::string *messages) { spvtools::SpirvTools tools(featureManager.getTargetEnv()); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 17246656ff..2001660c49 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -590,6 +590,8 @@ class SpirvEmitter : public ASTConsumer { /// Process ray query intrinsics SpirvInstruction *processRayQueryIntrinsics(const CXXMemberCallExpr *expr, hlsl::IntrinsicOp opcode); + /// Process spirv intrinsic instruction + SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr); private: /// Returns the for constant value 0 of the given type. diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index 68f0a35696..0221efb269 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -106,6 +106,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeTemplateParameter) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayQueryOpKHR) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvReadClock) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayTracingTerminateOpKHR) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvIntrinsicInstruction) #undef DEFINE_INVOKE_VISITOR_FOR_CLASS @@ -1004,5 +1005,17 @@ SpirvRayTracingTerminateOpKHR::SpirvRayTracingTerminateOpKHR(spv::Op opcode, opcode == spv::Op::OpIgnoreIntersectionKHR); } +SpirvIntrinsicInstruction::SpirvIntrinsicInstruction( + QualType resultType, uint32_t opcode, + llvm::ArrayRef vecOperands, llvm::StringRef ext, + SpirvExtInstImport *set, llvm::ArrayRef capts, SourceLocation loc) + : SpirvInstruction(IK_SpirvIntrinsicInstruction, + set != nullptr ? spv::Op::OpExtInst + : static_cast(opcode), + resultType, loc), + instruction(opcode), operands(vecOperands.begin(), vecOperands.end()), + capabilities(capts.begin(), capts.end()), extension(ext.data()), + instructionSet(set) {} + } // namespace spirv } // namespace clang diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 5332d74b4c..11b6a0deb5 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -11945,6 +11945,22 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, case AttributeList::AT_VKShaderRecordEXT: declAttr = ::new (S.Context) VKShaderRecordEXTAttr(A.getRange(), S.Context, A.getAttributeSpellingListIndex()); break; + case AttributeList::AT_VKCapabilityExt: + declAttr = ::new (S.Context) VKCapabilityExtAttr( + A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + A.getAttributeSpellingListIndex()); + break; + case AttributeList::AT_VKExtensionExt: + declAttr = ::new (S.Context) VKExtensionExtAttr( + A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr), + A.getAttributeSpellingListIndex()); + break; + case AttributeList::AT_VKInstructionExt: + declAttr = ::new (S.Context) VKInstructionExtAttr( + A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + ValidateAttributeStringArg(S, A, nullptr, 1), + A.getAttributeSpellingListIndex()); + break; default: Handled = false; return; diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicInstruction.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicInstruction.hlsl new file mode 100644 index 0000000000..6714299562 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicInstruction.hlsl @@ -0,0 +1,34 @@ +// Run: %dxc -T vs_6_0 -E main + +struct SInstanceData { + float4x3 VisualToWorld; + float4 Output; +}; + +struct VS_INPUT { + float3 Position : POSITION; + SInstanceData InstanceData : TEXCOORD4; +}; + +[[vk::ext_capability(5055)]] +[[vk::ext_extension("SPV_KHR_shader_clock")]] +[[vk::ext_instruction(/* OpReadClockKHR */ 5056)]] +uint64_t ReadClock(uint scope); + +[[vk::ext_instruction(/* Sin*/ 13, "GLSL.std.450")]] +float4 spv_sin(float4 v); + +// CHECK: OpCapability ShaderClockKHR +// CHECK-NEXT: OpExtension "SPV_KHR_shader_clock" +// CHECK-NEXT: {{%\d+}} = OpExtInstImport "GLSL.std.450" + +float4 main(const VS_INPUT v) : SV_Position { + SInstanceData I = v.InstanceData; + uint64_t clock; +// CHECK: {{%\d+}} = OpExtInst %v4float {{%\d+}} Sin {{%\d+}} + I.Output = spv_sin(v.InstanceData.Output); +// CHECK: {{%\d+}} = OpReadClockKHR %ulong %uint_1 + clock = ReadClock(vk::DeviceScope); + + return I.Output; +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index c48ab6c1c4..5fe749b429 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -1320,6 +1320,10 @@ TEST_F(FileTest, IntrinsicsVkReadClock) { runFileTest("intrinsics.vkreadclock.hlsl"); } +TEST_F(FileTest, IntrinsicsSpirv) { + runFileTest("spv.intrinsicInstruction.hlsl"); +} + // Intrinsics added in SM 6.6 TEST_F(FileTest, IntrinsicsSM66PackU8S8) { runFileTest("intrinsics.sm6_6.pack_s8u8.hlsl");