Skip to content

Commit

Permalink
[HLSL] Use hlsl vector template in type printer (#95489)
Browse files Browse the repository at this point in the history
In HLSL we really want to be using the HLSL vector template and other
built-in sugared spellings for some builtin types. This updates the type
printer to take an option to use HLSL type spellings.

This changes printing vector type names from:

```
T __attribute__((ext_vector_type(N)))
```
To:
```
vector<T, N>
```
  • Loading branch information
llvm-beanz committed Jun 14, 2024
1 parent 3ecba1a commit b6fd6d4
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 126 deletions.
7 changes: 6 additions & 1 deletion clang/include/clang/AST/PrettyPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct PrintingPolicy {
PrintCanonicalTypes(false), PrintInjectedClassNameWithArguments(true),
UsePreferredNames(true), AlwaysIncludeTypeForTemplateArgument(false),
CleanUglifiedParameters(false), EntireContentsOfLargeArray(true),
UseEnumerators(true) {}
UseEnumerators(true), UseHLSLTypes(LO.HLSL) {}

/// Adjust this printing policy for cases where it's known that we're
/// printing C++ code (for instance, if AST dumping reaches a C++-only
Expand Down Expand Up @@ -342,6 +342,11 @@ struct PrintingPolicy {
LLVM_PREFERRED_TYPE(bool)
unsigned UseEnumerators : 1;

/// Whether or not we're printing known HLSL code and should print HLSL
/// sugared types when possible.
LLVM_PREFERRED_TYPE(bool)
unsigned UseHLSLTypes : 1;

/// Callbacks to use to allow the behavior of printing to be customized.
const PrintingCallbacks *Callbacks = nullptr;
};
Expand Down
32 changes: 25 additions & 7 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,16 +644,25 @@ void TypePrinter::printDependentAddressSpaceAfter(
void TypePrinter::printDependentSizedExtVectorBefore(
const DependentSizedExtVectorType *T,
raw_ostream &OS) {
if (Policy.UseHLSLTypes)
OS << "vector<";
printBefore(T->getElementType(), OS);
}

void TypePrinter::printDependentSizedExtVectorAfter(
const DependentSizedExtVectorType *T,
raw_ostream &OS) {
OS << " __attribute__((ext_vector_type(";
if (T->getSizeExpr())
T->getSizeExpr()->printPretty(OS, nullptr, Policy);
OS << ")))";
if (Policy.UseHLSLTypes) {
OS << ", ";
if (T->getSizeExpr())
T->getSizeExpr()->printPretty(OS, nullptr, Policy);
OS << ">";
} else {
OS << " __attribute__((ext_vector_type(";
if (T->getSizeExpr())
T->getSizeExpr()->printPretty(OS, nullptr, Policy);
OS << ")))";
}
printAfter(T->getElementType(), OS);
}

Expand Down Expand Up @@ -815,14 +824,23 @@ void TypePrinter::printDependentVectorAfter(

void TypePrinter::printExtVectorBefore(const ExtVectorType *T,
raw_ostream &OS) {
if (Policy.UseHLSLTypes)
OS << "vector<";
printBefore(T->getElementType(), OS);
}

void TypePrinter::printExtVectorAfter(const ExtVectorType *T, raw_ostream &OS) {
printAfter(T->getElementType(), OS);
OS << " __attribute__((ext_vector_type(";
OS << T->getNumElements();
OS << ")))";

if (Policy.UseHLSLTypes) {
OS << ", ";
OS << T->getNumElements();
OS << ">";
} else {
OS << " __attribute__((ext_vector_type(";
OS << T->getNumElements();
OS << ")))";
}
}

void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
Expand Down
2 changes: 1 addition & 1 deletion clang/test/AST/HLSL/pch.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
hlsl::RWBuffer<float> Buffer;

float2 bar(float2 a, float2 b) {
// CHECK:CallExpr 0x{{[0-9a-f]+}} <col:10, col:18> 'float2':'float __attribute__((ext_vector_type(2)))'
// CHECK:CallExpr 0x{{[0-9a-f]+}} <col:10, col:18> 'float2':'vector<float, 2>'
// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (*)(float2, float2)' <FunctionToPointerDecay>
// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)'
return foo(a, b);
Expand Down
2 changes: 1 addition & 1 deletion clang/test/AST/HLSL/pch_with_buf.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
hlsl::RWBuffer<float> Buf2;

float2 bar(float2 a, float2 b) {
// CHECK:CallExpr 0x{{[0-9a-f]+}} <col:10, col:18> 'float2':'float __attribute__((ext_vector_type(2)))'
// CHECK:CallExpr 0x{{[0-9a-f]+}} <col:10, col:18> 'float2':'vector<float, 2>'
// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (*)(float2, float2)' <FunctionToPointerDecay>
// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)'
return foo(a, b);
Expand Down
16 changes: 8 additions & 8 deletions clang/test/AST/HLSL/vector-alias.hlsl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s

// CHECK: NamespaceDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit hlsl
// CHECK-NEXT: TypeAliasTemplateDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector
Expand All @@ -8,8 +8,8 @@
// CHECK-NEXT: NonTypeTemplateParmDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> 'int' depth 0 index 1 element_count
// CHECK-NEXT: TemplateArgument expr
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 4
// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'element __attribute__((ext_vector_type(element_count)))'
// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'element __attribute__((ext_vector_type(element_count)))' dependent <invalid sloc>
// CHECK-NEXT: TypeAliasDecl 0x{{[0-9a-fA-F]+}} <<invalid sloc>> <invalid sloc> implicit vector 'vector<element, element_count>'
// CHECK-NEXT: DependentSizedExtVectorType 0x{{[0-9a-fA-F]+}} 'vector<element, element_count>' dependent <invalid sloc>
// CHECK-NEXT: TemplateTypeParmType 0x{{[0-9a-fA-F]+}} 'element' dependent depth 0 index 0
// CHECK-NEXT: TemplateTypeParm 0x{{[0-9a-fA-F]+}} 'element'
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' lvalue
Expand All @@ -24,30 +24,30 @@ int entry() {
hlsl::vector<float, 2> Vec2 = {1.0, 2.0};

// CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:24:3, col:43>
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:42> col:26 Vec2 'hlsl::vector<float, 2>':'float __attribute__((ext_vector_type(2)))' cinit
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:42> col:26 Vec2 'hlsl::vector<float, 2>':'vector<float, 2>' cinit

// Verify that you don't need to specify the namespace.
vector<int, 2> Vec2a = {1, 2};

// CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:30:3, col:32>
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:31> col:18 Vec2a 'vector<int, 2>':'int __attribute__((ext_vector_type(2)))' cinit
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:31> col:18 Vec2a 'vector<int, 2>' cinit

// Build a bigger vector.
vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0};

// CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:36:3, col:48>
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:47> col:21 used Vec4 'vector<double, 4>':'double __attribute__((ext_vector_type(4)))' cinit
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:47> col:21 used Vec4 'vector<double, 4>' cinit

// Verify that swizzles still work.
vector<double, 3> Vec3 = Vec4.xyz;

// CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:42:3, col:36>
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:33> col:21 Vec3 'vector<double, 3>':'double __attribute__((ext_vector_type(3)))' cinit
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:33> col:21 Vec3 'vector<double, 3>' cinit

// Verify that the implicit arguments generate the correct type.
vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0};

// CHECK: DeclStmt 0x{{[0-9a-fA-F]+}} <line:48:3, col:42>
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:41> col:12 ImpVec4 'vector<>':'float __attribute__((ext_vector_type(4)))' cinit
// CHECK-NEXT: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:41> col:12 ImpVec4 'vector<>':'vector<float, 4>' cinit
return 1;
}
22 changes: 11 additions & 11 deletions clang/test/AST/HLSL/vector-constructors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,32 @@ void entry() {

// For the float2 vector, we just expect a conversion from constructor
// parameters to an initialization list
// CHECK-LABEL: VarDecl 0x{{[0-9a-fA-F]+}} {{.*}} used Vec2 'float2':'float __attribute__((ext_vector_type(2)))' cinit
// CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'float __attribute__((ext_vector_type(2)))' functional cast to float2 <NoOp>
// CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'float __attribute__((ext_vector_type(2)))'
// CHECK-LABEL: VarDecl 0x{{[0-9a-fA-F]+}} {{.*}} used Vec2 'float2':'vector<float, 2>' cinit
// CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'vector<float, 2>' functional cast to float2 <NoOp>
// CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'vector<float, 2>'
// CHECK-NEXT: FloatingLiteral 0x{{[0-9a-fA-F]+}} {{.*}} 'float' 1.000000e+00
// CHECK-NEXT: FloatingLiteral 0x{{[0-9a-fA-F]+}} {{.*}} 'float' 2.000000e+00


// For the float 3 things get fun...
// Here we expect accesses to the vec2 to provide the first and second
// components using ArraySubscriptExpr
// CHECK-LABEL: VarDecl 0x{{[0-9a-fA-F]+}} {{.*}} col:10 Vec3 'float3':'float __attribute__((ext_vector_type(3)))' cinit
// CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'float __attribute__((ext_vector_type(3)))' functional cast to float3 <NoOp>
// CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'float __attribute__((ext_vector_type(3)))'
// CHECK-LABEL: VarDecl 0x{{[0-9a-fA-F]+}} {{.*}} col:10 Vec3 'float3':'vector<float, 3>' cinit
// CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'vector<float, 3>' functional cast to float3 <NoOp>
// CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'vector<float, 3>'
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:24, <invalid sloc>> 'float' <LValueToRValue>
// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9a-fA-F]+}} <col:24, <invalid sloc>> 'float' lvalue
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'float __attribute__((ext_vector_type(2)))' lvalue Var 0x{{[0-9a-fA-F]+}} 'Vec2' 'float2':'float __attribute__((ext_vector_type(2)))'
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'vector<float, 2>' lvalue Var 0x{{[0-9a-fA-F]+}} 'Vec2' 'float2':'vector<float, 2>'
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 0
// CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:24, <invalid sloc>> 'float' <LValueToRValue>
// CHECK-NEXT: ArraySubscriptExpr 0x{{[0-9a-fA-F]+}} <col:24, <invalid sloc>> 'float' lvalue
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'float __attribute__((ext_vector_type(2)))' lvalue Var 0x{{[0-9a-fA-F]+}} 'Vec2' 'float2':'float __attribute__((ext_vector_type(2)))'
// CHECK-NEXT: DeclRefExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float2':'vector<float, 2>' lvalue Var 0x{{[0-9a-fA-F]+}} 'Vec2' 'float2':'vector<float, 2>'
// CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <<invalid sloc>> 'int' 1
// CHECK-NEXT: FloatingLiteral 0x{{[0-9a-fA-F]+}} {{.*}} 'float' 3.000000e+00

// CHECK: VarDecl 0x{{[0-9a-fA-F]+}} {{.*}} col:10 Vec3b 'float3':'float __attribute__((ext_vector_type(3)))' cinit
// CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'float __attribute__((ext_vector_type(3)))' functional cast to float3 <NoOp>
// CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'float __attribute__((ext_vector_type(3)))'
// CHECK: VarDecl 0x{{[0-9a-fA-F]+}} {{.*}} col:10 Vec3b 'float3':'vector<float, 3>' cinit
// CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'vector<float, 3>' functional cast to float3 <NoOp>
// CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} {{.*}} 'float3':'vector<float, 3>'

// CHECK-NEXT: FloatingLiteral 0x{{[0-9a-fA-F]+}} {{.*}} 'float' 1.000000e+00
// CHECK-NEXT: FloatingLiteral 0x{{[0-9a-fA-F]+}} {{.*}} 'float' 2.000000e+00
Expand Down
2 changes: 1 addition & 1 deletion clang/test/SemaHLSL/BuiltIns/RWBuffers.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ RWBuffer<> BufferErr2;

[numthreads(1,1,1)]
void main() {
(void)Buffer.h; // expected-error {{'h' is a private member of 'hlsl::RWBuffer<float __attribute__((ext_vector_type(3)))>'}}
(void)Buffer.h; // expected-error {{'h' is a private member of 'hlsl::RWBuffer<vector<float, 3> >'}}
// expected-note@* {{implicitly declared private here}}
}
2 changes: 1 addition & 1 deletion clang/test/SemaHLSL/BuiltIns/clamp-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ float2 test_clamp_no_second_arg(float2 p0) {

float2 test_clamp_vector_size_mismatch(float3 p0, float2 p1) {
return clamp(p0, p0, p1);
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}}
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 2>' (vector of 2 'float' values)}}
}

float2 test_clamp_builtin_vector_size_mismatch(float3 p0, float2 p1) {
Expand Down
2 changes: 1 addition & 1 deletion clang/test/SemaHLSL/BuiltIns/dot-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ float test_dot_no_second_arg(float2 p0) {

float test_dot_vector_size_mismatch(float3 p0, float2 p1) {
return dot(p0, p1);
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}}
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 2>' (vector of 2 'float' values)}}
}

float test_dot_builtin_vector_size_mismatch(float3 p0, float2 p1) {
Expand Down
2 changes: 1 addition & 1 deletion clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ float2 test_lerp_no_second_arg(float2 p0) {

float2 test_lerp_vector_size_mismatch(float3 p0, float2 p1) {
return lerp(p0, p0, p1);
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}}
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 2>' (vector of 2 'float' values)}}
}

float2 test_lerp_builtin_vector_size_mismatch(float3 p0, float2 p1) {
Expand Down
2 changes: 1 addition & 1 deletion clang/test/SemaHLSL/BuiltIns/mad-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ float2 test_mad_no_second_arg(float2 p0) {

float2 test_mad_vector_size_mismatch(float3 p0, float2 p1) {
return mad(p0, p0, p1);
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'float __attribute__((ext_vector_type(2)))' (vector of 2 'float' values)}}
// expected-warning@-1 {{implicit conversion truncates vector: 'float3' (aka 'vector<float, 3>') to 'vector<float, 2>' (vector of 2 'float' values)}}
}

float2 test_mad_builtin_vector_size_mismatch(float3 p0, float2 p1) {
Expand Down
4 changes: 2 additions & 2 deletions clang/test/SemaHLSL/BuiltIns/vector-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// Some bad declarations
hlsl::vector ShouldWorkSomeday; // expected-error{{use of alias template 'hlsl::vector' requires template arguments}}
// expected-note@*:* {{template declaration from hidden source: template <class element = float, int element_count = 4> using vector = element __attribute__((ext_vector_type(element_count)))}}
// expected-note@*:* {{template declaration from hidden source: template <class element = float, int element_count = 4> using vector = vector<element, element_count>}}

hlsl::vector<1> BadVec; // expected-error{{template argument for template type parameter must be a type}}
// expected-note@*:* {{template parameter from hidden source: class element = float}}
Expand All @@ -11,7 +11,7 @@ hlsl::vector<int, float> AnotherBadVec; // expected-error{{template argument for
// expected-note@*:* {{template parameter from hidden source: int element_count = 4}}

hlsl::vector<int, 2, 3> YABV; // expected-error{{too many template arguments for alias template 'vector'}}
// expected-note@*:* {{template declaration from hidden source: template <class element = float, int element_count = 4> using vector = element __attribute__((ext_vector_type(element_count)))}}
// expected-note@*:* {{template declaration from hidden source: template <class element = float, int element_count = 4> using vector = vector<element, element_count>}}

// This code is rejected by clang because clang puts the HLSL built-in types
// into the HLSL namespace.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -x hlsl -finclude-default-header -verify %s

int2 ToTwoInts(int V) {
return V.xy; // expected-error{{vector component access exceeds type 'int __attribute__((ext_vector_type(1)))' (vector of 1 'int' value)}}
return V.xy; // expected-error{{vector component access exceeds type 'vector<int, 1>' (vector of 1 'int' value)}}
}

float2 ToTwoFloats(float V) {
return V.rg; // expected-error{{vector component access exceeds type 'float __attribute__((ext_vector_type(1)))' (vector of 1 'float' value)}}
return V.rg; // expected-error{{vector component access exceeds type 'vector<float, 1>' (vector of 1 'float' value)}}
}

int4 SomeNonsense(int V) {
Expand Down
Loading

0 comments on commit b6fd6d4

Please sign in to comment.