From 2eae8d3d329ef88ce4cbd2ecdfb8ef0489e6316f Mon Sep 17 00:00:00 2001 From: JiaoluAMD Date: Mon, 29 Nov 2021 19:18:12 +0800 Subject: [PATCH] [SPIRV] Add support of [[vk::ext_type_def]] (#4068) Support [[vk::ext_type_def]] and vk::ext_type. This is related https://github.com/microsoft/DirectXShaderCompiler/issues/3919 Co-authored-by: Jaebaek Seo --- tools/clang/include/clang/AST/HlslTypes.h | 3 + tools/clang/include/clang/Basic/Attr.td | 8 +++ .../clang/include/clang/SPIRV/SpirvContext.h | 7 ++ .../include/clang/SPIRV/SpirvInstruction.h | 12 ++-- tools/clang/include/clang/SPIRV/SpirvType.h | 32 +++++++++ tools/clang/lib/AST/ASTContextHLSL.cpp | 9 ++- tools/clang/lib/SPIRV/CapabilityVisitor.cpp | 5 +- tools/clang/lib/SPIRV/EmitVisitor.cpp | 69 +++++++++++++++++-- tools/clang/lib/SPIRV/EmitVisitor.h | 6 ++ tools/clang/lib/SPIRV/LowerTypeVisitor.cpp | 5 ++ tools/clang/lib/SPIRV/SpirvBuilder.cpp | 7 +- tools/clang/lib/SPIRV/SpirvContext.cpp | 23 +++++++ tools/clang/lib/SPIRV/SpirvEmitter.cpp | 63 +++++++++++++++-- tools/clang/lib/SPIRV/SpirvEmitter.h | 3 + tools/clang/lib/SPIRV/SpirvInstruction.cpp | 16 +++-- tools/clang/lib/SPIRV/SpirvType.cpp | 7 +- tools/clang/lib/Sema/SemaHLSL.cpp | 44 +++++++++--- .../spv.intrinsicTypeInteger.hlsl | 21 ++++++ .../spv.intrinsicTypeRayquery.hlsl | 40 +++++++++++ .../unittests/SPIRV/CodeGenSpirvTest.cpp | 2 + 20 files changed, 345 insertions(+), 37 deletions(-) create mode 100644 tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl create mode 100644 tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index 73fd883d0c..9f01dd6183 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -340,6 +340,9 @@ clang::CXXRecordDecl* DeclareTemplateTypeWithHandle( clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle( clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName); +clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandleInDeclContext( + clang::ASTContext &context, clang::DeclContext *declContext, + llvm::StringRef typeName, llvm::StringRef templateParamName); clang::CXXRecordDecl *DeclareConstantBufferViewType(clang::ASTContext& context, bool bTBuf); clang::CXXRecordDecl* DeclareRayQueryType(clang::ASTContext& context); clang::CXXRecordDecl *DeclareResourceType(clang::ASTContext &context, diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 9a0e89a32c..3a9f0cc8cc 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -1145,6 +1145,14 @@ def VKReferenceExt : InheritableAttr { let Documentation = [Undocumented]; } +def VKTypeDefExt : InheritableAttr { + let Spellings = [CXX11<"vk", "ext_type_def">]; + let Subjects = SubjectList<[Function], ErrorDiag>; + let Args = [UnsignedArgument<"id">, UnsignedArgument<"opcode">]; + let LangOpts = [SPIRV]; + let Documentation = [Undocumented]; +} + // Global variables that are of scalar type def ScalarGlobalVar : SubsetSubjecthasGlobalStorage() && S->getType()->isScalarType()}]>; diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index 7e730cc491..1d728353cc 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -288,6 +288,12 @@ class SpirvContext { return rayQueryTypeKHR; } + const SpirvIntrinsicType * + getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, + llvm::ArrayRef operands); + + SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId); + /// --- Hybrid type getter functions --- /// /// Concrete SpirvType objects represent a SPIR-V type completely. Hybrid @@ -467,6 +473,7 @@ class SpirvContext { llvm::DenseMap pointerTypes; llvm::SmallVector hybridPointerTypes; llvm::DenseSet functionTypes; + llvm::DenseMap spirvIntrinsicTypes; const AccelerationStructureTypeNV *accelerationStructureTypeNV; const RayQueryTypeKHR *rayQueryTypeKHR; diff --git a/tools/clang/include/clang/SPIRV/SpirvInstruction.h b/tools/clang/include/clang/SPIRV/SpirvInstruction.h index fba2316496..0e77458c81 100644 --- a/tools/clang/include/clang/SPIRV/SpirvInstruction.h +++ b/tools/clang/include/clang/SPIRV/SpirvInstruction.h @@ -1110,10 +1110,13 @@ class SpirvConstant : public SpirvInstruction { } bool isSpecConstant() const; + void setLiteral(bool literal = true) { literalConstant = literal; } + bool isLiteral() { return literalConstant; } protected: - SpirvConstant(Kind, spv::Op, const SpirvType *); - SpirvConstant(Kind, spv::Op, QualType); + SpirvConstant(Kind, spv::Op, const SpirvType *, bool literal = false); + SpirvConstant(Kind, spv::Op, QualType, bool literal = false); + bool literalConstant; }; class SpirvConstantBoolean : public SpirvConstant { @@ -1141,7 +1144,7 @@ class SpirvConstantBoolean : public SpirvConstant { class SpirvConstantInteger : public SpirvConstant { public: SpirvConstantInteger(QualType type, llvm::APInt value, - bool isSpecConst = false, bool literal = false); + bool isSpecConst = false); DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger) @@ -1155,12 +1158,9 @@ class SpirvConstantInteger : public SpirvConstant { bool invokeVisitor(Visitor *v) override; llvm::APInt getValue() const { return value; } - void setLiteral(bool l = true) { isLiteral = l; } - bool getLiteral() { return isLiteral; } private: llvm::APInt value; - bool isLiteral; }; class SpirvConstantFloat : public SpirvConstant { diff --git a/tools/clang/include/clang/SPIRV/SpirvType.h b/tools/clang/include/clang/SPIRV/SpirvType.h index 2158eb540a..baf9c304ab 100644 --- a/tools/clang/include/clang/SPIRV/SpirvType.h +++ b/tools/clang/include/clang/SPIRV/SpirvType.h @@ -49,6 +49,7 @@ class SpirvType { TK_Function, TK_AccelerationStructureNV, TK_RayQueryKHR, + TK_SpirvIntrinsicType, // Order matters: all the following are hybrid types TK_HybridStruct, TK_HybridPointer, @@ -412,6 +413,37 @@ class RayQueryTypeKHR : public SpirvType { } }; +class SpirvInstruction; +struct SpvIntrinsicTypeOperand { + SpvIntrinsicTypeOperand(SpirvType *type_operand) + : operand_as_type(type_operand), isTypeOperand(true) {} + SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand) + : operand_as_inst(inst_operand), isTypeOperand(false) {} + union { + SpirvType *operand_as_type; + SpirvInstruction *operand_as_inst; + }; + bool isTypeOperand; +}; + +class SpirvIntrinsicType : public SpirvType { +public: + SpirvIntrinsicType(unsigned typeOp, + llvm::ArrayRef inOps); + + static bool classof(const SpirvType *t) { + return t->getKind() == TK_SpirvIntrinsicType; + } + unsigned getOpCode() const { return typeOpCode; } + llvm::ArrayRef getOperands() const { + return operands; + } + +private: + unsigned typeOpCode; + llvm::SmallVector operands; +}; + class HybridType : public SpirvType { public: static bool classof(const SpirvType *t) { diff --git a/tools/clang/lib/AST/ASTContextHLSL.cpp b/tools/clang/lib/AST/ASTContextHLSL.cpp index 6a6e9a8ded..f849e8ae63 100644 --- a/tools/clang/lib/AST/ASTContextHLSL.cpp +++ b/tools/clang/lib/AST/ASTContextHLSL.cpp @@ -840,8 +840,15 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams( CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle( ASTContext& context, StringRef typeName, StringRef templateParamName) { + return DeclareUIntTemplatedTypeWithHandleInDeclContext( + context, context.getTranslationUnitDecl(), typeName, templateParamName); +} + +CXXRecordDecl *hlsl::DeclareUIntTemplatedTypeWithHandleInDeclContext( + ASTContext &context, DeclContext *declContext, StringRef typeName, + StringRef templateParamName) { // template FeedbackTexture2D[Array] { ... } - BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName); + BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName); typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy); typeDeclBuilder.startDefinition(); typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle. diff --git a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp index a3250a7b83..aed9adf0a2 100644 --- a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp +++ b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp @@ -529,8 +529,9 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) { } case spv::Op::OpRayQueryInitializeKHR: { auto rayQueryInst = dyn_cast(instr); - if (rayQueryInst->hasCullFlags()) { - addCapability(spv::Capability::RayTraversalPrimitiveCullingKHR); + if (rayQueryInst && rayQueryInst->hasCullFlags()) { + addCapability( + spv::Capability::RayTraversalPrimitiveCullingKHR); } break; diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index cd72488a72..eece2db637 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -1884,10 +1884,9 @@ bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) { } for (const auto operand : inst->getOperands()) { - // TODO: Handle Literals with other types. - auto literalOperand = dyn_cast(operand); - if (literalOperand && literalOperand->getLiteral()) { - curInst.push_back(literalOperand->getValue().getZExtValue()); + auto literalOperand = dyn_cast(operand); + if (literalOperand && literalOperand->isLiteral()) { + typeHandler.emitLiteral(literalOperand, curInst); } else { curInst.push_back(getOrAssignResultId(operand)); } @@ -2451,6 +2450,24 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { initTypeInstruction(spv::Op::OpTypeRayQueryKHR); curTypeInst.push_back(id); finalizeTypeInstruction(); + } else if (const auto *spvIntrinsicType = + dyn_cast(type)) { + initTypeInstruction(static_cast(spvIntrinsicType->getOpCode())); + curTypeInst.push_back(id); + for (const SpvIntrinsicTypeOperand &operand : + spvIntrinsicType->getOperands()) { + if (operand.isTypeOperand) { + curTypeInst.push_back(emitType(operand.operand_as_type)); + } else { + auto *literal = dyn_cast(operand.operand_as_inst); + if (literal && literal->isLiteral()) { + emitLiteral(literal, curTypeInst); + } else { + curTypeInst.push_back(getOrAssignResultId(operand.operand_as_inst)); + } + } + } + finalizeTypeInstruction(); } // Hybrid Types // Note: The type lowering pass should lower all types to SpirvTypes. @@ -2467,6 +2484,50 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { return id; } +template +void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral, + vecType &outInst) { + const auto &literalVal = intLiteral->getValue(); + bool positive = !literalVal.isNegative(); + if (literalVal.getBitWidth() <= 32) { + outInst.push_back(positive ? literalVal.getZExtValue() + : literalVal.getSExtValue()); + } else { + assert(literalVal.getBitWidth() == 64); + uint64_t val = + positive ? literalVal.getZExtValue() : literalVal.getSExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } +} + +template +void EmitTypeHandler::emitFloatLiteral(const SpirvConstantFloat *fLiteral, + vecType &outInst) { + const auto &literalVal = fLiteral->getValue(); + const auto bitwidth = llvm::APFloat::getSizeInBits(literalVal.getSemantics()); + if (bitwidth <= 32) { + outInst.push_back(literalVal.bitcastToAPInt().getZExtValue()); + } else { + assert(bitwidth == 64); + uint64_t val = literalVal.bitcastToAPInt().getZExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } +} + +template +void EmitTypeHandler::emitLiteral(const SpirvConstant *literal, + VecType &outInst) { + if (auto boolLiteral = dyn_cast(literal)) { + outInst.push_back(static_cast(boolLiteral->getValue())); + } else if (auto intLiteral = dyn_cast(literal)) { + emitIntLiteral(intLiteral, outInst); + } else if (auto fLiteral = dyn_cast(literal)) { + emitFloatLiteral(fLiteral, outInst); + } +} + void EmitTypeHandler::emitDecoration(uint32_t typeResultId, spv::Decoration decoration, llvm::ArrayRef decorationParams, diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index 38ec433818..f979247d69 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -109,6 +109,12 @@ class EmitTypeHandler { uint32_t getOrCreateConstantComposite(SpirvConstantComposite *); uint32_t getOrCreateConstantNull(SpirvConstantNull *); uint32_t getOrCreateConstantBool(SpirvConstantBoolean *); + template + void emitLiteral(const SpirvConstant *, vecType &outInst); + template + void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst); + template + void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst); private: void initTypeInstruction(spv::Op op); diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index 92aa8f67e1..ad4379ff2f 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -593,6 +593,11 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type, if (name == "RayQuery") return spvContext.getRayQueryTypeKHR(); + if (name == "ext_type") { + auto typeId = hlsl::GetHLSLResourceTemplateUInt(type); + return spvContext.getCreatedSpirvIntrinsicType(typeId); + } + if (name == "StructuredBuffer" || name == "RWStructuredBuffer" || name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") { // StructureBuffer will be translated into an OpTypeStruct with one diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index ccf74b8c1c..cfef113a92 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -1001,10 +1001,13 @@ SpirvInstruction *SpirvBuilder::createSpirvIntrInstExt( SpirvExtInstImport *set = (instSet.size() == 0) ? nullptr : getExtInstSet(instSet); + + if (retType != QualType() && retType->isVoidType()) { + retType = QualType(); + } auto *inst = new (context) SpirvIntrinsicInstruction( - retType->isVoidType() ? QualType() : retType, opcode, operands, - extensions, set, capablities, loc); + retType, opcode, operands, extensions, set, capablities, loc); insertPoint->addInstruction(inst); return inst; } diff --git a/tools/clang/lib/SPIRV/SpirvContext.cpp b/tools/clang/lib/SPIRV/SpirvContext.cpp index 9f4c988120..9c180a3ef6 100644 --- a/tools/clang/lib/SPIRV/SpirvContext.cpp +++ b/tools/clang/lib/SPIRV/SpirvContext.cpp @@ -95,6 +95,11 @@ SpirvContext::~SpirvContext() { for (auto &typePair : typeTemplateParams) typePair.second->releaseMemory(); + + for (auto &pair : spirvIntrinsicTypes) { + assert(pair.second); + pair.second->~SpirvIntrinsicType(); + } } inline uint32_t log2ForBitwidth(uint32_t bitwidth) { @@ -527,5 +532,23 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) { typeTemplateParams.clear(); } +const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType( + unsigned typeId, unsigned typeOpCode, + llvm::ArrayRef operands) { + if (spirvIntrinsicTypes[typeId] == nullptr) { + spirvIntrinsicTypes[typeId] = + new (this) SpirvIntrinsicType(typeOpCode, operands); + } + return spirvIntrinsicTypes[typeId]; +} + +SpirvIntrinsicType * +SpirvContext::getCreatedSpirvIntrinsicType(unsigned typeId) { + if (spirvIntrinsicTypes.find(typeId) == spirvIntrinsicTypes.end()){ + return nullptr; + } + return spirvIntrinsicTypes[typeId]; +} + } // end namespace spirv } // end namespace clang diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index eeea430520..3357e548b5 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -18,9 +18,9 @@ #include "dxc/HlslIntrinsicOp.h" #include "spirv-tools/optimizer.hpp" #include "clang/SPIRV/AstTypeProbe.h" +#include "clang/SPIRV/String.h" #include "clang/Sema/Sema.h" #include "llvm/ADT/StringExtras.h" - #include "InitListHandler.h" #include "dxc/DXIL/DxilConstants.h" @@ -2337,8 +2337,11 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) { return doCXXMemberCallExpr(memberCall); auto funcDecl = callExpr->getDirectCallee(); - if (funcDecl && funcDecl->hasAttr()) { - return processSpvIntrinsicCallExpr(callExpr); + if (funcDecl) { + if (funcDecl->hasAttr()) + return processSpvIntrinsicCallExpr(callExpr); + else if(funcDecl->hasAttr()) + return processSpvIntrinsicTypeDef(callExpr); } // Intrinsic functions such as 'dot' or 'mul' if (hlsl::IsIntrinsicOp(funcDecl)) { @@ -12530,7 +12533,7 @@ SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) { } spvArgs.push_back(argInst); } else if (param->hasAttr()) { - auto constArg = dyn_cast(argInst); + auto constArg = dyn_cast(argInst); assert(constArg != nullptr); constArg->setLiteral(); spvArgs.push_back(argInst); @@ -12601,6 +12604,58 @@ SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) { execModesParams, expr->getExprLoc()); } +SpirvInstruction * +SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { + auto funcDecl = expr->getDirectCallee(); + auto typeDefAttr = funcDecl->getAttr(); + llvm::SmallVector capbilities; + llvm::SmallVector extensions; + + for (auto &attr : funcDecl->getAttrs()) { + if (auto capAttr = dyn_cast(attr)) { + capbilities.push_back(capAttr->getCapability()); + } else if (auto extAttr = dyn_cast(attr)) { + extensions.push_back(extAttr->getName()); + } + } + + SmallVector operands; + const auto args = expr->getArgs(); + for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { + auto param = funcDecl->getParamDecl(i); + const Expr *arg = args[i]->IgnoreParenLValueCasts(); + if (param->hasAttr()) { + auto *recType = param->getType()->getAs(); + if (recType && recType->getDecl()->getName() == "ext_type") { + auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType()); + auto *typeArg = spvContext.getCreatedSpirvIntrinsicType(typeId); + operands.emplace_back(typeArg); + } else { + operands.emplace_back(doExpr(arg)); + } + } else if (param->hasAttr()) { + SpirvInstruction *argInst = doExpr(arg); + auto constArg = dyn_cast(argInst); + assert(constArg != nullptr); + constArg->setLiteral(); + operands.emplace_back(constArg); + } else { + operands.emplace_back(loadIfGLValue(arg)); + } + } + spvContext.getSpirvIntrinsicType(typeDefAttr->getId(), + typeDefAttr->getOpcode(), operands); + + // Emit dummy OpNop with no semantic meaning, with possible extension and + // capabilities + SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt( + static_cast(spv::Op::OpNop), QualType(), {}, extensions, {}, + capbilities, expr->getExprLoc()); + retVal->setRValue(); + + return retVal; +} + bool SpirvEmitter::spirvToolsValidate(std::vector *mod, std::string *messages) { spvtools::SpirvTools tools(featureManager.getTargetEnv()); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 900c63da82..521d44bf76 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -596,6 +596,9 @@ class SpirvEmitter : public ASTConsumer { hlsl::IntrinsicOp opcode); /// Process spirv intrinsic instruction SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr); + + /// Process spirv intrinsic type definition + SpirvInstruction *processSpvIntrinsicTypeDef(const CallExpr *expr); /// Custom intrinsic to support basic buffer_reference use case SpirvInstruction *processRawBufferLoad(const CallExpr *callExpr); diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index 1d4ee08c32..fb4c6cf887 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -476,15 +476,19 @@ SpirvCompositeConstruct::SpirvCompositeConstruct( resultType, loc), consituents(constituentsVec.begin(), constituentsVec.end()) {} -SpirvConstant::SpirvConstant(Kind kind, spv::Op op, const SpirvType *spvType) +SpirvConstant::SpirvConstant(Kind kind, spv::Op op, const SpirvType *spvType, + bool literal) : SpirvInstruction(kind, op, QualType(), - /*SourceLocation*/ {}) { + /*SourceLocation*/ {}), + literalConstant(literal) { setResultType(spvType); } -SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType) +SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType, + bool literal) : SpirvInstruction(kind, op, resultType, - /*SourceLocation*/ {}) {} + /*SourceLocation*/ {}), + literalConstant(literal) {} bool SpirvConstant::isSpecConstant() const { return opcode == spv::Op::OpSpecConstant || @@ -509,11 +513,11 @@ bool SpirvConstantBoolean::operator==(const SpirvConstantBoolean &that) const { } SpirvConstantInteger::SpirvConstantInteger(QualType type, llvm::APInt val, - bool isSpecConst, bool literal) + bool isSpecConst) : SpirvConstant(IK_ConstantInteger, isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant, type), - value(val), isLiteral(literal) { + value(val) { assert(type->isIntegerType()); } diff --git a/tools/clang/lib/SPIRV/SpirvType.cpp b/tools/clang/lib/SPIRV/SpirvType.cpp index d155767714..b191fcd20b 100644 --- a/tools/clang/lib/SPIRV/SpirvType.cpp +++ b/tools/clang/lib/SPIRV/SpirvType.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "clang/SPIRV/SpirvType.h" - +#include "clang/SPIRV/SpirvInstruction.h" #include namespace clang { @@ -167,6 +167,11 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const { (!stride.hasValue() || stride.getValue() == that.stride.getValue()); } +SpirvIntrinsicType::SpirvIntrinsicType( + unsigned typeOp, llvm::ArrayRef inOps) + : SpirvType(TK_SpirvIntrinsicType, "spirvIntrinsicType"), + typeOpCode(typeOp), operands(inOps.begin(), inOps.end()) {} + StructType::StructType(llvm::ArrayRef fieldsVec, llvm::StringRef name, bool isReadOnly, StructInterfaceType iface) diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index db94af2588..8379be1055 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -181,6 +181,7 @@ enum ArBasicKind { #ifdef ENABLE_SPIRV_CODEGEN AR_OBJECT_VK_SUBPASS_INPUT, AR_OBJECT_VK_SUBPASS_INPUT_MS, + AR_OBJECT_VK_SPV_INTRINSIC_TYPE, #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -472,6 +473,7 @@ const UINT g_uBasicKindProps[] = #ifdef ENABLE_SPIRV_CODEGEN BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT_MS + BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE use recordType #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1395,6 +1397,7 @@ const ArBasicKind g_ArBasicKindsAsTypes[] = #ifdef ENABLE_SPIRV_CODEGEN AR_OBJECT_VK_SUBPASS_INPUT, AR_OBJECT_VK_SUBPASS_INPUT_MS, + AR_OBJECT_VK_SPV_INTRINSIC_TYPE, #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1486,7 +1489,8 @@ const uint8_t g_ArBasicKindsTemplateCount[] = // SPIRV change starts #ifdef ENABLE_SPIRV_CODEGEN 1, // AR_OBJECT_VK_SUBPASS_INPUT - 1, // AR_OBJECT_VK_SUBPASS_INPUT_MS + 1, // AR_OBJECT_VK_SUBPASS_INPUT_MS, + 1, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1587,6 +1591,7 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] = #ifdef ENABLE_SPIRV_CODEGEN { 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SUBPASS_INPUT (SubpassInput) { 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SUBPASS_INPUT_MS (SubpassInputMS) + { 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -1706,6 +1711,7 @@ const char* g_ArBasicTypeNames[] = #ifdef ENABLE_SPIRV_CODEGEN "SubpassInput", "SubpassInputMS", + "ext_type", #endif // ENABLE_SPIRV_CODEGEN // SPIRV change ends @@ -3588,11 +3594,16 @@ class HLSLExternalSource : public ExternalSemaSource { else if (kind == AR_OBJECT_FEEDBACKTEXTURE2D_ARRAY) { recordDecl = DeclareUIntTemplatedTypeWithHandle(*m_context, "FeedbackTexture2DArray", "kind"); } +#ifdef ENABLE_SPIRV_CODEGEN + else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_TYPE && m_vkNSDecl) { + recordDecl = DeclareUIntTemplatedTypeWithHandleInDeclContext( + *m_context, m_vkNSDecl, typeName, "id"); + recordDecl->setImplicit(true); + } +#endif else if (templateArgCount == 0) { recordDecl = DeclareRecordTypeWithHandle(*m_context, typeName); - } - else - { + } else { DXASSERT(templateArgCount == 1 || templateArgCount == 2, "otherwise a new case has been added"); TypeSourceInfo* typeDefault = TemplateHasDefaultType(kind) ? float4TypeSourceInfo : nullptr; @@ -3712,12 +3723,6 @@ class HLSLExternalSource : public ExternalSemaSource { m_sema = &S; S.addExternalSource(this); - AddObjectTypes(); - AddStdIsEqualImplementation(context, S); - for (auto && intrinsic : m_intrinsicTables) { - AddIntrinsicTableMethods(intrinsic); - } - #ifdef ENABLE_SPIRV_CODEGEN if (m_sema->getLangOpts().SPIRV) { // Create the "vk" namespace which contains Vulkan-specific intrinsics. @@ -3727,7 +3732,17 @@ class HLSLExternalSource : public ExternalSemaSource { SourceLocation(), &context.Idents.get("vk"), /*PrevDecl*/ nullptr); context.getTranslationUnitDecl()->addDecl(m_vkNSDecl); + } +#endif // ENABLE_SPIRV_CODEGEN + + AddObjectTypes(); + AddStdIsEqualImplementation(context, S); + for (auto &&intrinsic : m_intrinsicTables) { + AddIntrinsicTableMethods(intrinsic); + } +#ifdef ENABLE_SPIRV_CODEGEN + if (m_sema->getLangOpts().SPIRV) { // Add Vulkan-specific intrinsics. AddVkIntrinsicFunctions(); AddVkIntrinsicConstants(); @@ -12106,6 +12121,12 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, A.getRange(), S.Context, unsigned(ValidateAttributeIntArg(S, A)), A.getAttributeSpellingListIndex()); break; + case AttributeList::AT_VKTypeDefExt: + declAttr = ::new (S.Context) VKTypeDefExtAttr( + A.getRange(), S.Context, unsigned(ValidateAttributeIntArg(S, A)), + unsigned(ValidateAttributeIntArg(S, A, 1)), + A.getAttributeSpellingListIndex()); + break; default: Handled = false; return; @@ -12888,7 +12909,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth, // Validate that Vulkan specific feature is only used when targeting SPIR-V if (!getLangOpts().SPIRV) { if (basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT || - basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT_MS) { + basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT_MS || + basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_TYPE) { Diag(D.getLocStart(), diag::err_hlsl_vulkan_specific_feature) << g_ArBasicTypeNames[basicKind]; result = false; diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl new file mode 100644 index 0000000000..8f0c5d5431 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeInteger.hlsl @@ -0,0 +1,21 @@ +// RUN: %dxc -T ps_6_0 -E main -spirv + +[[vk::ext_type_def(0, 21)]] +void createTypeInt([[vk::ext_literal]] int sizeInBits, + [[vk::ext_literal]] int signedness); + +[[vk::ext_type_def(1, 23)]] +void createTypeVector([[vk::ext_reference]] vk::ext_type<0> typeInt, + [[vk::ext_literal]] int componentCount); + +//CHECK: %spirvIntrinsicType = OpTypeInt 32 0 +//CHECK: %spirvIntrinsicType_0 = OpTypeVector %spirvIntrinsicType 4 + +vk::ext_type<0> foo1; +vk::ext_type<1> foo2; +float main() : SV_Target +{ + createTypeInt(32, 0); + createTypeVector(foo1, 4); + return 0.0; +} diff --git a/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl new file mode 100644 index 0000000000..bd0715828d --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spv.intrinsicTypeRayquery.hlsl @@ -0,0 +1,40 @@ +// RUN: %dxc -T cs_6_5 -E main -spirv + +[[vk::ext_capability(/* RayQueryKHR */ 4472)]] +[[vk::ext_extension("SPV_KHR_ray_query")]] +[[vk::ext_type_def(/* Unique id for type */ 2, + /* OpTypeRayQueryKHR */ 4472)]] +void createTypeRayQueryKHR(); + +[[vk::ext_type_def(/* Unique id for type */ 3, + /* OpTypeAccelerationStructureKHR */ 5341)]] +void createAcceleStructureType(); + +[[vk::ext_instruction(/* OpRayQueryTerminateKHR */ 4474)]] +void rayQueryTerminateEXT( + [[vk::ext_reference]] vk::ext_type<2> rq); + +vk::ext_type<3> as : register(t0); + +[[vk::ext_instruction(/* OpRayQueryInitializeKHR */ 4473)]] +void rayQueryInitializeEXT([[vk::ext_reference]] vk::ext_type<2> rayQuery, vk::ext_type<3> as, uint rayFlags, uint cullMask, float3 origin, float tMin, float3 direction, float tMax); + +[[vk::ext_instruction(/* OpRayQueryTerminateKHR */ 4474)]] +void rayQueryTerminateEXT( + [[vk::ext_reference]] vk::ext_type<2> rq ); + +//CHECK: %spirvIntrinsicType = OpTypeAccelerationStructureKHR +//CHECK: %spirvIntrinsicType_0 = OpTypeRayQueryKHR + +//CHECK: OpRayQueryInitializeKHR %rq {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} {{%\w+}} +//CHECK: OpRayQueryTerminateKHR %rq + +[numthreads(64, 1, 1)] +void main() +{ + createTypeRayQueryKHR(); + createAcceleStructureType(); + vk::ext_type<2> rq; + rayQueryInitializeEXT(rq, as, 0, 0, float3(0, 0, 0), 0.0, float3(1,1,1), 1.0); + rayQueryTerminateEXT(rq); +} diff --git a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp index c368ec5f41..17442719f5 100644 --- a/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp +++ b/tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp @@ -1348,6 +1348,8 @@ TEST_F(FileTest, IntrinsicsSpirv) { runFileTest("spv.intrinsicDecorate.hlsl", Expect::Success, false); runFileTest("spv.intrinsicExecutionMode.hlsl", Expect::Success, false); runFileTest("spv.intrinsicStorageClass.hlsl", Expect::Success, false); + runFileTest("spv.intrinsicTypeInteger.hlsl"); + runFileTest("spv.intrinsicTypeRayquery.hlsl", Expect::Success, false); runFileTest("spv.intrinsic.reference.error.hlsl", Expect::Failure); } TEST_F(FileTest, IntrinsicsVkReadClock) {