diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs index cb8a39463..876be56c1 100644 --- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs +++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs @@ -25,6 +25,11 @@ partial class ID2D1ShaderGenerator /// internal static partial class LoadBytecode { + /// + /// The shared cache of values. + /// + private static readonly DynamicCache HlslBytecodeCache = new(); + /// /// Extracts the requested shader profile for the current shader. /// @@ -141,46 +146,54 @@ public static bool IsSimpleInputShader(INamedTypeSymbol structDeclarationSymbol, /// The instance for the shader to compile. /// The used to cancel the operation, if needed. /// - public static unsafe HlslBytecodeInfo GetInfo(HlslBytecodeInfoKey key, CancellationToken token) + public static HlslBytecodeInfo GetInfo(ref HlslBytecodeInfoKey key, CancellationToken token) { - // No embedded shader was requested, or there were some errors earlier in the pipeline. - // In this case, skip the compilation, as diagnostic will be emitted for those anyway. - // Compiling would just add overhead and result in more errors, as the HLSL would be invalid. - // We also skip compilation if no shader profile has been requested (we never just assume one). - if (key.HasErrors || key.RequestedShaderProfile is null) + static unsafe HlslBytecodeInfo GetInfo(HlslBytecodeInfoKey key, CancellationToken token) { - return HlslBytecodeInfo.Missing.Instance; - } + // No embedded shader was requested, or there were some errors earlier in the pipeline. + // In this case, skip the compilation, as diagnostic will be emitted for those anyway. + // Compiling would just add overhead and result in more errors, as the HLSL would be invalid. + // We also skip compilation if no shader profile has been requested (we never just assume one). + if (key.HasErrors || key.RequestedShaderProfile is null) + { + return HlslBytecodeInfo.Missing.Instance; + } - try - { - token.ThrowIfCancellationRequested(); + try + { + token.ThrowIfCancellationRequested(); - // Compile the shader bytecode using the effective parameters - using ComPtr dxcBlobBytecode = D3DCompiler.Compile( - key.HlslSource.AsSpan(), - key.EffectiveShaderProfile, - key.EffectiveCompileOptions); + // Compile the shader bytecode using the effective parameters + using ComPtr dxcBlobBytecode = D3DCompiler.Compile( + key.HlslSource.AsSpan(), + key.EffectiveShaderProfile, + key.EffectiveCompileOptions); - token.ThrowIfCancellationRequested(); + token.ThrowIfCancellationRequested(); - byte* buffer = (byte*)dxcBlobBytecode.Get()->GetBufferPointer(); - int length = checked((int)dxcBlobBytecode.Get()->GetBufferSize()); + byte* buffer = (byte*)dxcBlobBytecode.Get()->GetBufferPointer(); + int length = checked((int)dxcBlobBytecode.Get()->GetBufferSize()); - byte[] array = new ReadOnlySpan(buffer, length).ToArray(); + byte[] array = new ReadOnlySpan(buffer, length).ToArray(); - ImmutableArray bytecode = Unsafe.As>(ref array); + ImmutableArray bytecode = Unsafe.As>(ref array); - return new HlslBytecodeInfo.Success(bytecode); - } - catch (Win32Exception e) - { - return new HlslBytecodeInfo.Win32Error(e.NativeErrorCode, D3DCompiler.PrettifyFxcErrorMessage(e.Message)); - } - catch (FxcCompilationException e) - { - return new HlslBytecodeInfo.FxcError(D3DCompiler.PrettifyFxcErrorMessage(e.Message)); + return new HlslBytecodeInfo.Success(bytecode); + } + catch (Win32Exception e) + { + return new HlslBytecodeInfo.Win32Error(e.NativeErrorCode, D3DCompiler.PrettifyFxcErrorMessage(e.Message)); + } + catch (FxcCompilationException e) + { + return new HlslBytecodeInfo.FxcError(D3DCompiler.PrettifyFxcErrorMessage(e.Message)); + } } + + // Get or create the HLSL bytecode compilation result for the input key. The dynamic cache + // will take care of retrieving an existing cached value if the same shader has been compiled + // already with the same parameters. After this call, callers must use the updated key value. + return HlslBytecodeCache.GetOrCreate(ref key, GetInfo, token); } /// diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs index b3d85420b..044b2e32d 100644 --- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs +++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs @@ -154,7 +154,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) hasErrors); // TODO: cache this across transform runs - HlslBytecodeInfo hlslInfo = LoadBytecode.GetInfo(hlslInfoKey, token); + HlslBytecodeInfo hlslInfo = LoadBytecode.GetInfo(ref hlslInfoKey, token); token.ThrowIfCancellationRequested(); diff --git a/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems b/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems index 0e4410f29..289de1eaf 100644 --- a/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems +++ b/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems @@ -17,6 +17,7 @@ + diff --git a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs new file mode 100644 index 000000000..787750893 --- /dev/null +++ b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs @@ -0,0 +1,312 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; + +namespace ComputeSharp.SourceGeneration.Helpers; + +/// +/// A dynamic cache that can be used to cache computed values within incremental models. +/// This automatically trims excess items, and relies on the incremental state tables +/// keeping values alive for at least one incremental step, in order to work correctly. +/// +/// The type of keys to use for the cache. +/// The type of values to store in the cache. +public sealed class DynamicCache + where TKey : class +{ + /// + /// The backing instance for the cache. + /// + private readonly ConcurrentDictionary map = new(); + + /// + /// The tracking dead entries to remove. + /// + private readonly ConditionalWeakTable table = new(); + + /// + /// Gets or creates a new value for a given key, using a supplied callback if needed. + /// + /// The key to use as lookup. + /// The callback to use to create new values, if needed. + /// A cancellation token for the operation of creating a new value. + /// The resulting value. + /// Thrown if is canceled. + /// + /// This method might replace with a new instance that has the same + /// value according to its equality comparison logic. Callers should always use the last value of + /// after this method returns and discard the previous one, if different. + /// + public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + // Create a new entry that we will use to perform the lookup. + // Each entry simply forwards equality logic to the wrapped object. + Entry entry = new(key); + + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + + // We're performing a lookup on this temporary entry. We need it to + // track the value it will potentially match against, so we can + // return it to the caller. This ensures the same object is used. + entry.SetIsPerformingLookup(true); + + // Try to check whether we already have a value in the cache for + // this object. That is, we want to check whether there is an entry + // with a value equal to the one we have now (not necessarily the same + // object). If we find it, we return it and throw away the new entry. + if (this.map.TryGetValue(entry, out TValue? value)) + { + cancellationToken.ThrowIfCancellationRequested(); + + // We have a match, so replace the object with the one that actually matched. + // This guarantees that it will remain alive, so the entry will not die. + key = entry.GetLastMatchedValue(); + + return value; + } + + cancellationToken.ThrowIfCancellationRequested(); + + // Execute the slow fallback path, invoking the callback and trying to add a new + // value to the cache. If this succeeds, it means that the current key object has + // been added to the cache, so there is nothing left to do. If this fails, it means + // another thread has beat us to adding the key, so we should perform the initial + // lookup again to make sure we can find the exact instance that is in the cache. + // This is needed to ensure valid cache entries remain alive over time. + if (TryGetOrCreate(entry, key, callback, cancellationToken, out value)) + { + // If the operation has been canceled after inserting a new key and value, + // the value will just be wasted, but there isn't really anything we can do + // about it at this point. We also ignore the unnecessary remover finalizer. + cancellationToken.ThrowIfCancellationRequested(); + + return value; + } + } + } + + /// + /// Tries to get or create a new value for a given key, using a supplied callback. + /// + /// The instance to try to insert into the cache. + /// The key to use as lookup. + /// The callback to use to create new values, if needed. + /// A cancellation token for the operation of creating a new value. + /// The resulting value (should be ignored if the method fails). + /// Whether was successfully inserted into the cache. + /// Thrown if is canceled. + [MethodImpl(MethodImplOptions.NoInlining)] + private bool TryGetOrCreate( + Entry entry, + TKey key, + GetOrCreateCallback callback, + CancellationToken cancellationToken, + out TValue value) + { + // No value is present, so we can create it now + value = callback(key, cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); + + // We're about to try to add the item to the cache, so we no longer need to + // track the last matched value. We already have a reference to the object. + entry.SetIsPerformingLookup(false); + + // Try to add the value (this only fails if someone raced with this thread). + // In this case, our key will also be added to the cache. Callers will have + // a strong reference to it, which ensures the weak reference remains alive. + if (!this.map.TryAdd(entry, value)) + { + return false; + } + + // We need to setup the removal of this key-value pair when the key is no longer + // referenced. To do this, we add a new EntryRemover instance to the table. When + // the key has no active references and is collected (which means the entry will + // also become invalid), the finalizer of the EntryRemover instance will run on + // a following GC, and remove that dead Entry instance from the map automatically. + this.table.Add(key, new EntryRemover(this.map, entry)); + + return true; + } + + /// + /// A callback to create a new value from a given key. + /// + /// The resulting instance. + /// A cancellation token for the operation of creating a new value. + /// + public delegate TValue GetOrCreateCallback(TKey key, CancellationToken cancellationToken); + + /// + /// An entry to use in . + /// + private sealed class Entry + { + /// + /// A weak reference to the actual entry instance. + /// + private readonly WeakReference reference; + + /// + /// The hashcode of the target key (so it's available even after the key is gone). + /// + private readonly int hashCode; + + /// + /// The last key matched from , if available. + /// + private TKey? lastMatchedKey; + + /// + /// Indicates whether the entry is currently in lookup mode. + /// + private bool isPerformingLookup; + + /// + /// Creates a new instance with the specified parameters. + /// + /// The key to use for the entry. + public Entry(TKey key) + { + this.reference = new WeakReference(key); + this.hashCode = EqualityComparer.Default.GetHashCode(key); + } + + /// + /// Gets whether or not the current entry is alive. + /// + public bool IsAlive => this.reference.TryGetTarget(out _); + + /// + /// Gets the last matched key retrieved during a lookup operation. + /// + /// The last matched key retrieved during a lookup operation. + /// Thrown if no key is available. + public TKey GetLastMatchedValue() + { + TKey? lastMatchedKey = this.lastMatchedKey; + + if (lastMatchedKey is null) + { + EntryHelper.ThrowInvalidOperationExceptionForLastMatchedKey(); + } + + return lastMatchedKey; + } + + /// + /// Sets whether the current instance is performing a lookup. + /// + /// The new value for the configuration. + public void SetIsPerformingLookup(bool value) + { + this.lastMatchedKey = null; + this.isPerformingLookup = value; + } + + /// + public override bool Equals(object? obj) + { + if (obj is not Entry entry) + { + return false; + } + + // Special case matching on the entry identity directly. This is used + // by the finalizer of EntryRemover to find the entry to remove. + if (this == entry) + { + return true; + } + + _ = this.reference.TryGetTarget(out TKey? left); + _ = entry.reference.TryGetTarget(out TKey? right); + + bool isMatch = EqualityComparer.Default.Equals(left, right); + + // If we have a match and we're in lookup mode, store the last item. Note that the dictionary + // does not guarantee the order of arguments to Equals calls, so we cannot rely on the Entry + // object being used for lookup the one whose Equals method is being called. So if there is + // a match, we check which of the two input Entry instances is the one being used for lookups, + // and set the last match key object on that one. If there is no match, there is no need to + // clear the last key, as even in case the entry ends up being inserted into the map in the + // falback path, before doing so SetIsPerformingLookup(false) will be called, which will clear + // any previous matches. So there is already no way for the key to be accidentally rooted here. + if (isMatch) + { + if (this.isPerformingLookup) + { + this.lastMatchedKey = right; + } + else if (entry.isPerformingLookup) + { + entry.lastMatchedKey = left; + } + } + + return isMatch; + } + + /// + public override int GetHashCode() + { + return this.hashCode; + } + } + + /// + /// An object reponsible for removing dead instance from the table. + /// + private sealed class EntryRemover + { + /// + private readonly ConcurrentDictionary map; + + /// + /// The target instance to remove from . + /// + private readonly Entry entry; + + /// + /// Creates a new instance with the specified parameters. + /// + /// + /// + public EntryRemover(ConcurrentDictionary map, Entry entry) + { + this.map = map; + this.entry = entry; + } + + /// + /// Removes from when the current instance is finalized. + /// + ~EntryRemover() + { + _ = this.map.TryRemove(this.entry, out _); + } + } +} + +/// +/// Private helpers for the type. +/// +file static class EntryHelper +{ + /// + /// Throws an when there is no last matching key. + /// + [DoesNotReturn] + public static void ThrowInvalidOperationExceptionForLastMatchedKey() + { + throw new InvalidOperationException("No last matching key has been found."); + } +} \ No newline at end of file