Skip to content

Commit

Permalink
[SPIRV] Add support of [[vk::ext_type_def]] (#4068)
Browse files Browse the repository at this point in the history
Support [[vk::ext_type_def]] and vk::ext_type.
This is related
#3919

Co-authored-by: Jaebaek Seo <jaebaek@google.com>
  • Loading branch information
jiaolu and jaebaek authored Nov 29, 2021
1 parent 676fe64 commit 2eae8d3
Show file tree
Hide file tree
Showing 20 changed files with 345 additions and 37 deletions.
3 changes: 3 additions & 0 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ clang::CXXRecordDecl* DeclareTemplateTypeWithHandle(

clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle(
clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName);
clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandleInDeclContext(
clang::ASTContext &context, clang::DeclContext *declContext,
llvm::StringRef typeName, llvm::StringRef templateParamName);
clang::CXXRecordDecl *DeclareConstantBufferViewType(clang::ASTContext& context, bool bTBuf);
clang::CXXRecordDecl* DeclareRayQueryType(clang::ASTContext& context);
clang::CXXRecordDecl *DeclareResourceType(clang::ASTContext &context,
Expand Down
8 changes: 8 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,14 @@ def VKReferenceExt : InheritableAttr {
let Documentation = [Undocumented];
}

def VKTypeDefExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_type_def">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Args = [UnsignedArgument<"id">, UnsignedArgument<"opcode">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

// Global variables that are of scalar type
def ScalarGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isScalarType()}]>;

Expand Down
7 changes: 7 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ class SpirvContext {
return rayQueryTypeKHR;
}

const SpirvIntrinsicType *
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands);

SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);

/// --- Hybrid type getter functions ---
///
/// Concrete SpirvType objects represent a SPIR-V type completely. Hybrid
Expand Down Expand Up @@ -467,6 +473,7 @@ class SpirvContext {
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType*> spirvIntrinsicTypes;
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
const RayQueryTypeKHR *rayQueryTypeKHR;

Expand Down
12 changes: 6 additions & 6 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1110,10 +1110,13 @@ class SpirvConstant : public SpirvInstruction {
}

bool isSpecConstant() const;
void setLiteral(bool literal = true) { literalConstant = literal; }
bool isLiteral() { return literalConstant; }

protected:
SpirvConstant(Kind, spv::Op, const SpirvType *);
SpirvConstant(Kind, spv::Op, QualType);
SpirvConstant(Kind, spv::Op, const SpirvType *, bool literal = false);
SpirvConstant(Kind, spv::Op, QualType, bool literal = false);
bool literalConstant;
};

