diff --git a/src/libraries/Microsoft.Extensions.Caching.Memory/ref/Microsoft.Extensions.Caching.Memory.cs b/src/libraries/Microsoft.Extensions.Caching.Memory/ref/Microsoft.Extensions.Caching.Memory.cs index ee0d00700c54b..9a4eef4661d42 100644 --- a/src/libraries/Microsoft.Extensions.Caching.Memory/ref/Microsoft.Extensions.Caching.Memory.cs +++ b/src/libraries/Microsoft.Extensions.Caching.Memory/ref/Microsoft.Extensions.Caching.Memory.cs @@ -34,6 +34,7 @@ protected virtual void Dispose(bool disposing) { } ~MemoryCache() { } public void Remove(object key) { } public bool TryGetValue(object key, out object result) { throw null; } + public void Clear() { } } public partial class MemoryCacheOptions : Microsoft.Extensions.Options.IOptions { diff --git a/src/libraries/Microsoft.Extensions.Caching.Memory/src/MemoryCache.cs b/src/libraries/Microsoft.Extensions.Caching.Memory/src/MemoryCache.cs index 7c426014c4f5f..b0d6c71eb6008 100644 --- a/src/libraries/Microsoft.Extensions.Caching.Memory/src/MemoryCache.cs +++ b/src/libraries/Microsoft.Extensions.Caching.Memory/src/MemoryCache.cs @@ -23,9 +23,8 @@ public class MemoryCache : IMemoryCache internal readonly ILogger _logger; private readonly MemoryCacheOptions _options; - private readonly ConcurrentDictionary _entries; - private long _cacheSize; + private CoherentState _coherentState; private bool _disposed; private DateTimeOffset _lastExpirationScan; @@ -56,7 +55,7 @@ public MemoryCache(IOptions optionsAccessor, ILoggerFactory _options = optionsAccessor.Value; _logger = loggerFactory.CreateLogger(); - _entries = new ConcurrentDictionary(); + _coherentState = new(); if (_options.Clock == null) { @@ -75,15 +74,13 @@ public MemoryCache(IOptions optionsAccessor, ILoggerFactory /// /// Gets the count of the current entries for diagnostic purposes. /// - public int Count => _entries.Count; + public int Count => _coherentState.Count; // internal for testing - internal long Size { get => Interlocked.Read(ref _cacheSize); } + internal long Size => _coherentState.Size; internal bool TrackLinkedCacheEntries { get; } - private ICollection> EntriesCollection => _entries; - /// public ICacheEntry CreateEntry(object key) { @@ -123,7 +120,8 @@ internal void SetEntry(CacheEntry entry) // Initialize the last access timestamp at the time the entry is added entry.LastAccessed = utcNow; - if (_entries.TryGetValue(entry.Key, out CacheEntry priorEntry)) + CoherentState coherentState = _coherentState; // Clear() can update the reference in the meantime + if (coherentState._entries.TryGetValue(entry.Key, out CacheEntry priorEntry)) { priorEntry.SetExpired(EvictionReason.Replaced); } @@ -133,13 +131,13 @@ internal void SetEntry(CacheEntry entry) entry.InvokeEvictionCallbacks(); if (priorEntry != null) { - RemoveEntry(priorEntry); + coherentState.RemoveEntry(priorEntry, _options); } StartScanForExpiredItemsIfNeeded(utcNow); return; } - bool exceedsCapacity = UpdateCacheSizeExceedsCapacity(entry); + bool exceedsCapacity = UpdateCacheSizeExceedsCapacity(entry, coherentState); if (!exceedsCapacity) { bool entryAdded = false; @@ -147,19 +145,19 @@ internal void SetEntry(CacheEntry entry) if (priorEntry == null) { // Try to add the new entry if no previous entries exist. - entryAdded = _entries.TryAdd(entry.Key, entry); + entryAdded = coherentState._entries.TryAdd(entry.Key, entry); } else { // Try to update with the new entry if a previous entries exist. - entryAdded = _entries.TryUpdate(entry.Key, entry, priorEntry); + entryAdded = coherentState._entries.TryUpdate(entry.Key, entry, priorEntry); if (entryAdded) { if (_options.SizeLimit.HasValue) { // The prior entry was removed, decrease the by the prior entry's size - Interlocked.Add(ref _cacheSize, -priorEntry.Size.Value); + Interlocked.Add(ref coherentState._cacheSize, -priorEntry.Size.Value); } } else @@ -167,7 +165,7 @@ internal void SetEntry(CacheEntry entry) // The update will fail if the previous entry was removed after retrival. // Adding the new entry will succeed only if no entry has been added since. // This guarantees removing an old entry does not prevent adding a new entry. - entryAdded = _entries.TryAdd(entry.Key, entry); + entryAdded = coherentState._entries.TryAdd(entry.Key, entry); } } @@ -180,7 +178,7 @@ internal void SetEntry(CacheEntry entry) if (_options.SizeLimit.HasValue) { // Entry could not be added, reset cache size - Interlocked.Add(ref _cacheSize, -entry.Size.Value); + Interlocked.Add(ref coherentState._cacheSize, -entry.Size.Value); } entry.SetExpired(EvictionReason.Replaced); entry.InvokeEvictionCallbacks(); @@ -198,7 +196,7 @@ internal void SetEntry(CacheEntry entry) entry.InvokeEvictionCallbacks(); if (priorEntry != null) { - RemoveEntry(priorEntry); + coherentState.RemoveEntry(priorEntry, _options); } } @@ -213,7 +211,8 @@ public bool TryGetValue(object key, out object result) DateTimeOffset utcNow = _options.Clock.UtcNow; - if (_entries.TryGetValue(key, out CacheEntry entry)) + CoherentState coherentState = _coherentState; // Clear() can update the reference in the meantime + if (coherentState._entries.TryGetValue(key, out CacheEntry entry)) { // Check if expired due to expiration tokens, timers, etc. and if so, remove it. // Allow a stale Replaced value to be returned due to concurrent calls to SetExpired during SetEntry. @@ -236,7 +235,7 @@ public bool TryGetValue(object key, out object result) else { // TODO: For efficiency queue this up for batch removal - RemoveEntry(entry); + coherentState.RemoveEntry(entry, _options); } } @@ -250,13 +249,14 @@ public bool TryGetValue(object key, out object result) public void Remove(object key) { ValidateCacheKey(key); - CheckDisposed(); - if (_entries.TryRemove(key, out CacheEntry entry)) + + CoherentState coherentState = _coherentState; // Clear() can update the reference in the meantime + if (coherentState._entries.TryRemove(key, out CacheEntry entry)) { if (_options.SizeLimit.HasValue) { - Interlocked.Add(ref _cacheSize, -entry.Size.Value); + Interlocked.Add(ref coherentState._cacheSize, -entry.Size.Value); } entry.SetExpired(EvictionReason.Removed); @@ -266,22 +266,25 @@ public void Remove(object key) StartScanForExpiredItemsIfNeeded(_options.Clock.UtcNow); } - private void RemoveEntry(CacheEntry entry) + /// + /// Removes all keys and values from the cache. + /// + public void Clear() { - if (EntriesCollection.Remove(new KeyValuePair(entry.Key, entry))) + CheckDisposed(); + + CoherentState oldState = Interlocked.Exchange(ref _coherentState, new CoherentState()); + foreach (var entry in oldState._entries) { - if (_options.SizeLimit.HasValue) - { - Interlocked.Add(ref _cacheSize, -entry.Size.Value); - } - entry.InvokeEvictionCallbacks(); + entry.Value.SetExpired(EvictionReason.Removed); + entry.Value.InvokeEvictionCallbacks(); } } internal void EntryExpired(CacheEntry entry) { // TODO: For efficiency consider processing these expirations in batches. - RemoveEntry(entry); + _coherentState.RemoveEntry(entry, _options); StartScanForExpiredItemsIfNeeded(_options.Clock.UtcNow); } @@ -307,18 +310,19 @@ private static void ScanForExpiredItems(MemoryCache cache) { DateTimeOffset now = cache._lastExpirationScan = cache._options.Clock.UtcNow; - foreach (KeyValuePair item in cache._entries) + CoherentState coherentState = cache._coherentState; // Clear() can update the reference in the meantime + foreach (KeyValuePair item in coherentState._entries) { CacheEntry entry = item.Value; if (entry.CheckExpired(now)) { - cache.RemoveEntry(entry); + coherentState.RemoveEntry(entry, cache._options); } } } - private bool UpdateCacheSizeExceedsCapacity(CacheEntry entry) + private bool UpdateCacheSizeExceedsCapacity(CacheEntry entry, CoherentState coherentState) { if (!_options.SizeLimit.HasValue) { @@ -328,7 +332,7 @@ private bool UpdateCacheSizeExceedsCapacity(CacheEntry entry) long newSize = 0L; for (int i = 0; i < 100; i++) { - long sizeRead = Interlocked.Read(ref _cacheSize); + long sizeRead = coherentState.Size; newSize = sizeRead + entry.Size.Value; if (newSize < 0 || newSize > _options.SizeLimit) @@ -337,7 +341,7 @@ private bool UpdateCacheSizeExceedsCapacity(CacheEntry entry) return true; } - if (sizeRead == Interlocked.CompareExchange(ref _cacheSize, newSize, sizeRead)) + if (sizeRead == Interlocked.CompareExchange(ref coherentState._cacheSize, newSize, sizeRead)) { return false; } @@ -356,17 +360,18 @@ private void TriggerOvercapacityCompaction() private static void OvercapacityCompaction(MemoryCache cache) { - long currentSize = Interlocked.Read(ref cache._cacheSize); + CoherentState coherentState = cache._coherentState; // Clear() can update the reference in the meantime + long currentSize = coherentState.Size; cache._logger.LogDebug($"Overcapacity compaction executing. Current size {currentSize}"); double? lowWatermark = cache._options.SizeLimit * (1 - cache._options.CompactionPercentage); if (currentSize > lowWatermark) { - cache.Compact(currentSize - (long)lowWatermark, entry => entry.Size.Value); + cache.Compact(currentSize - (long)lowWatermark, entry => entry.Size.Value, coherentState); } - cache._logger.LogDebug($"Overcapacity compaction executed. New size {Interlocked.Read(ref cache._cacheSize)}"); + cache._logger.LogDebug($"Overcapacity compaction executed. New size {coherentState.Size}"); } /// Remove at least the given percentage (0.10 for 10%) of the total entries (or estimated memory?), according to the following policy: @@ -378,11 +383,12 @@ private static void OvercapacityCompaction(MemoryCache cache) /// ?. Larger objects - estimated by object graph size, inaccurate. public void Compact(double percentage) { - int removalCountTarget = (int)(_entries.Count * percentage); - Compact(removalCountTarget, _ => 1); + CoherentState coherentState = _coherentState; // Clear() can update the reference in the meantime + int removalCountTarget = (int)(coherentState.Count * percentage); + Compact(removalCountTarget, _ => 1, coherentState); } - private void Compact(long removalSizeTarget, Func computeEntrySize) + private void Compact(long removalSizeTarget, Func computeEntrySize, CoherentState coherentState) { var entriesToRemove = new List(); var lowPriEntries = new List(); @@ -392,7 +398,7 @@ private void Compact(long removalSizeTarget, Func computeEntry // Sort items by expired & priority status DateTimeOffset now = _options.Clock.UtcNow; - foreach (KeyValuePair item in _entries) + foreach (KeyValuePair item in coherentState._entries) { CacheEntry entry = item.Value; if (entry.CheckExpired(now)) @@ -427,7 +433,7 @@ private void Compact(long removalSizeTarget, Func computeEntry foreach (CacheEntry entry in entriesToRemove) { - RemoveEntry(entry); + coherentState.RemoveEntry(entry, _options); } // Policy: @@ -500,5 +506,29 @@ private static void ValidateCacheKey(object key) static void Throw() => throw new ArgumentNullException(nameof(key)); } + + private sealed class CoherentState + { + internal ConcurrentDictionary _entries = new ConcurrentDictionary(); + internal long _cacheSize; + + private ICollection> EntriesCollection => _entries; + + internal int Count => _entries.Count; + + internal long Size => Interlocked.Read(ref _cacheSize); + + internal void RemoveEntry(CacheEntry entry, MemoryCacheOptions options) + { + if (EntriesCollection.Remove(new KeyValuePair(entry.Key, entry))) + { + if (options.SizeLimit.HasValue) + { + Interlocked.Add(ref _cacheSize, -entry.Size.Value); + } + entry.InvokeEvictionCallbacks(); + } + } + } } } diff --git a/src/libraries/Microsoft.Extensions.Caching.Memory/tests/CapacityTests.cs b/src/libraries/Microsoft.Extensions.Caching.Memory/tests/CapacityTests.cs index e4d5ef4893568..79e09aaf54449 100644 --- a/src/libraries/Microsoft.Extensions.Caching.Memory/tests/CapacityTests.cs +++ b/src/libraries/Microsoft.Extensions.Caching.Memory/tests/CapacityTests.cs @@ -444,5 +444,19 @@ public void NoCompactionWhenNoMaximumEntriesCountSpecified() // There should be 6 items in the cache Assert.Equal(6, cache.Count); } + + [Fact] + public void ClearZeroesTheSize() + { + var cache = new MemoryCache(new MemoryCacheOptions { SizeLimit = 10 }); + Assert.Equal(0, cache.Size); + + cache.Set("key", "value", new MemoryCacheEntryOptions { Size = 5 }); + Assert.Equal(5, cache.Size); + + cache.Clear(); + Assert.Equal(0, cache.Size); + Assert.Equal(0, cache.Count); + } } } diff --git a/src/libraries/Microsoft.Extensions.Caching.Memory/tests/MemoryCacheSetAndRemoveTests.cs b/src/libraries/Microsoft.Extensions.Caching.Memory/tests/MemoryCacheSetAndRemoveTests.cs index 65810e352e927..28ad9ebebf795 100644 --- a/src/libraries/Microsoft.Extensions.Caching.Memory/tests/MemoryCacheSetAndRemoveTests.cs +++ b/src/libraries/Microsoft.Extensions.Caching.Memory/tests/MemoryCacheSetAndRemoveTests.cs @@ -380,6 +380,29 @@ public void RemoveRemoves() Assert.Null(result); } + [Fact] + public void ClearClears() + { + var cache = (MemoryCache)CreateCache(); + var obj = new object(); + string[] keys = new string[] { "key1", "key2", "key3", "key4" }; + + foreach (string key in keys) + { + var result = cache.Set(key, obj); + Assert.Same(obj, result); + Assert.Same(obj, cache.Get(key)); + } + + cache.Clear(); + + Assert.Equal(0, cache.Count); + foreach (string key in keys) + { + Assert.Null(cache.Get(key)); + } + } + [Fact] public void RemoveRemovesAndInvokesCallback() { @@ -411,6 +434,38 @@ public void RemoveRemovesAndInvokesCallback() Assert.Null(result); } + [Fact] + public void ClearClearsAndInvokesCallback() + { + var cache = (MemoryCache)CreateCache(); + var value = new object(); + string key = "myKey"; + var callbackInvoked = new ManualResetEvent(false); + + var options = new MemoryCacheEntryOptions(); + options.PostEvictionCallbacks.Add(new PostEvictionCallbackRegistration() + { + EvictionCallback = (subkey, subValue, reason, state) => + { + Assert.Equal(key, subkey); + Assert.Same(value, subValue); + Assert.Equal(EvictionReason.Removed, reason); + var localCallbackInvoked = (ManualResetEvent)state; + localCallbackInvoked.Set(); + }, + State = callbackInvoked + }); + var result = cache.Set(key, value, options); + Assert.Same(value, result); + + cache.Clear(); + Assert.Equal(0, cache.Count); + Assert.True(callbackInvoked.WaitOne(TimeSpan.FromSeconds(30)), "Callback"); + + result = cache.Get(key); + Assert.Null(result); + } + [Fact] public void RemoveAndReAddFromCallbackWorks() { diff --git a/src/libraries/Microsoft.Extensions.Caching.Memory/tests/TokenExpirationTests.cs b/src/libraries/Microsoft.Extensions.Caching.Memory/tests/TokenExpirationTests.cs index 182d53cf2e8b9..d55be43c63364 100644 --- a/src/libraries/Microsoft.Extensions.Caching.Memory/tests/TokenExpirationTests.cs +++ b/src/libraries/Microsoft.Extensions.Caching.Memory/tests/TokenExpirationTests.cs @@ -162,6 +162,29 @@ public void RemoveItemDisposesTokenRegistration() Assert.True(callbackInvoked.WaitOne(TimeSpan.FromSeconds(30)), "Callback"); } + [Fact] + public void ClearingCacheDisposesTokenRegistration() + { + var cache = (MemoryCache)CreateCache(); + string key = "myKey"; + var value = new object(); + var callbackInvoked = new ManualResetEvent(false); + var expirationToken = new TestExpirationToken() { ActiveChangeCallbacks = true }; + cache.Set(key, value, new MemoryCacheEntryOptions() + .AddExpirationToken(expirationToken) + .RegisterPostEvictionCallback((subkey, subValue, reason, state) => + { + var localCallbackInvoked = (ManualResetEvent)state; + localCallbackInvoked.Set(); + }, state: callbackInvoked)); + cache.Clear(); + + Assert.Equal(0, cache.Count); + Assert.NotNull(expirationToken.Registration); + Assert.True(expirationToken.Registration.Disposed); + Assert.True(callbackInvoked.WaitOne(TimeSpan.FromSeconds(30)), "Callback"); + } + [Fact] public void AddExpiredTokenPreventsCaching() {