Skip to content

Commit

Permalink
[SPIRV] Add support of the GL_EXT_spirv_intrinsics (#3949)
Browse files Browse the repository at this point in the history
[SPIRV] Add support of the GL_EXT_spirv_intrinsics

Related to the issue: #3919
Add these attributes

vk::ext_capability
vk::ext_extension
vk::ext_instruction
vk::ext_reference
vk::ext_literal

Note this commit allows the redeclaration of a HLSL intrinsic function using a function declaration with `vk::ext_instruction`.

Co-authored-by: Jaebaek Seo <jaebaek@google.com>
  • Loading branch information
jiaolu and jaebaek authored Sep 24, 2021
1 parent f008085 commit de6a8ed
Show file tree
Hide file tree
Showing 16 changed files with 342 additions and 5 deletions.
38 changes: 38 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,14 @@ def VKBinding : InheritableAttr {
let Documentation = [Undocumented];
}

def VKCapabilityExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_capability">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Args = [IntArgument<"capability">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

def VKCounterBinding : InheritableAttr {
let Spellings = [CXX11<"vk", "counter_binding">];
let Subjects = SubjectList<[CounterStructuredBuffer], ErrorDiag, "ExpectedCounterStructuredBuffer">;
Expand All @@ -1018,6 +1026,14 @@ def VKCounterBinding : InheritableAttr {
let Documentation = [Undocumented];
}

def VKExtensionExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_extension">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Args = [StringArgument<"name">];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

// Global variables that are of struct type
def StructGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isStructureType()}]>;

Expand Down Expand Up @@ -1083,6 +1099,28 @@ def VKInputAttachmentIndex : InheritableAttr {
let Documentation = [Undocumented];
}

def VKInstructionExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_instruction">];
let Subjects = SubjectList<[Function], ErrorDiag>;
let Args = [IntArgument<"opcode">, StringArgument<"instruction_set", 1>];
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

def VKLiteralExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_literal">];
let Subjects = SubjectList<[ParmVar], ErrorDiag>;
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

def VKReferenceExt : InheritableAttr {
let Spellings = [CXX11<"vk", "ext_reference">];
let Subjects = SubjectList<[ParmVar], ErrorDiag>;
let LangOpts = [SPIRV];
let Documentation = [Undocumented];
}

// Global variables that are of scalar type
def ScalarGlobalVar : SubsetSubject<Var, [{S->hasGlobalStorage() && S->getType()->isScalarType()}]>;

