diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index d3baef3c2b..1d728353cc 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -290,8 +290,7 @@ class SpirvContext { const SpirvIntrinsicType * getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, - llvm::ArrayRef constants, - SpirvIntrinsicType *elementTy); + llvm::ArrayRef operands); SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId); diff --git a/tools/clang/include/clang/SPIRV/SpirvType.h b/tools/clang/include/clang/SPIRV/SpirvType.h index 44a877517c..baf9c304ab 100644 --- a/tools/clang/include/clang/SPIRV/SpirvType.h +++ b/tools/clang/include/clang/SPIRV/SpirvType.h @@ -413,22 +413,35 @@ class RayQueryTypeKHR : public SpirvType { } }; -class SpirvConstant; +class SpirvInstruction; +struct SpvIntrinsicTypeOperand { + SpvIntrinsicTypeOperand(SpirvType *type_operand) + : operand_as_type(type_operand), isTypeOperand(true) {} + SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand) + : operand_as_inst(inst_operand), isTypeOperand(false) {} + union { + SpirvType *operand_as_type; + SpirvInstruction *operand_as_inst; + }; + bool isTypeOperand; +}; + class SpirvIntrinsicType : public SpirvType { public: - SpirvIntrinsicType(unsigned typeOp, llvm::ArrayRef constants, - SpirvIntrinsicType *elementTy); + SpirvIntrinsicType(unsigned typeOp, + llvm::ArrayRef inOps); + static bool classof(const SpirvType *t) { return t->getKind() == TK_SpirvIntrinsicType; } unsigned getOpCode() const { return typeOpCode; } - llvm::ArrayRef getLiterals() const { return literals; } - SpirvIntrinsicType *getElemType() const { return elementType; } + llvm::ArrayRef getOperands() const { + return operands; + } private: unsigned typeOpCode; - llvm::SmallVector literals; - SpirvIntrinsicType *elementType; + llvm::SmallVector operands; }; class HybridType : public SpirvType { diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 56459d539d..eece2db637 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -2450,15 +2450,22 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { initTypeInstruction(spv::Op::OpTypeRayQueryKHR); curTypeInst.push_back(id); finalizeTypeInstruction(); - } - else if (const auto *spvIntrinsicType = dyn_cast(type)) { + } else if (const auto *spvIntrinsicType = + dyn_cast(type)) { initTypeInstruction(static_cast(spvIntrinsicType->getOpCode())); curTypeInst.push_back(id); - if (spvIntrinsicType->getElemType()) { - curTypeInst.push_back(emitType(spvIntrinsicType->getElemType())); - } - for (auto& literal : spvIntrinsicType->getLiterals()) { - emitLiteral(literal, curTypeInst); + for (const SpvIntrinsicTypeOperand &operand : + spvIntrinsicType->getOperands()) { + if (operand.isTypeOperand) { + curTypeInst.push_back(emitType(operand.operand_as_type)); + } else { + auto *literal = dyn_cast(operand.operand_as_inst); + if (literal && literal->isLiteral()) { + emitLiteral(literal, curTypeInst); + } else { + curTypeInst.push_back(getOrAssignResultId(operand.operand_as_inst)); + } + } } finalizeTypeInstruction(); } @@ -2477,36 +2484,47 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) { return id; } +template +void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral, + vecType &outInst) { + const auto &literalVal = intLiteral->getValue(); + bool positive = !literalVal.isNegative(); + if (literalVal.getBitWidth() <= 32) { + outInst.push_back(positive ? literalVal.getZExtValue() + : literalVal.getSExtValue()); + } else { + assert(literalVal.getBitWidth() == 64); + uint64_t val = + positive ? literalVal.getZExtValue() : literalVal.getSExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } +} + +template +void EmitTypeHandler::emitFloatLiteral(const SpirvConstantFloat *fLiteral, + vecType &outInst) { + const auto &literalVal = fLiteral->getValue(); + const auto bitwidth = llvm::APFloat::getSizeInBits(literalVal.getSemantics()); + if (bitwidth <= 32) { + outInst.push_back(literalVal.bitcastToAPInt().getZExtValue()); + } else { + assert(bitwidth == 64); + uint64_t val = literalVal.bitcastToAPInt().getZExtValue(); + outInst.push_back(static_cast(val)); + outInst.push_back(static_cast(val >> 32)); + } +} + template void EmitTypeHandler::emitLiteral(const SpirvConstant *literal, VecType &outInst) { if (auto boolLiteral = dyn_cast(literal)) { outInst.push_back(static_cast(boolLiteral->getValue())); } else if (auto intLiteral = dyn_cast(literal)) { - const auto &literalVal = intLiteral->getValue(); - bool positive = !literalVal.isNegative(); - if (literalVal.getBitWidth() <= 32) { - outInst.push_back(positive ? literalVal.getZExtValue() - : literalVal.getSExtValue()); - } else { - assert(literalVal.getBitWidth() == 64); - uint64_t val = - positive ? literalVal.getZExtValue() : literalVal.getSExtValue(); - outInst.push_back(static_cast(val)); - outInst.push_back(static_cast(val >> 32)); - } + emitIntLiteral(intLiteral, outInst); } else if (auto fLiteral = dyn_cast(literal)) { - const auto &literalVal = fLiteral->getValue(); - const auto bitwidth = - llvm::APFloat::getSizeInBits(literalVal.getSemantics()); - if (bitwidth <= 32) { - outInst.push_back(literalVal.bitcastToAPInt().getZExtValue()); - } else { - assert(bitwidth == 64); - uint64_t val = literalVal.bitcastToAPInt().getZExtValue(); - outInst.push_back(static_cast(val)); - outInst.push_back(static_cast(val >> 32)); - } + emitFloatLiteral(fLiteral, outInst); } } diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index 38b2431fc1..f979247d69 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -111,6 +111,10 @@ class EmitTypeHandler { uint32_t getOrCreateConstantBool(SpirvConstantBoolean *); template void emitLiteral(const SpirvConstant *, vecType &outInst); + template + void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst); + template + void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst); private: void initTypeInstruction(spv::Op op); diff --git a/tools/clang/lib/SPIRV/SpirvContext.cpp b/tools/clang/lib/SPIRV/SpirvContext.cpp index bccb47dd43..3db12aac54 100644 --- a/tools/clang/lib/SPIRV/SpirvContext.cpp +++ b/tools/clang/lib/SPIRV/SpirvContext.cpp @@ -527,13 +527,12 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) { typeTemplateParams.clear(); } -const SpirvIntrinsicType * -SpirvContext::getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode, - llvm::ArrayRef constants, - SpirvIntrinsicType *elementTy) { +const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType( + unsigned typeId, unsigned typeOpCode, + llvm::ArrayRef operands) { if (spirvIntrinsicTypes[typeId] == nullptr) { spirvIntrinsicTypes[typeId] = - new (this) SpirvIntrinsicType(typeOpCode, constants, elementTy); + new (this) SpirvIntrinsicType(typeOpCode, operands); } return spirvIntrinsicTypes[typeId]; } diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 666c217dbe..cae6c6a68f 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -12609,7 +12609,6 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { auto funcDecl = expr->getDirectCallee(); auto typeDefAttr = funcDecl->getAttr(); SpirvIntrinsicType *elementType = nullptr; - SmallVector constants; llvm::SmallVector capbilities; llvm::SmallVector extensions; @@ -12621,23 +12620,32 @@ SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) { } } + SmallVector operands; const auto args = expr->getArgs(); for (uint32_t i = 0; i < expr->getNumArgs(); ++i) { auto param = funcDecl->getParamDecl(i); const Expr *arg = args[i]->IgnoreParenLValueCasts(); if (param->hasAttr()) { - auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType()); - elementType = spvContext.getCreatedSpirvIntrinsicType(typeId); + auto *recType = param->getType()->getAs(); + if (recType && recType->getDecl()->getName() == "ext_type") { + auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType()); + auto *typeArg = spvContext.getCreatedSpirvIntrinsicType(typeId); + operands.emplace_back(typeArg); + } else { + operands.emplace_back(doExpr(arg)); + } } else if (param->hasAttr()) { SpirvInstruction *argInst = doExpr(arg); auto constArg = dyn_cast(argInst); assert(constArg != nullptr); constArg->setLiteral(); - constants.push_back(constArg); + operands.emplace_back(constArg); + } else { + operands.emplace_back(loadIfGLValue(arg)); } } - spvContext.getSpirvIntrinsicType( - typeDefAttr->getId(), typeDefAttr->getOpcode(), constants, elementType); + spvContext.getSpirvIntrinsicType(typeDefAttr->getId(), + typeDefAttr->getOpcode(), operands); // Emit dummy OpNop with no semantic meaning, with possible extension and // capabilities diff --git a/tools/clang/lib/SPIRV/SpirvType.cpp b/tools/clang/lib/SPIRV/SpirvType.cpp index 91ba944416..b191fcd20b 100644 --- a/tools/clang/lib/SPIRV/SpirvType.cpp +++ b/tools/clang/lib/SPIRV/SpirvType.cpp @@ -168,11 +168,9 @@ bool RuntimeArrayType::operator==(const RuntimeArrayType &that) const { } SpirvIntrinsicType::SpirvIntrinsicType( - unsigned typeOp, llvm::ArrayRef constants, - SpirvIntrinsicType *eleTy) + unsigned typeOp, llvm::ArrayRef inOps) : SpirvType(TK_SpirvIntrinsicType, "spirvIntrinsicType"), - typeOpCode(typeOp), literals(constants.begin(), constants.end()), - elementType(eleTy) {} + typeOpCode(typeOp), operands(inOps.begin(), inOps.end()) {} StructType::StructType(llvm::ArrayRef fieldsVec, llvm::StringRef name, bool isReadOnly,