Skip to content

Commit

Permalink
[spirv] add vk::ext_result_id<T> type (#4192)
Browse files Browse the repository at this point in the history
As a part of HLSL version of GL_EXT_spirv_intrinsics, this commit adds
vk::ext_result_id<T> type. We must use it for a variable definition or
a function parameter. It means we do not consider it as a physical
storage. Instead, it will be a result id of the instruction.

Related to #3919
  • Loading branch information
jaebaek authored Jan 19, 2022
1 parent 46c9735 commit 19139d8
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 16 deletions.
8 changes: 8 additions & 0 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ clang::CXXRecordDecl* DeclareTemplateTypeWithHandle(
uint8_t templateArgCount,
_In_opt_ clang::TypeSourceInfo* defaultTypeArgValue);

clang::CXXRecordDecl* DeclareTemplateTypeWithHandleInDeclContext(
clang::ASTContext& context,
clang::DeclContext *declContext,
llvm::StringRef name,
uint8_t templateArgCount,
_In_opt_ clang::TypeSourceInfo* defaultTypeArgValue);

clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle(
clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName);
clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandleInDeclContext(
Expand Down Expand Up @@ -404,6 +411,7 @@ bool IsHLSLAggregateType(clang::QualType type);
clang::QualType GetHLSLResourceResultType(clang::QualType type);
unsigned GetHLSLResourceTemplateUInt(clang::QualType type);
bool IsIncompleteHLSLResourceArrayType(clang::ASTContext& context, clang::QualType type);
clang::QualType GetHLSLResourceTemplateParamType(clang::QualType type);
clang::QualType GetHLSLInputPatchElementType(clang::QualType type);
unsigned GetHLSLInputPatchCount(clang::QualType type);
clang::QualType GetHLSLOutputPatchElementType(clang::QualType type);
Expand Down
10 changes: 8 additions & 2 deletions tools/clang/include/clang/SPIRV/AstTypeProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,16 @@ bool isOrContainsNonFpColMajorMatrix(const ASTContext &,
const SpirvCodeGenOptions &, QualType type,
const Decl *decl);

/// \bried Returns true if the given type is a String or StringLiteral type.
/// \brief Returns true if the given type is `vk::ext_result_id<T>`.
bool isExtResultIdType(QualType type);

/// \brief Returns true if the given type is defined in `vk` namespace.
bool isTypeInVkNamespace(const RecordType *type);

/// \brief Returns true if the given type is a String or StringLiteral type.
bool isStringType(QualType);

/// \bried Returns true if the given type is a bindless array of an opaque type.
/// \brief Returns true if the given type is a bindless array of an opaque type.
bool isBindlessOpaqueArray(QualType type);

/// \brief Generates the corresponding SPIR-V vector type for the given Clang
Expand Down
16 changes: 15 additions & 1 deletion tools/clang/lib/AST/ASTContextHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,14 +665,28 @@ CXXRecordDecl* hlsl::DeclareTemplateTypeWithHandle(
StringRef name,
uint8_t templateArgCount,
_In_opt_ TypeSourceInfo* defaultTypeArgValue)
{
return DeclareTemplateTypeWithHandleInDeclContext(context,
context.getTranslationUnitDecl(),
name,
templateArgCount,
defaultTypeArgValue);
}

CXXRecordDecl* hlsl::DeclareTemplateTypeWithHandleInDeclContext(
ASTContext& context,
DeclContext *declContext,
StringRef name,
uint8_t templateArgCount,
_In_opt_ TypeSourceInfo* defaultTypeArgValue)
{
DXASSERT(templateArgCount != 0, "otherwise caller should be creating a class or struct");
DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated for a different template pattern");

// Create an object template declaration in translation unit scope.
// templateArgCount=1: template<typename element> typeName { ... }
// templateArgCount=2: template<typename element, int count> typeName { ... }
BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), name);
BuiltinTypeDeclBuilder typeDeclBuilder(declContext, name);
TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", defaultTypeArgValue);
NonTypeTemplateParmDecl* countTemplateParamDecl = nullptr;
if (templateArgCount > 1)
Expand Down
14 changes: 7 additions & 7 deletions tools/clang/lib/AST/HlslTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,14 +729,19 @@ bool IsIncompleteHLSLResourceArrayType(clang::ASTContext &context,
}
return false;
}
QualType GetHLSLInputPatchElementType(QualType type) {

QualType GetHLSLResourceTemplateParamType(QualType type) {
type = type.getCanonicalType();
const RecordType *RT = cast<RecordType>(type);
const ClassTemplateSpecializationDecl *templateDecl =
cast<ClassTemplateSpecializationDecl>(RT->getAsCXXRecordDecl());
const TemplateArgumentList &argList = templateDecl->getTemplateArgs();
return argList[0].getAsType();
}

QualType GetHLSLInputPatchElementType(QualType type) {
return GetHLSLResourceTemplateParamType(type);
}
unsigned GetHLSLInputPatchCount(QualType type) {
type = type.getCanonicalType();
const RecordType *RT = cast<RecordType>(type);
Expand All @@ -746,12 +751,7 @@ unsigned GetHLSLInputPatchCount(QualType type) {
return argList[1].getAsIntegral().getLimitedValue();
}
clang::QualType GetHLSLOutputPatchElementType(QualType type) {
type = type.getCanonicalType();
const RecordType *RT = cast<RecordType>(type);
const ClassTemplateSpecializationDecl *templateDecl =
cast<ClassTemplateSpecializationDecl>(RT->getAsCXXRecordDecl());
const TemplateArgumentList &argList = templateDecl->getTemplateArgs();
return argList[0].getAsType();
return GetHLSLResourceTemplateParamType(type);
}
unsigned GetHLSLOutputPatchCount(QualType type) {
type = type.getCanonicalType();
Expand Down
19 changes: 19 additions & 0 deletions tools/clang/lib/SPIRV/AstTypeProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,25 @@ bool isOrContainsNonFpColMajorMatrix(const ASTContext &astContext,
return false;
}

bool isTypeInVkNamespace(const RecordType *type) {
if (const auto *nameSpaceDecl =
dyn_cast<NamespaceDecl>(type->getDecl()->getDeclContext())) {
return nameSpaceDecl->getName() == "vk";
}
return false;
}

bool isExtResultIdType(QualType type) {
if (const auto *elaboratedType = type->getAs<ElaboratedType>()) {
if (const auto *recordType = elaboratedType->getAs<RecordType>()) {
if (!isTypeInVkNamespace(recordType))
return false;
return recordType->getDecl()->getName() == "ext_result_id";
}
}
return false;
}

bool isStringType(QualType type) {
return hlsl::IsStringType(type) || hlsl::IsStringLiteralType(type);
}
Expand Down
15 changes: 15 additions & 0 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,21 @@ SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
return varInstr;
}

SpirvInstruction *DeclResultIdMapper::createResultId(const VarDecl *var) {
assert(isExtResultIdType(var->getType()));

// Without initialization, we cannot generate the result id.
if (!var->hasInit()) {
emitError("Found uninitialized variable for result id.",
var->getLocation());
return nullptr;
}

SpirvInstruction *init = theEmitter.doExpr(var->getInit());
astDecls[var] = createDeclSpirvInfo(init);
return init;
}

SpirvInstruction *
DeclResultIdMapper::createOrUpdateStringVar(const VarDecl *var) {
assert(hlsl::IsStringType(var->getType()) ||
Expand Down
8 changes: 8 additions & 0 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,14 @@ class DeclResultIdMapper {
/// for it.
SpirvInstruction *createOrUpdateStringVar(const VarDecl *);

/// \brief Returns an instruction that represents the given VarDecl.
/// VarDecl must be a variable of vk::ext_result_id<Type> type.
///
/// This function inspects the VarDecl for an initialization expression. If
/// initialization expression is not found, it will emit an error because the
/// variable with result id requires an initialization.
SpirvInstruction *createResultId(const VarDecl *var);

/// \brief Creates an Enum constant.
void createEnumConstant(const EnumConstantDecl *decl);

Expand Down
25 changes: 20 additions & 5 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,22 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
return 0;
}

const SpirvType *
LowerTypeVisitor::lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name,
SpirvLayoutRule rule,
SourceLocation srcLoc) {
if (name == "ext_type") {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(type);
return spvContext.getCreatedSpirvIntrinsicType(typeId);
}
if (name == "ext_result_id") {
QualType realType = hlsl::GetHLSLResourceTemplateParamType(type);
return lowerType(realType, rule, llvm::None, srcLoc);
}
emitError("unknown type %0 in vk namespace", srcLoc) << type;
return nullptr;
}

const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
SpirvLayoutRule rule,
SourceLocation srcLoc) {
Expand All @@ -535,6 +551,10 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
assert(recordType);
const llvm::StringRef name = recordType->getDecl()->getName();

if (isTypeInVkNamespace(recordType)) {
return lowerVkTypeInVkNamespace(type, name, rule, srcLoc);
}

// TODO: avoid string comparison once hlsl::IsHLSLResouceType() does that.

{ // Texture types
Expand Down Expand Up @@ -593,11 +613,6 @@ const SpirvType *LowerTypeVisitor::lowerResourceType(QualType type,
if (name == "RayQuery")
return spvContext.getRayQueryTypeKHR();

if (name == "ext_type") {
auto typeId = hlsl::GetHLSLResourceTemplateUInt(type);
return spvContext.getCreatedSpirvIntrinsicType(typeId);
}

if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||
name == "AppendStructuredBuffer" || name == "ConsumeStructuredBuffer") {
// StructureBuffer<S> will be translated into an OpTypeStruct with one
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ class LowerTypeVisitor : public Visitor {
const SpirvType *lowerResourceType(QualType type, SpirvLayoutRule rule,
SourceLocation);

/// Lowers the given type defined in vk namespace into its SPIR-V type.
const SpirvType *lowerVkTypeInVkNamespace(QualType type, llvm::StringRef name,
SpirvLayoutRule rule,
SourceLocation srcLoc);

/// For the given sampled type, returns the corresponding image format
/// that can be used to create an image object.
spv::ImageFormat translateSampledTypeToImageFormat(QualType sampledType,
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,11 @@ void SpirvEmitter::doVarDecl(const VarDecl *decl) {
const auto loc = decl->getLocation();
const auto range = decl->getSourceRange();

if (isExtResultIdType(decl->getType())) {
declIdMapper.createResultId(decl);
return;
}

// HLSL has the 'string' type which can be used for rare purposes such as
// printf (SPIR-V's DebugPrintf). SPIR-V does not have a 'char' or 'string'
// type, and therefore any variable of such type should not be created.
Expand Down
16 changes: 15 additions & 1 deletion tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ enum ArBasicKind {
AR_OBJECT_VK_SUBPASS_INPUT,
AR_OBJECT_VK_SUBPASS_INPUT_MS,
AR_OBJECT_VK_SPV_INTRINSIC_TYPE,
AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID,
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends

Expand Down Expand Up @@ -476,6 +477,7 @@ const UINT g_uBasicKindProps[] =
BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT
BPROP_OBJECT | BPROP_RBUFFER, // AR_OBJECT_VK_SUBPASS_INPUT_MS
BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE use recordType
BPROP_OBJECT, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID use recordType
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends

Expand Down Expand Up @@ -1400,6 +1402,7 @@ const ArBasicKind g_ArBasicKindsAsTypes[] =
AR_OBJECT_VK_SUBPASS_INPUT,
AR_OBJECT_VK_SUBPASS_INPUT_MS,
AR_OBJECT_VK_SPV_INTRINSIC_TYPE,
AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID,
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends

Expand Down Expand Up @@ -1493,6 +1496,7 @@ const uint8_t g_ArBasicKindsTemplateCount[] =
1, // AR_OBJECT_VK_SUBPASS_INPUT
1, // AR_OBJECT_VK_SUBPASS_INPUT_MS,
1, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE
1, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends

Expand Down Expand Up @@ -1594,6 +1598,7 @@ const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] =
{ 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SUBPASS_INPUT (SubpassInput)
{ 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SUBPASS_INPUT_MS (SubpassInputMS)
{ 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SPV_INTRINSIC_TYPE
{ 0, MipsFalse, SampleFalse }, // AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends

Expand Down Expand Up @@ -1714,6 +1719,7 @@ const char* g_ArBasicTypeNames[] =
"SubpassInput",
"SubpassInputMS",
"ext_type",
"ext_result_id",
#endif // ENABLE_SPIRV_CODEGEN
// SPIRV change ends

Expand Down Expand Up @@ -3602,6 +3608,13 @@ class HLSLExternalSource : public ExternalSemaSource {
*m_context, m_vkNSDecl, typeName, "id");
recordDecl->setImplicit(true);
}
else if (kind == AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID && m_vkNSDecl) {
recordDecl = DeclareTemplateTypeWithHandleInDeclContext(*m_context,
m_vkNSDecl,
typeName, 1,
nullptr);
recordDecl->setImplicit(true);
}
#endif
else if (templateArgCount == 0) {
recordDecl = DeclareRecordTypeWithHandle(*m_context, typeName);
Expand Down Expand Up @@ -12924,7 +12937,8 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth,
if (!getLangOpts().SPIRV) {
if (basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT ||
basicKind == ArBasicKind::AR_OBJECT_VK_SUBPASS_INPUT_MS ||
basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_TYPE) {
basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_TYPE ||
basicKind == ArBasicKind::AR_OBJECT_VK_SPV_INTRINSIC_RESULT_ID) {
Diag(D.getLocStart(), diag::err_hlsl_vulkan_specific_feature)
<< g_ArBasicTypeNames[basicKind];
result = false;
Expand Down
20 changes: 20 additions & 0 deletions tools/clang/test/CodeGenSPIRV/spv.intrinsic.result_id.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: %dxc -T cs_6_0 -E main

[[vk::ext_instruction(/* OpLoad */ 61)]]
vk::ext_result_id<float> load([[vk::ext_reference]] float pointer,
[[vk::ext_literal]] int memoryOperands);

[[vk::ext_instruction(/* OpStore */ 62)]]
void store([[vk::ext_reference]] float pointer,
vk::ext_result_id<float> value,
[[vk::ext_literal]] int memoryOperands);

[numthreads(1,1,1)]
void main() {
float foo, bar;

//CHECK: [[foo_value:%\w+]] = OpLoad %float %foo None
//CHECK: OpStore %bar [[foo_value]] Volatile
vk::ext_result_id<float> foo_value = load(foo, /* None */ 0x0);
store(bar, foo_value, /* Volatile */ 0x1);
}
1 change: 1 addition & 0 deletions tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,7 @@ TEST_F(FileTest, IntrinsicsVkQueueFamilyScope) {
}
TEST_F(FileTest, IntrinsicsSpirv) {
runFileTest("spv.intrinsicInstruction.hlsl");
runFileTest("spv.intrinsic.result_id.hlsl");
runFileTest("spv.intrinsicLiteral.hlsl");
runFileTest("spv.intrinsicDecorate.hlsl", Expect::Success, false);
runFileTest("spv.intrinsicExecutionMode.hlsl", Expect::Success, false);
Expand Down

0 comments on commit 19139d8

Please sign in to comment.