diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs
index cb8a39463..876be56c1 100644
--- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs
+++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs
@@ -25,6 +25,11 @@ partial class ID2D1ShaderGenerator
///
internal static partial class LoadBytecode
{
+ ///
+ /// The shared cache of values.
+ ///
+ private static readonly DynamicCache HlslBytecodeCache = new();
+
///
/// Extracts the requested shader profile for the current shader.
///
@@ -141,46 +146,54 @@ public static bool IsSimpleInputShader(INamedTypeSymbol structDeclarationSymbol,
/// The instance for the shader to compile.
/// The used to cancel the operation, if needed.
///
- public static unsafe HlslBytecodeInfo GetInfo(HlslBytecodeInfoKey key, CancellationToken token)
+ public static HlslBytecodeInfo GetInfo(ref HlslBytecodeInfoKey key, CancellationToken token)
{
- // No embedded shader was requested, or there were some errors earlier in the pipeline.
- // In this case, skip the compilation, as diagnostic will be emitted for those anyway.
- // Compiling would just add overhead and result in more errors, as the HLSL would be invalid.
- // We also skip compilation if no shader profile has been requested (we never just assume one).
- if (key.HasErrors || key.RequestedShaderProfile is null)
+ static unsafe HlslBytecodeInfo GetInfo(HlslBytecodeInfoKey key, CancellationToken token)
{
- return HlslBytecodeInfo.Missing.Instance;
- }
+ // No embedded shader was requested, or there were some errors earlier in the pipeline.
+ // In this case, skip the compilation, as diagnostic will be emitted for those anyway.
+ // Compiling would just add overhead and result in more errors, as the HLSL would be invalid.
+ // We also skip compilation if no shader profile has been requested (we never just assume one).
+ if (key.HasErrors || key.RequestedShaderProfile is null)
+ {
+ return HlslBytecodeInfo.Missing.Instance;
+ }
- try
- {
- token.ThrowIfCancellationRequested();
+ try
+ {
+ token.ThrowIfCancellationRequested();
- // Compile the shader bytecode using the effective parameters
- using ComPtr dxcBlobBytecode = D3DCompiler.Compile(
- key.HlslSource.AsSpan(),
- key.EffectiveShaderProfile,
- key.EffectiveCompileOptions);
+ // Compile the shader bytecode using the effective parameters
+ using ComPtr dxcBlobBytecode = D3DCompiler.Compile(
+ key.HlslSource.AsSpan(),
+ key.EffectiveShaderProfile,
+ key.EffectiveCompileOptions);
- token.ThrowIfCancellationRequested();
+ token.ThrowIfCancellationRequested();
- byte* buffer = (byte*)dxcBlobBytecode.Get()->GetBufferPointer();
- int length = checked((int)dxcBlobBytecode.Get()->GetBufferSize());
+ byte* buffer = (byte*)dxcBlobBytecode.Get()->GetBufferPointer();
+ int length = checked((int)dxcBlobBytecode.Get()->GetBufferSize());
- byte[] array = new ReadOnlySpan(buffer, length).ToArray();
+ byte[] array = new ReadOnlySpan(buffer, length).ToArray();
- ImmutableArray bytecode = Unsafe.As>(ref array);
+ ImmutableArray bytecode = Unsafe.As>(ref array);
- return new HlslBytecodeInfo.Success(bytecode);
- }
- catch (Win32Exception e)
- {
- return new HlslBytecodeInfo.Win32Error(e.NativeErrorCode, D3DCompiler.PrettifyFxcErrorMessage(e.Message));
- }
- catch (FxcCompilationException e)
- {
- return new HlslBytecodeInfo.FxcError(D3DCompiler.PrettifyFxcErrorMessage(e.Message));
+ return new HlslBytecodeInfo.Success(bytecode);
+ }
+ catch (Win32Exception e)
+ {
+ return new HlslBytecodeInfo.Win32Error(e.NativeErrorCode, D3DCompiler.PrettifyFxcErrorMessage(e.Message));
+ }
+ catch (FxcCompilationException e)
+ {
+ return new HlslBytecodeInfo.FxcError(D3DCompiler.PrettifyFxcErrorMessage(e.Message));
+ }
}
+
+ // Get or create the HLSL bytecode compilation result for the input key. The dynamic cache
+ // will take care of retrieving an existing cached value if the same shader has been compiled
+ // already with the same parameters. After this call, callers must use the updated key value.
+ return HlslBytecodeCache.GetOrCreate(ref key, GetInfo, token);
}
///
diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs
index b3d85420b..044b2e32d 100644
--- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs
+++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs
@@ -154,7 +154,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
hasErrors);
// TODO: cache this across transform runs
- HlslBytecodeInfo hlslInfo = LoadBytecode.GetInfo(hlslInfoKey, token);
+ HlslBytecodeInfo hlslInfo = LoadBytecode.GetInfo(ref hlslInfoKey, token);
token.ThrowIfCancellationRequested();
diff --git a/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems b/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems
index 0e4410f29..289de1eaf 100644
--- a/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems
+++ b/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems
@@ -17,6 +17,7 @@
+
diff --git a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs
new file mode 100644
index 000000000..787750893
--- /dev/null
+++ b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs
@@ -0,0 +1,312 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
+using System.Threading;
+
+namespace ComputeSharp.SourceGeneration.Helpers;
+
+///
+/// A dynamic cache that can be used to cache computed values within incremental models.
+/// This automatically trims excess items, and relies on the incremental state tables
+/// keeping values alive for at least one incremental step, in order to work correctly.
+///
+/// The type of keys to use for the cache.
+/// The type of values to store in the cache.
+public sealed class DynamicCache
+ where TKey : class
+{
+ ///
+ /// The backing instance for the cache.
+ ///
+ private readonly ConcurrentDictionary map = new();
+
+ ///
+ /// The tracking dead entries to remove.
+ ///
+ private readonly ConditionalWeakTable table = new();
+
+ ///
+ /// Gets or creates a new value for a given key, using a supplied callback if needed.
+ ///
+ /// The key to use as lookup.
+ /// The callback to use to create new values, if needed.
+ /// A cancellation token for the operation of creating a new value.
+ /// The resulting value.
+ /// Thrown if is canceled.
+ ///
+ /// This method might replace with a new instance that has the same
+ /// value according to its equality comparison logic. Callers should always use the last value of
+ /// after this method returns and discard the previous one, if different.
+ ///
+ public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback, CancellationToken cancellationToken)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Create a new entry that we will use to perform the lookup.
+ // Each entry simply forwards equality logic to the wrapped object.
+ Entry entry = new(key);
+
+ while (true)
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // We're performing a lookup on this temporary entry. We need it to
+ // track the value it will potentially match against, so we can
+ // return it to the caller. This ensures the same object is used.
+ entry.SetIsPerformingLookup(true);
+
+ // Try to check whether we already have a value in the cache for
+ // this object. That is, we want to check whether there is an entry
+ // with a value equal to the one we have now (not necessarily the same
+ // object). If we find it, we return it and throw away the new entry.
+ if (this.map.TryGetValue(entry, out TValue? value))
+ {
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // We have a match, so replace the object with the one that actually matched.
+ // This guarantees that it will remain alive, so the entry will not die.
+ key = entry.GetLastMatchedValue();
+
+ return value;
+ }
+
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Execute the slow fallback path, invoking the callback and trying to add a new
+ // value to the cache. If this succeeds, it means that the current key object has
+ // been added to the cache, so there is nothing left to do. If this fails, it means
+ // another thread has beat us to adding the key, so we should perform the initial
+ // lookup again to make sure we can find the exact instance that is in the cache.
+ // This is needed to ensure valid cache entries remain alive over time.
+ if (TryGetOrCreate(entry, key, callback, cancellationToken, out value))
+ {
+ // If the operation has been canceled after inserting a new key and value,
+ // the value will just be wasted, but there isn't really anything we can do
+ // about it at this point. We also ignore the unnecessary remover finalizer.
+ cancellationToken.ThrowIfCancellationRequested();
+
+ return value;
+ }
+ }
+ }
+
+ ///
+ /// Tries to get or create a new value for a given key, using a supplied callback.
+ ///
+ /// The instance to try to insert into the cache.
+ /// The key to use as lookup.
+ /// The callback to use to create new values, if needed.
+ /// A cancellation token for the operation of creating a new value.
+ /// The resulting value (should be ignored if the method fails).
+ /// Whether was successfully inserted into the cache.
+ /// Thrown if is canceled.
+ [MethodImpl(MethodImplOptions.NoInlining)]
+ private bool TryGetOrCreate(
+ Entry entry,
+ TKey key,
+ GetOrCreateCallback callback,
+ CancellationToken cancellationToken,
+ out TValue value)
+ {
+ // No value is present, so we can create it now
+ value = callback(key, cancellationToken);
+
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // We're about to try to add the item to the cache, so we no longer need to
+ // track the last matched value. We already have a reference to the object.
+ entry.SetIsPerformingLookup(false);
+
+ // Try to add the value (this only fails if someone raced with this thread).
+ // In this case, our key will also be added to the cache. Callers will have
+ // a strong reference to it, which ensures the weak reference remains alive.
+ if (!this.map.TryAdd(entry, value))
+ {
+ return false;
+ }
+
+ // We need to setup the removal of this key-value pair when the key is no longer
+ // referenced. To do this, we add a new EntryRemover instance to the table. When
+ // the key has no active references and is collected (which means the entry will
+ // also become invalid), the finalizer of the EntryRemover instance will run on
+ // a following GC, and remove that dead Entry instance from the map automatically.
+ this.table.Add(key, new EntryRemover(this.map, entry));
+
+ return true;
+ }
+
+ ///
+ /// A callback to create a new value from a given key.
+ ///
+ /// The resulting instance.
+ /// A cancellation token for the operation of creating a new value.
+ ///
+ public delegate TValue GetOrCreateCallback(TKey key, CancellationToken cancellationToken);
+
+ ///
+ /// An entry to use in .
+ ///
+ private sealed class Entry
+ {
+ ///
+ /// A weak reference to the actual entry instance.
+ ///
+ private readonly WeakReference reference;
+
+ ///
+ /// The hashcode of the target key (so it's available even after the key is gone).
+ ///
+ private readonly int hashCode;
+
+ ///
+ /// The last key matched from , if available.
+ ///
+ private TKey? lastMatchedKey;
+
+ ///
+ /// Indicates whether the entry is currently in lookup mode.
+ ///
+ private bool isPerformingLookup;
+
+ ///
+ /// Creates a new instance with the specified parameters.
+ ///
+ /// The key to use for the entry.
+ public Entry(TKey key)
+ {
+ this.reference = new WeakReference(key);
+ this.hashCode = EqualityComparer.Default.GetHashCode(key);
+ }
+
+ ///
+ /// Gets whether or not the current entry is alive.
+ ///
+ public bool IsAlive => this.reference.TryGetTarget(out _);
+
+ ///
+ /// Gets the last matched key retrieved during a lookup operation.
+ ///
+ /// The last matched key retrieved during a lookup operation.
+ /// Thrown if no key is available.
+ public TKey GetLastMatchedValue()
+ {
+ TKey? lastMatchedKey = this.lastMatchedKey;
+
+ if (lastMatchedKey is null)
+ {
+ EntryHelper.ThrowInvalidOperationExceptionForLastMatchedKey();
+ }
+
+ return lastMatchedKey;
+ }
+
+ ///
+ /// Sets whether the current instance is performing a lookup.
+ ///
+ /// The new value for the configuration.
+ public void SetIsPerformingLookup(bool value)
+ {
+ this.lastMatchedKey = null;
+ this.isPerformingLookup = value;
+ }
+
+ ///
+ public override bool Equals(object? obj)
+ {
+ if (obj is not Entry entry)
+ {
+ return false;
+ }
+
+ // Special case matching on the entry identity directly. This is used
+ // by the finalizer of EntryRemover to find the entry to remove.
+ if (this == entry)
+ {
+ return true;
+ }
+
+ _ = this.reference.TryGetTarget(out TKey? left);
+ _ = entry.reference.TryGetTarget(out TKey? right);
+
+ bool isMatch = EqualityComparer.Default.Equals(left, right);
+
+ // If we have a match and we're in lookup mode, store the last item. Note that the dictionary
+ // does not guarantee the order of arguments to Equals calls, so we cannot rely on the Entry
+ // object being used for lookup the one whose Equals method is being called. So if there is
+ // a match, we check which of the two input Entry instances is the one being used for lookups,
+ // and set the last match key object on that one. If there is no match, there is no need to
+ // clear the last key, as even in case the entry ends up being inserted into the map in the
+ // falback path, before doing so SetIsPerformingLookup(false) will be called, which will clear
+ // any previous matches. So there is already no way for the key to be accidentally rooted here.
+ if (isMatch)
+ {
+ if (this.isPerformingLookup)
+ {
+ this.lastMatchedKey = right;
+ }
+ else if (entry.isPerformingLookup)
+ {
+ entry.lastMatchedKey = left;
+ }
+ }
+
+ return isMatch;
+ }
+
+ ///
+ public override int GetHashCode()
+ {
+ return this.hashCode;
+ }
+ }
+
+ ///
+ /// An object reponsible for removing dead instance from the table.
+ ///
+ private sealed class EntryRemover
+ {
+ ///
+ private readonly ConcurrentDictionary map;
+
+ ///
+ /// The target instance to remove from .
+ ///
+ private readonly Entry entry;
+
+ ///
+ /// Creates a new instance with the specified parameters.
+ ///
+ ///
+ ///
+ public EntryRemover(ConcurrentDictionary map, Entry entry)
+ {
+ this.map = map;
+ this.entry = entry;
+ }
+
+ ///
+ /// Removes from when the current instance is finalized.
+ ///
+ ~EntryRemover()
+ {
+ _ = this.map.TryRemove(this.entry, out _);
+ }
+ }
+}
+
+///
+/// Private helpers for the type.
+///
+file static class EntryHelper
+{
+ ///
+ /// Throws an when there is no last matching key.
+ ///
+ [DoesNotReturn]
+ public static void ThrowInvalidOperationExceptionForLastMatchedKey()
+ {
+ throw new InvalidOperationException("No last matching key has been found.");
+ }
+}
\ No newline at end of file