Expand Down
7 changes: 7 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,13 @@ class SpirvBuilder {
/// OpIgnoreIntersectionKHR/OpTerminateIntersectionKHR
void createRaytracingTerminateKHR(spv::Op opcode, SourceLocation loc);

/// \brief Create spirv intrinsic instructions
SpirvInstruction *createSpirvIntrInstExt(
uint32_t opcode, QualType retType,
llvm::ArrayRef<SpirvInstruction *> operands,
llvm::ArrayRef<llvm::StringRef> extensions, llvm::StringRef instSet,
llvm::ArrayRef<uint32_t> capablities, SourceLocation loc);

/// \brief Returns a clone SPIR-V variable for CTBuffer with FXC memory layout
/// and creates copy instructions from the CTBuffer to the clone variable in
/// module.init if it contains HLSL matrix 1xN. Otherwise, returns nullptr.
Expand Down
44 changes: 43 additions & 1 deletion tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class SpirvInstruction {
IK_Store, // OpStore
IK_UnaryOp, // Unary operations
IK_VectorShuffle, // OpVectorShuffle
IK_SpirvIntrinsicInstruction, // Spirv Intrinsic Instructions

// For DebugInfo instructions defined in OpenCL.DebugInfo.100
IK_DebugInfoNone,
Expand Down Expand Up @@ -1125,7 +1126,7 @@ class SpirvConstantBoolean : public SpirvConstant {
class SpirvConstantInteger : public SpirvConstant {
public:
SpirvConstantInteger(QualType type, llvm::APInt value,
bool isSpecConst = false);
bool isSpecConst = false, bool literal = false);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger)

Expand All @@ -1139,9 +1140,12 @@ class SpirvConstantInteger : public SpirvConstant {
bool invokeVisitor(Visitor *v) override;

llvm::APInt getValue() const { return value; }
void setLiteral(bool l = true) { isLiteral = l; }
bool getLiteral() { return isLiteral; }

private:
llvm::APInt value;
bool isLiteral;
};

class SpirvConstantFloat : public SpirvConstant {
Expand Down Expand Up @@ -1999,6 +2003,44 @@ class SpirvDemoteToHelperInvocationEXT : public SpirvInstruction {
bool invokeVisitor(Visitor *v) override;
};

// A class keeping information of [[vk::ext_instruction(uint opcode,
// string extended_instruction_set)]] attribute. The attribute allows users to
// emit an arbitrary SPIR-V instruction by adding it to a function declaration.
// Note that this class does not represent an actual specific SPIR-V
// instruction. It is used to keep the information of the arbitrary SPIR-V
// instruction.
class SpirvIntrinsicInstruction : public SpirvInstruction {
public:
SpirvIntrinsicInstruction(QualType resultType, uint32_t opcode,
llvm::ArrayRef<SpirvInstruction *> operands,
llvm::ArrayRef<llvm::StringRef> extensions,
SpirvExtInstImport *set,
llvm::ArrayRef<uint32_t> capabilities,
SourceLocation loc);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvIntrinsicInstruction)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_SpirvIntrinsicInstruction;
}

bool invokeVisitor(Visitor *v) override;

llvm::ArrayRef<SpirvInstruction *> getOperands() const { return operands; }
llvm::ArrayRef<uint32_t> getCapabilities() const { return capabilities; }
llvm::ArrayRef<std::string> getExtensions() const { return extensions; }
SpirvExtInstImport *getInstructionSet() const { return instructionSet; }
uint32_t getInstruction() const { return instruction; }

private:
uint32_t instruction;
llvm::SmallVector<SpirvInstruction *, 4> operands;
llvm::SmallVector<uint32_t, 4> capabilities;
llvm::SmallVector<std::string, 4> extensions;
SpirvExtInstImport *instructionSet;
};

/// \breif Base class for all OpenCL.DebugInfo.100 extension instructions.
/// Note that all of these instructions should be added to the SPIR-V module as
/// an OpExtInst instructions. So, all of these instructions must:
Expand Down
1 change: 1 addition & 0 deletions tools/clang/include/clang/SPIRV/SpirvVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class Visitor {
DEFINE_VISIT_METHOD(SpirvRayQueryOpKHR)
DEFINE_VISIT_METHOD(SpirvReadClock)
DEFINE_VISIT_METHOD(SpirvRayTracingTerminateOpKHR)
DEFINE_VISIT_METHOD(SpirvIntrinsicInstruction)
#undef DEFINE_VISIT_METHOD

protected:
Expand Down
11 changes: 11 additions & 0 deletions tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,17 @@ bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
addCapability(getNonUniformCapability(resultType));
}

if (instr->getKind() == SpirvInstruction::IK_SpirvIntrinsicInstruction) {
SpirvIntrinsicInstruction *pSpvInst =
dyn_cast<SpirvIntrinsicInstruction>(instr);
for (auto &cap : pSpvInst->getCapabilities()) {
addCapability(static_cast<spv::Capability>(cap));
}
for (const auto &ext : pSpvInst->getExtensions()) {
spvBuilder.requireExtension(ext, loc);
}
}

// Add opcode-specific capabilities
switch (opcode) {
case spv::Op::OpDPdxCoarse:
Expand Down
26 changes: 26 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,32 @@ bool EmitVisitor::visit(SpirvRayTracingTerminateOpKHR *inst) {
return true;
}

bool EmitVisitor::visit(SpirvIntrinsicInstruction *inst) {
initInstruction(inst);
if (inst->hasResultType()) {
curInst.push_back(inst->getResultTypeId());
curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
}
if (inst->getInstructionSet()) {
curInst.push_back(
getOrAssignResultId<SpirvInstruction>(inst->getInstructionSet()));
curInst.push_back(inst->getInstruction());
}

for (const auto operand : inst->getOperands()) {
// TODO: Handle Literals with other types.
auto literalOperand = dyn_cast<SpirvConstantInteger>(operand);
if (literalOperand && literalOperand->getLiteral()) {
curInst.push_back(literalOperand->getValue().getZExtValue());
} else {
curInst.push_back(getOrAssignResultId<SpirvInstruction>(operand));
}
}

finalizeInstruction(&mainBinary);
return true;
}

// EmitTypeHandler ------

void EmitTypeHandler::initTypeInstruction(spv::Op op) {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ class EmitVisitor : public Visitor {
bool visit(SpirvDebugTypeMember *) override;
bool visit(SpirvDebugTypeTemplate *) override;
bool visit(SpirvDebugTypeTemplateParameter *) override;
bool visit(SpirvIntrinsicInstruction *) override;

using Visitor::visit;

Expand Down
17 changes: 17 additions & 0 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,23 @@ SpirvInstruction *SpirvBuilder::createReadClock(SpirvInstruction *scope,
return inst;
}

SpirvInstruction *SpirvBuilder::createSpirvIntrInstExt(
uint32_t opcode, QualType retType,
llvm::ArrayRef<SpirvInstruction *> operands,
llvm::ArrayRef<llvm::StringRef> extensions, llvm::StringRef instSet,
llvm::ArrayRef<uint32_t> capablities, SourceLocation loc) {
assert(insertPoint && "null insert point");

SpirvExtInstImport *set =
(instSet.size() == 0) ? nullptr : getExtInstSet(instSet);

auto *inst = new (context) SpirvIntrinsicInstruction(
retType->isVoidType() ? QualType() : retType, opcode, operands,
extensions, set, capablities, loc);
insertPoint->addInstruction(inst);
return inst;
}

void SpirvBuilder::createRaytracingTerminateKHR(spv::Op opcode,
SourceLocation loc) {
assert(insertPoint && "null insert point");
Expand Down
62 changes: 61 additions & 1 deletion tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2316,8 +2316,12 @@ SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) {
if (const auto *memberCall = dyn_cast<CXXMemberCallExpr>(callExpr))
return doCXXMemberCallExpr(memberCall);

auto funcDecl = callExpr->getDirectCallee();
if (funcDecl && funcDecl->hasAttr<VKInstructionExtAttr>()) {
return processSpvIntrinsicCallExpr(callExpr);
}
// Intrinsic functions such as 'dot' or 'mul'
if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) {
if (hlsl::IsIntrinsicOp(funcDecl)) {
return processIntrinsicCallExpr(callExpr);
}

Expand Down Expand Up @@ -12455,6 +12459,62 @@ SpirvEmitter::processRayQueryIntrinsics(const CXXMemberCallExpr *expr,
return retVal;
}

SpirvInstruction *
SpirvEmitter::processSpvIntrinsicCallExpr(const CallExpr *expr) {
auto funcDecl = expr->getDirectCallee();
auto &attrs = funcDecl->getAttrs();
QualType retType = funcDecl->getReturnType();

llvm::SmallVector<uint32_t, 2> capbilities;
llvm::SmallVector<llvm::StringRef, 2> extensions;
llvm::StringRef instSet = "";
uint32_t op = 0;
for (auto &attr : attrs) {
if (auto capAttr = dyn_cast<VKCapabilityExtAttr>(attr)) {
capbilities.push_back(capAttr->getCapability());
} else if (auto extAttr = dyn_cast<VKExtensionExtAttr>(attr)) {
extensions.push_back(extAttr->getName());
} else if (auto instAttr = dyn_cast<VKInstructionExtAttr>(attr)) {
op = instAttr->getOpcode();
instSet = instAttr->getInstruction_set();
}
}

llvm::SmallVector<SpirvInstruction *, 8> spvArgs;

const auto args = expr->getArgs();
for (uint32_t i = 0; i < expr->getNumArgs(); ++i) {
auto param = funcDecl->getParamDecl(i);
const Expr *arg = args[i]->IgnoreParenLValueCasts();
SpirvInstruction *argInst = doExpr(arg);
if (param->hasAttr<VKReferenceExtAttr>()) {
if (argInst->isRValue()) {
emitError("argument for a parameter with vk::ext_reference attribute "
"must be a reference",
arg->getExprLoc());
return nullptr;
}
spvArgs.push_back(argInst);
} else if (param->hasAttr<VKLiteralExtAttr>()) {
auto constArg = dyn_cast<SpirvConstantInteger>(argInst);
assert(constArg != nullptr);
constArg->setLiteral();
spvArgs.push_back(argInst);
} else {
spvArgs.push_back(loadIfGLValue(arg, argInst));
}
}

const auto loc = expr->getExprLoc();

SpirvInstruction *retVal = spvBuilder.createSpirvIntrInstExt(
op, retType, spvArgs, extensions, instSet, capbilities, loc);

// TODO: Revisit this r-value setting when handling vk::ext_result_id<T> ?
retVal->setRValue();
return retVal;
}

bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
std::string *messages) {
spvtools::SpirvTools tools(featureManager.getTargetEnv());
Expand Down
2 changes: 2 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ class SpirvEmitter : public ASTConsumer {
/// Process ray query intrinsics
SpirvInstruction *processRayQueryIntrinsics(const CXXMemberCallExpr *expr,
hlsl::IntrinsicOp opcode);
/// Process spirv intrinsic instruction
SpirvInstruction *processSpvIntrinsicCallExpr(const CallExpr *expr);

private:
/// Returns the <result-id> for constant value 0 of the given type.
Expand Down
18 changes: 16 additions & 2 deletions tools/clang/lib/SPIRV/SpirvInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDebugTypeTemplateParameter)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayQueryOpKHR)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvReadClock)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvRayTracingTerminateOpKHR)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvIntrinsicInstruction)

#undef DEFINE_INVOKE_VISITOR_FOR_CLASS

Expand Down Expand Up @@ -499,11 +500,11 @@ bool SpirvConstantBoolean::operator==(const SpirvConstantBoolean &that) const {
}

SpirvConstantInteger::SpirvConstantInteger(QualType type, llvm::APInt val,
bool isSpecConst)
bool isSpecConst, bool literal)
: SpirvConstant(IK_ConstantInteger,
isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
type),
value(val) {
value(val), isLiteral(literal) {
assert(type->isIntegerType());
}

Expand Down Expand Up @@ -1004,5 +1005,18 @@ SpirvRayTracingTerminateOpKHR::SpirvRayTracingTerminateOpKHR(spv::Op opcode,
opcode == spv::Op::OpIgnoreIntersectionKHR);
}

