Skip to content

Commit

Permalink
[SPIRV] Add support of [[vk::ext_type_def]]
Browse files Browse the repository at this point in the history
this is related
#3919
  • Loading branch information
jiaolu committed Nov 24, 2021
1 parent cc50c79 commit 374b714
Show file tree
Hide file tree
Showing 18 changed files with 280 additions and 30 deletions.
8 changes: 8 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,14 @@ def VKReferenceExt : InheritableAttr {
let Documentation = [Undocumented];
}

def VKTypeDefExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_type_def">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Args = [UnsignedArgument<"id">, UnsignedArgument<"opcode">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

// Global variables that are of scalar type
def ScalarGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isScalarType()}]>;

Expand Down
8 changes: 8 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ class SpirvContext {
return rayQueryTypeKHR;
}

const SpirvIntrinsicType *
getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpirvConstant *> constants,
SpirvIntrinsicType *elementTy);

SpirvIntrinsicType *getCreatedSpirvIntrinsicType(unsigned typeId);

/// --- Hybrid type getter functions ---
///
/// Concrete SpirvType objects represent a SPIR-V type completely. Hybrid
Expand Down Expand Up @@ -467,6 +474,7 @@ class SpirvContext {
llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType*> spirvIntrinsicTypes;
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
const RayQueryTypeKHR *rayQueryTypeKHR;

Expand Down
12 changes: 6 additions & 6 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1110,10 +1110,13 @@ class SpirvConstant : public SpirvInstruction {
}

bool isSpecConstant() const;
void setLiteral(bool literal = true) { literalConstant = literal; }
bool isLiteral() { return literalConstant; }

protected:
SpirvConstant(Kind, spv::Op, const SpirvType *);
SpirvConstant(Kind, spv::Op, QualType);
SpirvConstant(Kind, spv::Op, const SpirvType *, bool literal = false);
SpirvConstant(Kind, spv::Op, QualType, bool literal = false);
bool literalConstant;
};

