Skip to content

Commit

Permalink
Merge pull request #856 from Sergio0694/dev/shader-explicit-entry-points
Browse files Browse the repository at this point in the history
Support explicit interface implementations for shader entry points
  • Loading branch information
Sergio0694 authored Sep 21, 2024
2 parents 57f0a9b + f7b03ea commit f077308
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ private static partial class HlslSource
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
/// <param name="inputCount">The number of inputs for the shader.</param>
/// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
/// <param name="inputComplexIndices">The indicess of the complex shader inputs.</param>
/// <param name="inputSimpleIndices">The indices of the simple shader inputs.</param>
/// <param name="inputComplexIndices">The indices of the complex shader inputs.</param>
/// <param name="token">The <see cref="CancellationToken"/> used to cancel the operation, if needed.</param>
/// <returns>The HLSL source for the shader.</returns>
public static string GetHlslSource(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
Compilation compilation,
INamedTypeSymbol structDeclarationSymbol,
INamedTypeSymbol shaderInterfaceType,
int inputCount,
ImmutableArray<int> inputSimpleIndices,
ImmutableArray<int> inputComplexIndices,
Expand All @@ -46,6 +48,8 @@ public static string GetHlslSource(
// Detect any invalid properties
HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol);

token.ThrowIfCancellationRequested();

// We need to sets to track all discovered custom types and static methods
HashSet<INamedTypeSymbol> discoveredTypes = new(SymbolEqualityComparer.Default);
Dictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods = new(SymbolEqualityComparer.Default);
Expand All @@ -72,6 +76,7 @@ public static string GetHlslSource(
(string entryPoint, ImmutableArray<HlslMethod> processedMethods) = GetProcessedMethods(
diagnostics,
structDeclarationSymbol,
shaderInterfaceType,
semanticModelProvider,
discoveredTypes,
staticMethods,
Expand Down Expand Up @@ -302,6 +307,7 @@ private static ImmutableArray<HlslStaticField> GetStaticFields(
/// </summary>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
/// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
/// <param name="discoveredTypes">The collection of currently discovered types.</param>
/// <param name="staticMethods">The set of discovered and processed static methods.</param>
Expand All @@ -315,6 +321,7 @@ private static ImmutableArray<HlslStaticField> GetStaticFields(
private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProcessedMethods(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
INamedTypeSymbol structDeclarationSymbol,
INamedTypeSymbol shaderInterfaceType,
SemanticModelProvider semanticModel,
ICollection<INamedTypeSymbol> discoveredTypes,
IDictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods,
Expand All @@ -327,6 +334,7 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProces
{
using ImmutableArrayBuilder<HlslMethod> methods = new();

IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType.GetMethod("Execute")!;
string? entryPoint = null;

// By default, the scene position is not required. We will set this while
Expand All @@ -341,16 +349,17 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods) GetProces
continue;
}

// Ensure that we have accessible source information
if (!methodSymbol.TryGetSyntaxNode(token, out MethodDeclarationSyntax? methodDeclaration))
{
continue;
}

bool isShaderEntryPoint =
methodSymbol.Name == "Execute" &&
methodSymbol.ReturnType.HasFullyQualifiedMetadataName("ComputeSharp.Float4") &&
methodSymbol.TypeParameters.Length == 0 &&
methodSymbol.Parameters.Length == 0;
// Check whether the current method is the entry point (ie. it's implementing 'Execute').
// This is the same logic as in the DX12 generator for compute shaders and pixel shaders.
bool isShaderEntryPoint = SymbolEqualityComparer.Default.Equals(
structDeclarationSymbol.FindImplementationForInterfaceMember(entryPointInterfaceMethod),
methodSymbol);

// Except for the entry point, ignore explicit interface implementations
if (!isShaderEntryPoint && !methodSymbol.ExplicitInterfaceImplementations.IsDefaultOrEmpty)
Expand Down Expand Up @@ -460,8 +469,8 @@ private static bool GetD2DRequiresScenePositionInfo(INamedTypeSymbol structDecla
/// <param name="typeMethodDeclarations"><inheritdoc cref="HlslSourceSyntaxProcessor.WriteMethodDeclarations" path="/param[@name='typeMethodDeclarations']/node()"/></param>
/// <param name="executeMethod">The body of the entry point of the shader.</param>
/// <param name="inputCount">The number of shader inputs to declare.</param>
/// <param name="inputSimpleIndices">The indicess of the simple shader inputs.</param>
/// <param name="inputComplexIndices">The indicess of the complex shader inputs.</param>
/// <param name="inputSimpleIndices">The indices of the simple shader inputs.</param>
/// <param name="inputComplexIndices">The indices of the complex shader inputs.</param>
/// <param name="requiresScenePosition">Whether the shader requires the scene position.</param>
/// <returns>The series of statements to build the HLSL source to compile to execute the current shader.</returns>
private static string GetHlslSource(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
return default;
}

INamedTypeSymbol shaderInterfaceType = context.SemanticModel.Compilation.GetTypeByMetadataName("ComputeSharp.D2D1.ID2D1PixelShader")!;

// Check that the shader implements the ID2D1PixelShader interface
if (!typeSymbol.HasInterfaceWithType(context.SemanticModel.Compilation.GetTypeByMetadataName("ComputeSharp.D2D1.ID2D1PixelShader")!))
if (!typeSymbol.HasInterfaceWithType(shaderInterfaceType))
{
return default;
}
Expand Down Expand Up @@ -142,6 +144,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
diagnostics,
context.SemanticModel.Compilation,
typeSymbol,
shaderInterfaceType,
inputCount,
inputSimpleIndices,
inputComplexIndices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,23 @@ internal sealed partial class ShaderSourceRewriter(
Diagnostics.Add(UnsafeModifierOnMethodOrFunction, node);
}

if (updatedNode is not null)
// Add the tracked implicit declarations (at the start of the body).
// To optimize, we only do this if we do have any implicit variables.
if (this.implicitVariables.Count > 0)
{
BlockSyntax implicitBlock = Block(this.implicitVariables.Select(static v => LocalDeclarationStatement(v)).ToArray());

// Add the tracked implicit declarations (at the start of the body)
updatedNode = updatedNode.WithBody(implicitBlock).AddBodyStatements([.. updatedNode.Body!.Statements]);
}

// The entry point might be an explicit interface method implementation. In that case,
// the transpiled method will have the rewritten interface name as a prefix for the
// method name, which we don't want (it's invalid HLSL). So in that case, remove it.
if (this.isEntryPoint && updatedNode.ExplicitInterfaceSpecifier is not null)
{
updatedNode = updatedNode.WithExplicitInterfaceSpecifier(null);
}

return updatedNode;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@ namespace ComputeSharp.SourceGeneration.Extensions;
/// </summary>
internal static class ITypeSymbolExtensions
{
/// <summary>
/// Gets the method of this symbol that have a particular name.
/// </summary>
/// <param name="symbol">The input <see cref="ITypeSymbol"/> instance to check.</param>
/// <param name="name">The name of the method to find.</param>
/// <returns>The target method, if present.</returns>
public static IMethodSymbol? GetMethod(this ITypeSymbol symbol, string name)
{
foreach (ISymbol memberSymbol in symbol.GetMembers(name))
{
if (memberSymbol is IMethodSymbol methodSymbol &&
memberSymbol.Name == name)
{
return methodSymbol;
}
}

return null;
}

/// <summary>
/// Checks whether or not a given type symbol has a specified fully qualified metadata name.
/// </summary>
Expand All @@ -28,7 +48,7 @@ public static bool HasFullyQualifiedMetadataName(this ITypeSymbol symbol, string
/// Checks whether or not a given <see cref="ITypeSymbol"/> implements an interface of a specified type.
/// </summary>
/// <param name="typeSymbol">The target <see cref="ITypeSymbol"/> instance to check.</param>
/// <param name="interfaceSymbol">The <see cref="ITypeSymbol"/> instane to check for inheritance from.</param>
/// <param name="interfaceSymbol">The <see cref="ITypeSymbol"/> instance to check for inheritance from.</param>
/// <returns>Whether or not <paramref name="typeSymbol"/> has an interface of type <paramref name="interfaceSymbol"/>.</returns>
public static bool HasInterfaceWithType(this ITypeSymbol typeSymbol, ITypeSymbol interfaceSymbol)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;

namespace ComputeSharp.SourceGenerators;
Expand All @@ -10,9 +11,14 @@ partial class ComputeShaderDescriptorGenerator
/// </summary>
/// <param name="typeSymbol">The input <see cref="INamedTypeSymbol"/> instance to check.</param>
/// <param name="compilation">The <see cref="Compilation"/> instance currently in use.</param>
/// <param name="shaderInterfaceType">The (constructed) shader interface type implemented by the shader type.</param>
/// <param name="isPixelShaderLike">Whether <paramref name="typeSymbol"/> is a "pixel shader like" type.</param>
/// <returns>Whether <paramref name="typeSymbol"/> is a compute shader type at all.</returns>
private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compilation compilation, out bool isPixelShaderLike)
private static bool TryGetIsPixelShaderLike(
INamedTypeSymbol typeSymbol,
Compilation compilation,
[NotNullWhen(true)] out INamedTypeSymbol? shaderInterfaceType,
out bool isPixelShaderLike)
{
INamedTypeSymbol computeShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader")!;
INamedTypeSymbol pixelShaderSymbol = compilation.GetTypeByMetadataName("ComputeSharp.IComputeShader`1")!;
Expand All @@ -21,18 +27,21 @@ private static bool TryGetIsPixelShaderLike(INamedTypeSymbol typeSymbol, Compila
{
if (SymbolEqualityComparer.Default.Equals(interfaceSymbol, computeShaderSymbol))
{
shaderInterfaceType = interfaceSymbol;
isPixelShaderLike = false;

return true;
}
else if (SymbolEqualityComparer.Default.Equals(interfaceSymbol.ConstructedFrom, pixelShaderSymbol))
{
shaderInterfaceType = interfaceSymbol;
isPixelShaderLike = true;

return true;
}
}

shaderInterfaceType = null;
isPixelShaderLike = false;

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ internal static partial class HlslSource
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="compilation">The input <see cref="Compilation"/> object currently in use.</param>
/// <param name="structDeclarationSymbol">The <see cref="INamedTypeSymbol"/> for the shader type.</param>
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
/// <param name="isPixelShaderLike">Whether <paramref name="structDeclarationSymbol"/> is a "pixel shader like" type.</param>
/// <param name="threadsX">The thread ids value for the X axis.</param>
/// <param name="threadsY">The thread ids value for the Y axis.</param>
/// <param name="threadsZ">The thread ids value for the Z axis.</param>
Expand All @@ -42,6 +44,8 @@ public static void GetInfo(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
Compilation compilation,
INamedTypeSymbol structDeclarationSymbol,
INamedTypeSymbol shaderInterfaceType,
bool isPixelShaderLike,
int threadsX,
int threadsY,
int threadsZ,
Expand All @@ -53,6 +57,8 @@ public static void GetInfo(
// Detect any invalid properties
HlslDefinitionsSyntaxProcessor.DetectAndReportInvalidPropertyDeclarations(diagnostics, structDeclarationSymbol);

token.ThrowIfCancellationRequested();

// We need to sets to track all discovered custom types and static methods
HashSet<INamedTypeSymbol> discoveredTypes = new(SymbolEqualityComparer.Default);
Dictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods = new(SymbolEqualityComparer.Default);
Expand All @@ -62,9 +68,8 @@ public static void GetInfo(
Dictionary<IFieldSymbol, HlslStaticField> staticFieldDefinitions = new(SymbolEqualityComparer.Default);

// Setup the semantic model and basic properties
INamedTypeSymbol? pixelShaderSymbol = structDeclarationSymbol.AllInterfaces.FirstOrDefault(static interfaceSymbol => interfaceSymbol is { IsGenericType: true, Name: "IComputeShader" });
bool isComputeShader = pixelShaderSymbol is null;
string? implicitTextureType = isComputeShader ? null : HlslKnownTypes.GetMappedNameForPixelShaderType(pixelShaderSymbol!);
bool isComputeShader = !isPixelShaderLike;
string? implicitTextureType = HlslKnownTypes.GetMappedNameForPixelShaderType(shaderInterfaceType);

token.ThrowIfCancellationRequested();

Expand All @@ -90,6 +95,7 @@ public static void GetInfo(
(string entryPoint, ImmutableArray<HlslMethod> processedMethods, isSamplerUsed) = GetProcessedMethods(
diagnostics,
structDeclarationSymbol,
shaderInterfaceType,
semanticModelProvider,
discoveredTypes,
staticMethods,
Expand Down Expand Up @@ -360,6 +366,7 @@ private static ImmutableArray<HlslSharedBuffer> GetSharedBuffers(
/// </summary>
/// <param name="diagnostics">The collection of produced <see cref="DiagnosticInfo"/> instances.</param>
/// <param name="structDeclarationSymbol">The type symbol for the shader type.</param>
/// <param name="shaderInterfaceType">The shader interface type implemented by the shader type.</param>
/// <param name="semanticModel">The <see cref="SemanticModelProvider"/> instance for the type to process.</param>
/// <param name="discoveredTypes">The collection of currently discovered types.</param>
/// <param name="staticMethods">The set of discovered and processed static methods.</param>
Expand All @@ -373,6 +380,7 @@ private static ImmutableArray<HlslSharedBuffer> GetSharedBuffers(
private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSamplerUser) GetProcessedMethods(
ImmutableArrayBuilder<DiagnosticInfo> diagnostics,
INamedTypeSymbol structDeclarationSymbol,
INamedTypeSymbol shaderInterfaceType,
SemanticModelProvider semanticModel,
ICollection<INamedTypeSymbol> discoveredTypes,
IDictionary<IMethodSymbol, MethodDeclarationSyntax> staticMethods,
Expand All @@ -385,6 +393,7 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
{
using ImmutableArrayBuilder<HlslMethod> methods = new();

IMethodSymbol entryPointInterfaceMethod = shaderInterfaceType.GetMethod("Execute")!;
string? entryPoint = null;
bool isSamplerUsed = false;

Expand All @@ -396,22 +405,17 @@ private static (string EntryPoint, ImmutableArray<HlslMethod> Methods, bool IsSa
continue;
}

// Ensure that we have accessible source information
if (!methodSymbol.TryGetSyntaxNode(token, out MethodDeclarationSyntax? methodDeclaration))
{
continue;
}

bool isShaderEntryPoint =
(isComputeShader &&
methodSymbol.Name == "Execute" &&
methodSymbol.ReturnsVoid &&
methodSymbol.TypeParameters.Length == 0 &&
methodSymbol.Parameters.Length == 0) ||
(!isComputeShader &&
methodSymbol.Name == "Execute" &&
methodSymbol.ReturnType is not null && // TODO: match for pixel type
methodSymbol.TypeParameters.Length == 0 &&
methodSymbol.Parameters.Length == 0);
// Check whether the current method is the entry point (ie. it's implementing 'Execute'). We use
// 'FindImplementationForInterfaceMember' to handle explicit interface implementations as well.
bool isShaderEntryPoint = SymbolEqualityComparer.Default.Equals(
structDeclarationSymbol.FindImplementationForInterfaceMember(entryPointInterfaceMethod),
methodSymbol);

// Except for the entry point, ignore explicit interface implementations
if (!isShaderEntryPoint && !methodSymbol.ExplicitInterfaceImplementations.IsDefaultOrEmpty)
Expand Down
Loading

0 comments on commit f077308

Please sign in to comment.