class SpirvConstantBoolean : public SpirvConstant {
Expand Down Expand Up @@ -1141,7 +1144,7 @@ class SpirvConstantBoolean : public SpirvConstant {
class SpirvConstantInteger : public SpirvConstant {
public:
SpirvConstantInteger(QualType type, llvm::APInt value,
bool isSpecConst = false, bool literal = false);
bool isSpecConst = false);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger)

Expand All @@ -1155,12 +1158,9 @@ class SpirvConstantInteger : public SpirvConstant {
bool invokeVisitor(Visitor *v) override;

llvm::APInt getValue() const { return value; }
void setLiteral(bool l = true) { isLiteral = l; }
bool getLiteral() { return isLiteral; }

private:
llvm::APInt value;
bool isLiteral;
};

class SpirvConstantFloat : public SpirvConstant {
Expand Down
32 changes: 32 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvType.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class SpirvType {
TK_Function,
TK_AccelerationStructureNV,
TK_RayQueryKHR,
TK_SpirvIntrinsicType,
// Order matters: all the following are hybrid types
TK_HybridStruct,
TK_HybridPointer,
Expand Down Expand Up @@ -412,6 +413,37 @@ class RayQueryTypeKHR : public SpirvType {
}
};

class SpirvInstruction;
struct SpvIntrinsicTypeOperand {
SpvIntrinsicTypeOperand(SpirvType *type_operand)
: operand_as_type(type_operand), isTypeOperand(true) {}
SpvIntrinsicTypeOperand(SpirvInstruction *inst_operand)
: operand_as_inst(inst_operand), isTypeOperand(false) {}
union {
SpirvType *operand_as_type;
SpirvInstruction *operand_as_inst;
};
bool isTypeOperand;
};

class SpirvIntrinsicType : public SpirvType {
public:
SpirvIntrinsicType(unsigned typeOp,
llvm::ArrayRef<SpvIntrinsicTypeOperand> inOps);

static bool classof(const SpirvType *t) {
return t->getKind() == TK_SpirvIntrinsicType;
}
unsigned getOpCode() const { return typeOpCode; }
llvm::ArrayRef<SpvIntrinsicTypeOperand> getOperands() const {
return operands;
}

private:
unsigned typeOpCode;
llvm::SmallVector<SpvIntrinsicTypeOperand, 3> operands;
};

class HybridType : public SpirvType {
public:
static bool classof(const SpirvType *t) {
Expand Down
9 changes: 8 additions & 1 deletion tools/clang/lib/AST/ASTContextHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,15 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(

CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle(
ASTContext& context, StringRef typeName, StringRef templateParamName) {
return DeclareUIntTemplatedTypeWithHandleInDeclContext(
context, context.getTranslationUnitDecl(), typeName, templateParamName);
}

CXXRecordDecl *hlsl::DeclareUIntTemplatedTypeWithHandleInDeclContext(
ASTContext &context, DeclContext *declContext, StringRef typeName,
StringRef templateParamName) {
// template<uint kind> FeedbackTexture2D[Array] { ... }
BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName);
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, typeName);
typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy);
typeDeclBuilder.startDefinition();
typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
Expand Down
5 changes: 3 additions & 2 deletions tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,9 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
}
case spv::Op::OpRayQueryInitializeKHR: {
auto rayQueryInst = dyn_cast<SpirvRayQueryOpKHR>(instr);
if (rayQueryInst->hasCullFlags()) {
addCapability(spv::Capability::RayTraversalPrimitiveCullingKHR);
if (rayQueryInst && rayQueryInst->hasCullFlags()) {
addCapability(
spv::Capability::RayTraversalPrimitiveCullingKHR);
}

break;
Expand Down
69 changes: 65 additions & 4 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,10 +1884,9 @@ bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) {
}

for (const auto operand : inst->getOperands()) {
// TODO: Handle Literals with other types.
auto literalOperand = dyn_cast<SpirvConstantInteger>(operand);
if (literalOperand && literalOperand->getLiteral()) {
curInst.push_back(literalOperand->getValue().getZExtValue());
auto literalOperand = dyn_cast<SpirvConstant>(operand);
if (literalOperand && literalOperand->isLiteral()) {
typeHandler.emitLiteral(literalOperand, curInst);
} else {
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
}
Expand Down Expand Up @@ -2451,6 +2450,24 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
initTypeInstruction(spv::Op::OpTypeRayQueryKHR);
curTypeInst.push_back(id);
finalizeTypeInstruction();
} else if (const auto *spvIntrinsicType =
dyn_cast<SpirvIntrinsicType>(type)) {
initTypeInstruction(static_cast<spv::Op>(spvIntrinsicType->getOpCode()));
curTypeInst.push_back(id);
for (const SpvIntrinsicTypeOperand &operand :
spvIntrinsicType->getOperands()) {
if (operand.isTypeOperand) {
curTypeInst.push_back(emitType(operand.operand_as_type));
} else {
auto *literal = dyn_cast<SpirvConstant>(operand.operand_as_inst);
if (literal && literal->isLiteral()) {
emitLiteral(literal, curTypeInst);
} else {
curTypeInst.push_back(getOrAssignResultId(operand.operand_as_inst));
}
}
}
finalizeTypeInstruction();
}
// Hybrid Types
// Note: The type lowering pass should lower all types to SpirvTypes.
Expand All @@ -2467,6 +2484,50 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
return id;
}

template <typename vecType>
void EmitTypeHandler::emitIntLiteral(const SpirvConstantInteger *intLiteral,
vecType &outInst) {
const auto &literalVal = intLiteral->getValue();
bool positive = !literalVal.isNegative();
if (literalVal.getBitWidth() <= 32) {
outInst.push_back(positive ? literalVal.getZExtValue()
: literalVal.getSExtValue());
} else {
assert(literalVal.getBitWidth() == 64);
uint64_t val =
positive ? literalVal.getZExtValue() : literalVal.getSExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
}

template <typename vecType>
void EmitTypeHandler::emitFloatLiteral(const SpirvConstantFloat *fLiteral,
vecType &outInst) {
const auto &literalVal = fLiteral->getValue();
const auto bitwidth = llvm::APFloat::getSizeInBits(literalVal.getSemantics());
if (bitwidth <= 32) {
outInst.push_back(literalVal.bitcastToAPInt().getZExtValue());
} else {
assert(bitwidth == 64);
uint64_t val = literalVal.bitcastToAPInt().getZExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
}

template <typename VecType>
void EmitTypeHandler::emitLiteral(const SpirvConstant *literal,
VecType &outInst) {
if (auto boolLiteral = dyn_cast<SpirvConstantBoolean>(literal)) {
outInst.push_back(static_cast<unsigned>(boolLiteral->getValue()));
} else if (auto intLiteral = dyn_cast<SpirvConstantInteger>(literal)) {
emitIntLiteral(intLiteral, outInst);
} else if (auto fLiteral = dyn_cast<SpirvConstantFloat>(literal)) {
emitFloatLiteral(fLiteral, outInst);
}
}

void EmitTypeHandler::emitDecoration(uint32_t typeResultId,
spv::Decoration decoration,
llvm::ArrayRef<uint32_t> decorationParams,
Expand Down
6 changes: 6 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ class EmitTypeHandler {
uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
uint32_t getOrCreateConstantNull(SpirvConstantNull *);
uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
template <typename vecType>
void emitLiteral(const SpirvConstant *, vecType &outInst);
template <typename vecType>
void emitFloatLiteral(const SpirvConstantFloat *, vecType &outInst);
template <typename vecType>
void emitIntLiteral(const SpirvConstantInteger *, vecType &outInst);

private:
void initTypeInstruction(spv::Op op);
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
if (name == "RayQuery")
return spvContext.getRayQueryTypeKHR();

if (name == "ext_type") {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(type);
return spvContext.getCreatedSpirvIntrinsicType(typeId);
}

if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
// StructureBuffer<S> will be translated into an OpTypeStruct with one
Expand Down
7 changes: 5 additions & 2 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,10 +1001,13 @@ SpirvInstruction *SpirvBuilder::createSpirvIntrInstExt(

SpirvExtInstImport *set =
(instSet.size() == 0) ? nullptr : getExtInstSet(instSet);

if (retType != QualType() && retType->isVoidType()) {
retType = QualType();
}

auto *inst = new (context) SpirvIntrinsicInstruction(
retType->isVoidType() ? QualType() : retType, opcode, operands,
extensions, set, capablities, loc);
retType, opcode, operands, extensions, set, capablities, loc);
insertPoint->addInstruction(inst);
return inst;
}
Expand Down
23 changes: 23 additions & 0 deletions tools/clang/lib/SPIRV/SpirvContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ SpirvContext::~SpirvContext() {

for (auto &typePair : typeTemplateParams)
typePair.second->releaseMemory();

for (auto &pair : spirvIntrinsicTypes) {
assert(pair.second);
pair.second->~SpirvIntrinsicType();
}
}

inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
Expand Down Expand Up @@ -527,5 +532,23 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) {
typeTemplateParams.clear();
}

const SpirvIntrinsicType *SpirvContext::getSpirvIntrinsicType(
unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpvIntrinsicTypeOperand> operands) {
if (spirvIntrinsicTypes[typeId] == nullptr) {
spirvIntrinsicTypes[typeId] =
new (this) SpirvIntrinsicType(typeOpCode, operands);
}
return spirvIntrinsicTypes[typeId];
}

SpirvIntrinsicType *
SpirvContext::getCreatedSpirvIntrinsicType(unsigned typeId) {
if (spirvIntrinsicTypes.find(typeId) == spirvIntrinsicTypes.end()){
return nullptr;
}
return spirvIntrinsicTypes[typeId];
}

} // end namespace spirv
} // end namespace clang
Loading

0 comments on commit 2eae8d3

Please sign in to comment.