class SpirvConstantBoolean : public SpirvConstant {
Expand Down Expand Up @@ -1141,7 +1144,7 @@ class SpirvConstantBoolean : public SpirvConstant {
class SpirvConstantInteger : public SpirvConstant {
public:
SpirvConstantInteger(QualType type, llvm::APInt value,
bool isSpecConst = false, bool literal = false);
bool isSpecConst = false);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger)

Expand All @@ -1155,12 +1158,9 @@ class SpirvConstantInteger : public SpirvConstant {
bool invokeVisitor(Visitor *v) override;

llvm::APInt getValue() const { return value; }
void setLiteral(bool l = true) { isLiteral = l; }
bool getLiteral() { return isLiteral; }

private:
llvm::APInt value;
bool isLiteral;
};

class SpirvConstantFloat : public SpirvConstant {
Expand Down
19 changes: 19 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvType.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class SpirvType {
TK_Function,
TK_AccelerationStructureNV,
TK_RayQueryKHR,
TK_SpirvIntrinsicType,
// Order matters: all the following are hybrid types
TK_HybridStruct,
TK_HybridPointer,
Expand Down Expand Up @@ -412,6 +413,24 @@ class RayQueryTypeKHR : public SpirvType {
}
};

class SpirvConstant;
class SpirvIntrinsicType : public SpirvType {
public:
SpirvIntrinsicType(unsigned typeOp, llvm::ArrayRef<SpirvConstant *> constants,
SpirvIntrinsicType *elementTy);
static bool classof(const SpirvType *t) {
return t->getKind() == TK_SpirvIntrinsicType;
}
unsigned getOpCode() const { return typeOpCode; }
llvm::ArrayRef<SpirvConstant *> getLiterals() const { return literals; }
SpirvIntrinsicType *getElemType() const { return elementType; }

private:
unsigned typeOpCode;
llvm::SmallVector<SpirvConstant *, 3> literals;
SpirvIntrinsicType *elementType;
};

class HybridType : public SpirvType {
public:
static bool classof(const SpirvType *t) {
Expand Down
5 changes: 3 additions & 2 deletions tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,9 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
}
case spv::Op::OpRayQueryInitializeKHR: {
auto rayQueryInst = dyn_cast<SpirvRayQueryOpKHR>(instr);
if (rayQueryInst->hasCullFlags()) {
addCapability(spv::Capability::RayTraversalPrimitiveCullingKHR);
if (rayQueryInst && rayQueryInst->hasCullFlags()) {
addCapability(
spv::Capability::RayTraversalPrimitiveCullingKHR);
}

break;
Expand Down
51 changes: 47 additions & 4 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1884,10 +1884,9 @@ bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) {
}

for (const auto operand : inst->getOperands()) {
// TODO: Handle Literals with other types.
auto literalOperand = dyn_cast<SpirvConstantInteger>(operand);
if (literalOperand && literalOperand->getLiteral()) {
curInst.push_back(literalOperand->getValue().getZExtValue());
auto literalOperand = dyn_cast<SpirvConstant>(operand);
if (literalOperand && literalOperand->isLiteral()) {
typeHandler.emitLiteral(literalOperand, curInst);
} else {
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
}
Expand Down Expand Up @@ -2452,6 +2451,17 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
curTypeInst.push_back(id);
finalizeTypeInstruction();
}
else if (const auto *spvIntrinsicType = dyn_cast<SpirvIntrinsicType>(type)) {
initTypeInstruction(static_cast<spv::Op>(spvIntrinsicType->getOpCode()));
curTypeInst.push_back(id);
if (spvIntrinsicType->getElemType()) {
curTypeInst.push_back(emitType(spvIntrinsicType->getElemType()));
}
for (auto& literal : spvIntrinsicType->getLiterals()) {
emitLiteral(literal, curTypeInst);
}
finalizeTypeInstruction();
}
// Hybrid Types
// Note: The type lowering pass should lower all types to SpirvTypes.
// Therefore, if we find a hybrid type when going through the emitting pass,
Expand All @@ -2467,6 +2477,39 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
return id;
}

template <typename VecType>
void EmitTypeHandler::emitLiteral(const SpirvConstant *literal,
VecType &outInst) {
if (auto boolLiteral = dyn_cast<SpirvConstantBoolean>(literal)) {
outInst.push_back(static_cast<unsigned>(boolLiteral->getValue()));
} else if (auto intLiteral = dyn_cast<SpirvConstantInteger>(literal)) {
const auto &literalVal = intLiteral->getValue();
bool positive = !literalVal.isNegative();
if (literalVal.getBitWidth() <= 32) {
outInst.push_back(positive ? literalVal.getZExtValue()
: literalVal.getSExtValue());
} else {
assert(literalVal.getBitWidth() == 64);
uint64_t val =
positive ? literalVal.getZExtValue() : literalVal.getSExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
} else if (auto fLiteral = dyn_cast<SpirvConstantFloat>(literal)) {
const auto &literalVal = fLiteral->getValue();
const auto bitwidth =
llvm::APFloat::getSizeInBits(literalVal.getSemantics());
if (bitwidth <= 32) {
outInst.push_back(literalVal.bitcastToAPInt().getZExtValue());
} else {
assert(bitwidth == 64);
uint64_t val = literalVal.bitcastToAPInt().getZExtValue();
outInst.push_back(static_cast<unsigned>(val));
outInst.push_back(static_cast<unsigned>(val >> 32));
}
}
}

void EmitTypeHandler::emitDecoration(uint32_t typeResultId,
spv::Decoration decoration,
llvm::ArrayRef<uint32_t> decorationParams,
Expand Down
2 changes: 2 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class EmitTypeHandler {
uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
uint32_t getOrCreateConstantNull(SpirvConstantNull *);
uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
template <typename vecType>
void emitLiteral(const SpirvConstant *, vecType &outInst);

private:
void initTypeInstruction(spv::Op op);
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
if (name == "RayQuery")
return spvContext.getRayQueryTypeKHR();

if (name == "ext_type") {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(type);
return spvContext.getCreatedSpirvIntrinsicType(typeId);
}

if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
// StructureBuffer<S> will be translated into an OpTypeStruct with one
Expand Down
7 changes: 5 additions & 2 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,10 +1001,13 @@ SpirvInstruction *SpirvBuilder::createSpirvIntrInstExt(

SpirvExtInstImport *set =
(instSet.size() == 0) ? nullptr : getExtInstSet(instSet);

if (retType != QualType() && retType->isVoidType()) {
retType = QualType();
}

auto *inst = new (context) SpirvIntrinsicInstruction(
retType->isVoidType() ? QualType() : retType, opcode, operands,
extensions, set, capablities, loc);
retType, opcode, operands, extensions, set, capablities, loc);
insertPoint->addInstruction(inst);
return inst;
}
Expand Down
19 changes: 19 additions & 0 deletions tools/clang/lib/SPIRV/SpirvContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,5 +527,24 @@ void SpirvContext::moveDebugTypesToModule(SpirvModule *module) {
typeTemplateParams.clear();
}

const SpirvIntrinsicType *
SpirvContext::getSpirvIntrinsicType(unsigned typeId, unsigned typeOpCode,
llvm::ArrayRef<SpirvConstant *> constants,
SpirvIntrinsicType *elementTy) {
if (spirvIntrinsicTypes[typeId] == nullptr) {
spirvIntrinsicTypes[typeId] =
new (this) SpirvIntrinsicType(typeOpCode, constants, elementTy);
}
return spirvIntrinsicTypes[typeId];
}

SpirvIntrinsicType *
SpirvContext::getCreatedSpirvIntrinsicType(unsigned typeId) {
if (spirvIntrinsicTypes.find(typeId) == spirvIntrinsicTypes.end()){
return nullptr;
}
return spirvIntrinsicTypes[typeId];
}

} // end namespace spirv
} // end namespace clang
57 changes: 53 additions & 4 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#include "dxc/HlslIntrinsicOp.h"
#include "spirv-tools/optimizer.hpp"
#include "clang/SPIRV/AstTypeProbe.h"
#include "clang/SPIRV/String.h"
#include "clang/Sema/Sema.h"
#include "llvm/ADT/StringExtras.h"

#include "InitListHandler.h"
#include "dxc/DXIL/DxilConstants.h"

Expand Down Expand Up @@ -2337,8 +2337,11 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) {
return doCXXMemberCallExpr(memberCall);

auto funcDecl = callExpr->getDirectCallee();
if (funcDecl && funcDecl->hasAttr<VKInstructionExtAttr>()) {
return processSpvIntrinsicCallExpr(callExpr);
if (funcDecl) {
if (funcDecl->hasAttr<VKInstructionExtAttr>())
return processSpvIntrinsicCallExpr(callExpr);
else if(funcDecl->hasAttr<VKTypeDefExtAttr>())
return processSpvIntrinsicTypeDef(callExpr);
}
// Intrinsic functions such as 'dot' or 'mul'
if (hlsl::IsIntrinsicOp(funcDecl)) {
Expand Down Expand Up @@ -12530,7 +12533,7 @@ SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) {
}
spvArgs.push_back(argInst);
} else if (param->hasAttr<VKLiteralExtAttr>()) {
auto constArg = dyn_cast<SpirvConstantInteger>(argInst);
auto constArg = dyn_cast<SpirvConstant>(argInst);
assert(constArg != nullptr);
constArg->setLiteral();
spvArgs.push_back(argInst);
Expand Down Expand Up @@ -12601,6 +12604,52 @@ SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) {
execModesParams, expr->getExprLoc());
}

SpirvInstruction *
SpirvEmitter::processSpvIntrinsicTypeDef(const CallExpr *expr) {
auto funcDecl = expr->getDirectCallee();
auto typeDefAttr = funcDecl->getAttr<VKTypeDefExtAttr>();
SpirvIntrinsicType *elementType = nullptr;
SmallVector<SpirvConstant *, 3> constants;
llvm::SmallVector<uint32_t, 2> capbilities;
llvm::SmallVector<llvm::StringRef, 2> extensions;

for (auto &attr : funcDecl->getAttrs()) {
if (auto capAttr = dyn_cast<VKCapabilityExtAttr>(attr)) {
capbilities.push_back(capAttr->getCapability());
} else if (auto extAttr = dyn_cast<VKExtensionExtAttr>(attr)) {
extensions.push_back(extAttr->getName());
}
}

const auto args = expr->getArgs();
for (uint32_t i = 0; i < expr->getNumArgs(); ++i) {
auto param = funcDecl->getParamDecl(i);
const Expr *arg = args[i]->IgnoreParenLValueCasts();
if (param->hasAttr<VKReferenceExtAttr>()) {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(arg->getType());
elementType = spvContext.getCreatedSpirvIntrinsicType(typeId);
} else if (param->hasAttr<VKLiteralExtAttr>()) {
SpirvInstruction *argInst = doExpr(arg);
auto constArg = dyn_cast<SpirvConstant>(argInst);
assert(constArg != nullptr);
constArg->setLiteral();
constants.push_back(constArg);
}
}
spvContext.getSpirvIntrinsicType(
typeDefAttr->getId(), typeDefAttr->getOpcode(), constants, elementType);

// Emit dummy OpNop with no semantic meaning, with possible extension and
// capabilities

SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
static_cast<unsigned>(spv::Op::OpNop), QualType(), {}, extensions, {},
capbilities, expr->getExprLoc());
retVal->setRValue();

return retVal;
}

bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
std::string *messages) {
spvtools::SpirvTools tools(featureManager.getTargetEnv());
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,9 @@ class SpirvEmitter : public ASTConsumer {
hlsl::IntrinsicOp opcode);
/// Process spirv intrinsic instruction
SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr);

/// Process spirv intrinsic type definition
SpirvInstruction *processSpvIntrinsicTypeDef(const CallExpr *expr);

/// Custom intrinsic to support basic buffer_reference use case
SpirvInstruction *processRawBufferLoad(const CallExpr *callExpr);
Expand Down
16 changes: 10 additions & 6 deletions tools/clang/lib/SPIRV/SpirvInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,15 +476,19 @@ SpirvCompositeConstruct::SpirvCompositeConstruct(
resultType, loc),
consituents(constituentsVec.begin(), constituentsVec.end()) {}

SpirvConstant::SpirvConstant(Kind kind, spv::Op op, const SpirvType *spvType)
SpirvConstant::SpirvConstant(Kind kind, spv::Op op, const SpirvType *spvType,
bool literal)
: SpirvInstruction(kind, op, QualType(),
/*SourceLocation*/ {}) {
/*SourceLocation*/ {}),
literalConstant(literal) {
setResultType(spvType);
}

SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType)
SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType,
bool literal)
: SpirvInstruction(kind, op, resultType,
/*SourceLocation*/ {}) {}
/*SourceLocation*/ {}),
literalConstant(literal) {}

bool SpirvConstant::isSpecConstant() const {
return opcode == spv::Op::OpSpecConstant ||
Expand All @@ -509,11 +513,11 @@ bool SpirvConstantBoolean::operator==(const SpirvConstantBoolean &that) const {
}

SpirvConstantInteger::SpirvConstantInteger(QualType type, llvm::APInt val,
bool isSpecConst, bool literal)
bool isSpecConst)
: SpirvConstant(IK_ConstantInteger,
isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
type),
value(val), isLiteral(literal) {
value(val) {
assert(type->isIntegerType());
}

Expand Down
Loading

0 comments on commit 374b714

Please sign in to comment.