From 86fa77e77e9f317e098a66f2862463d59a5f0e08 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Tue, 26 Sep 2023 16:13:31 +0200 Subject: [PATCH 1/5] Add DynamicCache type --- .../ComputeSharp.SourceGeneration.projitems | 1 + .../Helpers/DynamicCache{TKey,TValue}.cs | 234 ++++++++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs diff --git a/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems b/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems index 0e4410f29..289de1eaf 100644 --- a/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems +++ b/src/ComputeSharp.SourceGeneration/ComputeSharp.SourceGeneration.projitems @@ -17,6 +17,7 @@ + diff --git a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs new file mode 100644 index 000000000..4e05ecebf --- /dev/null +++ b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs @@ -0,0 +1,234 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace ComputeSharp.SourceGeneration.Helpers; + +/// +/// A dynamic cache that can be used to cache computed values within incremental models. +/// This automatically trims excess items, and relies on the incremental state tables +/// keeping values alive for at least one incremental step, in order to work correctly. +/// +/// The type of keys to use for the cache. +/// The type of values to store in the cache. +public sealed class DynamicCache + where TKey : class +{ + /// + /// The backing instance for the cache. + /// + private readonly ConcurrentDictionary map = new(); + + /// + /// Gets or creates a new value for a given key, using a supplied callback if needed. + /// + /// The key to use as lookup. + /// The callback to use to create new values, if needed. + /// The resulting value. + /// + /// This method might replace with a new instance that has the same + /// value according to its equality comparison logic. Callers should always use the last value of + /// after this method returns and discard the previous one, if different. + /// + public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback) + { + // Create a new entry that we will use to perform the lookup. + // Each entry simply forwards equality logic to the wrapped object. + Entry entry = new(key); + + while (true) + { + // We're performing a lookup on this temporary entry. We need it to + // track the value it will potentially match against, so we can + // return it to the caller. This ensures the same object is used. + entry.SetIsPerformingLookup(true); + + // Try to check whether we already have a value in the cache for + // this object. That is, we want to check whether there is an entry + // with a value equal to the one we have now (not necessarily the same + // object). If we find it, we return it and throw away the new entry. + if (this.map.TryGetValue(entry, out TValue? value)) + { + // We have a match, so replace the object with the one that actually matched. + // This guarantees that it will remain alive, so the entry will not die. + key = entry.GetLastMatchedValue(); + + return value; + } + + // Execute the slow fallback path, invoking the callback and trying to add a new + // value to the cache. If this succeeds, it means that the current key object has + // been added to the cache, so there is nothing left to do. If this fails, it means + // another thread has beat us to adding the key, so we should perform the initial + // lookup again to make sure we can find the exact instance that is in the cache. + // This is needed to ensure valid cache entries remain alive over time. + if (TryGetOrCreate(entry, key, callback, out value)) + { + return value; + } + } + } + + /// + /// Tries to get or create a new value for a given key, using a supplied callback. + /// + /// The instance to try to insert into the cache. + /// The key to use as lookup. + /// The callback to use to create new values, if needed. + /// The resulting value (should be ignored if the method fails). + /// Whether was successfully inserted into the cache. + [MethodImpl(MethodImplOptions.NoInlining)] + private bool TryGetOrCreate(Entry entry, TKey key, GetOrCreateCallback callback, out TValue value) + { + // No value is present, so we can create it now + value = callback(key); + + // We're about to try to add the item to the cache, so we no longer need to + // track the last matched value. We already have a reference to the object. + entry.SetIsPerformingLookup(false); + + // Try to add the value (this only fails if someone raced with this thread). + // In this case, our key will also be added to the cache. Callers will have + // a strong reference to it, which ensures the weak reference remains alive. + if (!this.map.TryAdd(entry, value)) + { + return false; + } + + // As part of the fallback step, traverse all items and remove dead keys + foreach (Entry candidateKey in this.map.Keys) + { + if (!candidateKey.IsAlive) + { + _ = this.map.TryRemove(candidateKey, out _); + } + } + + return true; + } + + /// + /// A callback to create a new value from a given key. + /// + /// The resulting instance. + /// + public delegate TValue GetOrCreateCallback(TKey key); + + /// + /// An entry to use in . + /// + private sealed class Entry + { + /// + /// A weak reference to the actual entry instance. + /// + private readonly WeakReference reference; + + /// + /// The last key matched from , if available. + /// + private TKey? lastMatchedKey; + + /// + /// Indicates whether the entry is currently in lookup mode. + /// + private bool isPerformingLookup; + + /// + /// Creates a new instance with the specified parameters. + /// + /// The key to use for the entry. + public Entry(TKey key) + { + this.reference = new WeakReference(key); + } + + /// + /// Gets whether or not the current entry is alive. + /// + public bool IsAlive => this.reference.TryGetTarget(out _); + + /// + /// Gets the last matched key retrieved during a lookup operation. + /// + /// The last matched key retrieved during a lookup operation. + /// Thrown if no key is available. + public TKey GetLastMatchedValue() + { + TKey? lastMatchedKey = this.lastMatchedKey; + + if (lastMatchedKey is null) + { + EntryHelper.ThrowInvalidOperationExceptionForLastMatchedKey(); + } + + return lastMatchedKey; + } + + /// + /// Sets whether the current instance is performing a lookup. + /// + /// The new value for the configuration. + public void SetIsPerformingLookup(bool value) + { + this.lastMatchedKey = null; + this.isPerformingLookup = value; + } + + /// + public override bool Equals(object? obj) + { + if (obj is not Entry entry) + { + return false; + } + + _ = this.reference.TryGetTarget(out TKey? left); + _ = entry.reference.TryGetTarget(out TKey? right); + + bool isMatch = EqualityComparer.Default.Equals(left, right); + + // If we have a match and we're in lookup mode, store the last item. + // Otherwise, clear it to also make sure not to accidentally root + // keys that are not actually in the cache anymore. + if (isMatch && this.isPerformingLookup) + { + this.lastMatchedKey = right; + } + else + { + this.lastMatchedKey = null; + } + + return isMatch; + } + + /// + public override int GetHashCode() + { + if (this.reference.TryGetTarget(out TKey? value)) + { + return EqualityComparer.Default.GetHashCode(value); + } + + return 0; + } + } +} + +/// +/// Private helpers for the type. +/// +file static class EntryHelper +{ + /// + /// Throws an when there is no last matching key. + /// + [DoesNotReturn] + public static void ThrowInvalidOperationExceptionForLastMatchedKey() + { + throw new InvalidOperationException("No last matching key has been found."); + } +} \ No newline at end of file From 2122efcc25d7aa78976179cfb91d20c6f91e6407 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Wed, 27 Sep 2023 02:19:00 +0200 Subject: [PATCH 2/5] Implement auto-cleanup for DynamicCache --- .../Helpers/DynamicCache{TKey,TValue}.cs | 70 +++++++++++++++---- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs index 4e05ecebf..ad99217e6 100644 --- a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs +++ b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs @@ -21,6 +21,11 @@ public sealed class DynamicCache /// private readonly ConcurrentDictionary map = new(); + /// + /// The tracking dead entries to remove. + /// + private readonly ConditionalWeakTable table = new(); + /// /// Gets or creates a new value for a given key, using a supplied callback if needed. /// @@ -97,14 +102,12 @@ private bool TryGetOrCreate(Entry entry, TKey key, GetOrCreateCallback callback, return false; } - // As part of the fallback step, traverse all items and remove dead keys - foreach (Entry candidateKey in this.map.Keys) - { - if (!candidateKey.IsAlive) - { - _ = this.map.TryRemove(candidateKey, out _); - } - } + // We need to setup the removal of this key-value pair when the key is no longer + // referenced. To do this, we add a new EntryRemover instance to the table. When + // the key has no active references and is collected (which means the entry will + // also become invalid), the finalizer of the EntryRemover instance will run on + // a following GC, and remove that dead Entry instance from the map automatically. + this.table.Add(key, new EntryRemover(this.map, entry)); return true; } @@ -126,6 +129,11 @@ private sealed class Entry /// private readonly WeakReference reference; + /// + /// The hashcode of the target key (so it's available even after the key is gone). + /// + private readonly int hashCode; + /// /// The last key matched from , if available. /// @@ -143,6 +151,7 @@ private sealed class Entry public Entry(TKey key) { this.reference = new WeakReference(key); + this.hashCode = EqualityComparer.Default.GetHashCode(key); } /// @@ -185,6 +194,13 @@ public override bool Equals(object? obj) return false; } + // Special case matching on the entry identity directly. This is used + // by the finalizer of EntryRemover to find the entry to remove. + if (this == entry) + { + return true; + } + _ = this.reference.TryGetTarget(out TKey? left); _ = entry.reference.TryGetTarget(out TKey? right); @@ -208,12 +224,40 @@ public override bool Equals(object? obj) /// public override int GetHashCode() { - if (this.reference.TryGetTarget(out TKey? value)) - { - return EqualityComparer.Default.GetHashCode(value); - } + return this.hashCode; + } + } + + /// + /// An object reponsible for removing dead instance from the table. + /// + private sealed class EntryRemover + { + /// + private readonly ConcurrentDictionary map; + + /// + /// The target instance to remove from . + /// + private readonly Entry entry; - return 0; + /// + /// Creates a new instance with the specified parameters. + /// + /// + /// + public EntryRemover(ConcurrentDictionary map, Entry entry) + { + this.map = map; + this.entry = entry; + } + + /// + /// Removes from when the current instance is finalized. + /// + ~EntryRemover() + { + _ = this.map.TryRemove(this.entry, out _); } } } From 90896d557440acf85084908e62b8192a41513bac Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Wed, 27 Sep 2023 03:04:06 +0200 Subject: [PATCH 3/5] Handle Equals calls with arguments in any order --- .../Helpers/DynamicCache{TKey,TValue}.cs | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs index ad99217e6..4ecb52dd1 100644 --- a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs +++ b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs @@ -206,16 +206,24 @@ public override bool Equals(object? obj) bool isMatch = EqualityComparer.Default.Equals(left, right); - // If we have a match and we're in lookup mode, store the last item. - // Otherwise, clear it to also make sure not to accidentally root - // keys that are not actually in the cache anymore. - if (isMatch && this.isPerformingLookup) + // If we have a match and we're in lookup mode, store the last item. Note that the dictionary + // does not guarantee the order of arguments to Equals calls, so we cannot rely on the Entry + // object being used for lookup the one whose Equals method is being called. So if there is + // a match, we check which of the two input Entry instances is the one being used for lookups, + // and set the last match key object on that one. If there is no match, there is no need to + // clear the last key, as even in case the entry ends up being inserted into the map in the + // falback path, before doing so SetIsPerformingLookup(false) will be called, which will clear + // any previous matches. So there is already no way for the key to be accidentally rooted here. + if (isMatch) { - this.lastMatchedKey = right; - } - else - { - this.lastMatchedKey = null; + if (this.isPerformingLookup) + { + this.lastMatchedKey = right; + } + else if (entry.isPerformingLookup) + { + entry.lastMatchedKey = left; + } } return isMatch; From e2ce0a60798fe51d9b0cdfb4025dabe4556f1f38 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Thu, 28 Sep 2023 22:42:00 +0200 Subject: [PATCH 4/5] Add cancellation support to DynamicCache<,> --- .../Helpers/DynamicCache{TKey,TValue}.cs | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs index 4ecb52dd1..787750893 100644 --- a/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs +++ b/src/ComputeSharp.SourceGeneration/Helpers/DynamicCache{TKey,TValue}.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using System.Threading; namespace ComputeSharp.SourceGeneration.Helpers; @@ -31,20 +32,26 @@ public sealed class DynamicCache /// /// The key to use as lookup. /// The callback to use to create new values, if needed. + /// A cancellation token for the operation of creating a new value. /// The resulting value. + /// Thrown if is canceled. /// /// This method might replace with a new instance that has the same /// value according to its equality comparison logic. Callers should always use the last value of /// after this method returns and discard the previous one, if different. /// - public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback) + public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback, CancellationToken cancellationToken) { + cancellationToken.ThrowIfCancellationRequested(); + // Create a new entry that we will use to perform the lookup. // Each entry simply forwards equality logic to the wrapped object. Entry entry = new(key); while (true) { + cancellationToken.ThrowIfCancellationRequested(); + // We're performing a lookup on this temporary entry. We need it to // track the value it will potentially match against, so we can // return it to the caller. This ensures the same object is used. @@ -56,6 +63,8 @@ public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback) // object). If we find it, we return it and throw away the new entry. if (this.map.TryGetValue(entry, out TValue? value)) { + cancellationToken.ThrowIfCancellationRequested(); + // We have a match, so replace the object with the one that actually matched. // This guarantees that it will remain alive, so the entry will not die. key = entry.GetLastMatchedValue(); @@ -63,14 +72,21 @@ public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback) return value; } + cancellationToken.ThrowIfCancellationRequested(); + // Execute the slow fallback path, invoking the callback and trying to add a new // value to the cache. If this succeeds, it means that the current key object has // been added to the cache, so there is nothing left to do. If this fails, it means // another thread has beat us to adding the key, so we should perform the initial // lookup again to make sure we can find the exact instance that is in the cache. // This is needed to ensure valid cache entries remain alive over time. - if (TryGetOrCreate(entry, key, callback, out value)) + if (TryGetOrCreate(entry, key, callback, cancellationToken, out value)) { + // If the operation has been canceled after inserting a new key and value, + // the value will just be wasted, but there isn't really anything we can do + // about it at this point. We also ignore the unnecessary remover finalizer. + cancellationToken.ThrowIfCancellationRequested(); + return value; } } @@ -82,13 +98,22 @@ public TValue GetOrCreate(ref TKey key, GetOrCreateCallback callback) /// The instance to try to insert into the cache. /// The key to use as lookup. /// The callback to use to create new values, if needed. + /// A cancellation token for the operation of creating a new value. /// The resulting value (should be ignored if the method fails). /// Whether was successfully inserted into the cache. + /// Thrown if is canceled. [MethodImpl(MethodImplOptions.NoInlining)] - private bool TryGetOrCreate(Entry entry, TKey key, GetOrCreateCallback callback, out TValue value) + private bool TryGetOrCreate( + Entry entry, + TKey key, + GetOrCreateCallback callback, + CancellationToken cancellationToken, + out TValue value) { // No value is present, so we can create it now - value = callback(key); + value = callback(key, cancellationToken); + + cancellationToken.ThrowIfCancellationRequested(); // We're about to try to add the item to the cache, so we no longer need to // track the last matched value. We already have a reference to the object. @@ -116,8 +141,9 @@ private bool TryGetOrCreate(Entry entry, TKey key, GetOrCreateCallback callback, /// A callback to create a new value from a given key. /// /// The resulting instance. + /// A cancellation token for the operation of creating a new value. /// - public delegate TValue GetOrCreateCallback(TKey key); + public delegate TValue GetOrCreateCallback(TKey key, CancellationToken cancellationToken); /// /// An entry to use in . From a302521f9451c4cec8f215c120433625caa51cd2 Mon Sep 17 00:00:00 2001 From: Sergio Pedri Date: Fri, 29 Sep 2023 02:35:23 +0200 Subject: [PATCH 5/5] Adopt DynamicCache<,> for HlslBytecodeInfo values --- ...haderGenerator.CreateLoadBytecodeMethod.cs | 73 +++++++++++-------- .../ID2D1ShaderGenerator.cs | 2 +- 2 files changed, 44 insertions(+), 31 deletions(-) diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs index cb8a39463..876be56c1 100644 --- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs +++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.CreateLoadBytecodeMethod.cs @@ -25,6 +25,11 @@ partial class ID2D1ShaderGenerator /// internal static partial class LoadBytecode { + /// + /// The shared cache of values. + /// + private static readonly DynamicCache HlslBytecodeCache = new(); + /// /// Extracts the requested shader profile for the current shader. /// @@ -141,46 +146,54 @@ public static bool IsSimpleInputShader(INamedTypeSymbol structDeclarationSymbol, /// The instance for the shader to compile. /// The used to cancel the operation, if needed. /// - public static unsafe HlslBytecodeInfo GetInfo(HlslBytecodeInfoKey key, CancellationToken token) + public static HlslBytecodeInfo GetInfo(ref HlslBytecodeInfoKey key, CancellationToken token) { - // No embedded shader was requested, or there were some errors earlier in the pipeline. - // In this case, skip the compilation, as diagnostic will be emitted for those anyway. - // Compiling would just add overhead and result in more errors, as the HLSL would be invalid. - // We also skip compilation if no shader profile has been requested (we never just assume one). - if (key.HasErrors || key.RequestedShaderProfile is null) + static unsafe HlslBytecodeInfo GetInfo(HlslBytecodeInfoKey key, CancellationToken token) { - return HlslBytecodeInfo.Missing.Instance; - } + // No embedded shader was requested, or there were some errors earlier in the pipeline. + // In this case, skip the compilation, as diagnostic will be emitted for those anyway. + // Compiling would just add overhead and result in more errors, as the HLSL would be invalid. + // We also skip compilation if no shader profile has been requested (we never just assume one). + if (key.HasErrors || key.RequestedShaderProfile is null) + { + return HlslBytecodeInfo.Missing.Instance; + } - try - { - token.ThrowIfCancellationRequested(); + try + { + token.ThrowIfCancellationRequested(); - // Compile the shader bytecode using the effective parameters - using ComPtr dxcBlobBytecode = D3DCompiler.Compile( - key.HlslSource.AsSpan(), - key.EffectiveShaderProfile, - key.EffectiveCompileOptions); + // Compile the shader bytecode using the effective parameters + using ComPtr dxcBlobBytecode = D3DCompiler.Compile( + key.HlslSource.AsSpan(), + key.EffectiveShaderProfile, + key.EffectiveCompileOptions); - token.ThrowIfCancellationRequested(); + token.ThrowIfCancellationRequested(); - byte* buffer = (byte*)dxcBlobBytecode.Get()->GetBufferPointer(); - int length = checked((int)dxcBlobBytecode.Get()->GetBufferSize()); + byte* buffer = (byte*)dxcBlobBytecode.Get()->GetBufferPointer(); + int length = checked((int)dxcBlobBytecode.Get()->GetBufferSize()); - byte[] array = new ReadOnlySpan(buffer, length).ToArray(); + byte[] array = new ReadOnlySpan(buffer, length).ToArray(); - ImmutableArray bytecode = Unsafe.As>(ref array); + ImmutableArray bytecode = Unsafe.As>(ref array); - return new HlslBytecodeInfo.Success(bytecode); - } - catch (Win32Exception e) - { - return new HlslBytecodeInfo.Win32Error(e.NativeErrorCode, D3DCompiler.PrettifyFxcErrorMessage(e.Message)); - } - catch (FxcCompilationException e) - { - return new HlslBytecodeInfo.FxcError(D3DCompiler.PrettifyFxcErrorMessage(e.Message)); + return new HlslBytecodeInfo.Success(bytecode); + } + catch (Win32Exception e) + { + return new HlslBytecodeInfo.Win32Error(e.NativeErrorCode, D3DCompiler.PrettifyFxcErrorMessage(e.Message)); + } + catch (FxcCompilationException e) + { + return new HlslBytecodeInfo.FxcError(D3DCompiler.PrettifyFxcErrorMessage(e.Message)); + } } + + // Get or create the HLSL bytecode compilation result for the input key. The dynamic cache + // will take care of retrieving an existing cached value if the same shader has been compiled + // already with the same parameters. After this call, callers must use the updated key value. + return HlslBytecodeCache.GetOrCreate(ref key, GetInfo, token); } /// diff --git a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs index b3d85420b..044b2e32d 100644 --- a/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs +++ b/src/ComputeSharp.D2D1.SourceGenerators/ID2D1ShaderGenerator.cs @@ -154,7 +154,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) hasErrors); // TODO: cache this across transform runs - HlslBytecodeInfo hlslInfo = LoadBytecode.GetInfo(hlslInfoKey, token); + HlslBytecodeInfo hlslInfo = LoadBytecode.GetInfo(ref hlslInfoKey, token); token.ThrowIfCancellationRequested();