SpirvIntrinsicInstruction::SpirvIntrinsicInstruction(
QualType resultType, uint32_t opcode,
llvm::ArrayRef<SpirvInstruction *> vecOperands,
llvm::ArrayRef<llvm::StringRef> exts, SpirvExtInstImport *set,
llvm::ArrayRef<uint32_t> capts, SourceLocation loc)
: SpirvInstruction(IK_SpirvIntrinsicInstruction,
set != nullptr ? spv::Op::OpExtInst
: static_cast<spv::Op>(opcode),
resultType, loc),
instruction(opcode), operands(vecOperands.begin(), vecOperands.end()),
capabilities(capts.begin(), capts.end()),
extensions(exts.begin(), exts.end()), instructionSet(set) {}

} // namespace spirv
} // namespace clang
24 changes: 24 additions & 0 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12038,6 +12038,30 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
case AttributeList::AT_VKShaderRecordEXT:
declAttr = ::new (S.Context) VKShaderRecordEXTAttr(A.getRange(), S.Context, A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKCapabilityExt:
declAttr = ::new (S.Context) VKCapabilityExtAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKExtensionExt:
declAttr = ::new (S.Context) VKExtensionExtAttr(
A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKInstructionExt:
declAttr = ::new (S.Context) VKInstructionExtAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
ValidateAttributeStringArg(S, A, nullptr, 1),
A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKLiteralExt:
declAttr = ::new (S.Context) VKLiteralExtAttr(
A.getRange(), S.Context, A.getAttributeSpellingListIndex());
break;
case AttributeList::AT_VKReferenceExt:
declAttr = ::new (S.Context) VKReferenceExtAttr(
A.getRange(), S.Context, A.getAttributeSpellingListIndex());
break;
default:
Handled = false;
return;
Expand Down
Loading

0 comments on commit de6a8ed

Please sign in to comment.