Skip to content

Commit

Permalink
Fix | Replaced System.Runtime.Caching with Microsoft.Extensions.Cachi…
Browse files Browse the repository at this point in the history
…ng.Memory (#2493)
  • Loading branch information
arellegue authored and mdaigle committed Jun 20, 2024
1 parent 42424ef commit df2a2dc
Show file tree
Hide file tree
Showing 18 changed files with 124 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<Compile Include="..\..\ref\Microsoft.Data.SqlClient.Batch.NetCoreApp.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(MicrosoftExtensionsCachingMemoryVersion)" />
<PackageReference Include="System.Security.Cryptography.Cng" Version="$(SystemSecurityCryptographyCngVersion)" />
<PackageReference Include="Microsoft.Identity.Client" Version="$(MicrosoftIdentityClientVersion)" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,6 @@
<!-- Package References Etc -->
<ItemGroup Condition="'$(TargetGroup)' == 'netcoreapp'">
<PackageReference Include="System.Configuration.ConfigurationManager" Version="$(SystemConfigurationConfigurationManagerVersion)" />
<PackageReference Include="System.Runtime.Caching" Version="$(SystemRuntimeCachingVersion)" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient.SNI.runtime" Version="$(MicrosoftDataSqlClientSNIRuntimeVersion)" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<Reference Include="System.Transactions" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(MicrosoftExtensionsCachingMemoryVersion)" />
<PackageReference Include="Microsoft.Identity.Client" Version="$(MicrosoftIdentityClientVersion)" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@
</COMReference>
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(MicrosoftExtensionsCachingMemoryVersion)" />
<PackageReference Include="System.Text.Encodings.Web">
<Version>$(SystemTextEncodingsWebVersion)</Version>
</PackageReference>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Runtime.Caching;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Extensibility;

Expand All @@ -27,7 +27,7 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
private static readonly MemoryCache s_accountPwCache = new MemoryCache(new MemoryCacheOptions());
private static readonly int s_accountPwCacheTtlInHours = 2;
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
Expand Down Expand Up @@ -270,11 +270,11 @@ previousPw is byte[] previousPwBytes &&
// We cache the password hash to ensure future connection requests include a validated password
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
// against the same tenant could succeed with an invalid password when we re-use the cached token.
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
using (ICacheEntry entry = s_accountPwCache.CreateEntry(pwCacheKey))
{
s_accountPwCache.Remove(pwCacheKey);
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
}
entry.Value = GetHash(parameters.Password);
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(s_accountPwCacheTtlInHours);
};

SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IdentityModel.Tokens.Jwt;
using System.Runtime.Caching;
using System.Security.Claims;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Logging;
using Microsoft.IdentityModel.Protocols;
Expand Down Expand Up @@ -59,7 +59,7 @@ internal class AzureAttestationEnclaveProvider : EnclaveProviderBase
// such as https://sql.azure.attest.com/.well-known/openid-configuration
private const string AttestationUrlSuffix = @"/.well-known/openid-configuration";

private static readonly MemoryCache OpenIdConnectConfigurationCache = new MemoryCache("OpenIdConnectConfigurationCache");
private static readonly MemoryCache OpenIdConnectConfigurationCache = new MemoryCache(new MemoryCacheOptions());
#endregion

#region Internal methods
Expand Down Expand Up @@ -332,7 +332,7 @@ private static string GetInnerMostExceptionMessage(Exception exception)
// It also caches that information for 1 day to avoid DDOS attacks.
private OpenIdConnectConfiguration GetOpenIdConfigForSigningKeys(string url, bool forceUpdate)
{
OpenIdConnectConfiguration openIdConnectConfig = OpenIdConnectConfigurationCache[url] as OpenIdConnectConfiguration;
OpenIdConnectConfiguration openIdConnectConfig = OpenIdConnectConfigurationCache.Get<OpenIdConnectConfiguration>(url);
if (forceUpdate || openIdConnectConfig == null)
{
// Compute the meta data endpoint
Expand All @@ -348,7 +348,11 @@ private OpenIdConnectConfiguration GetOpenIdConfigForSigningKeys(string url, boo
throw SQL.AttestationFailed(string.Format(Strings.GetAttestationTokenSigningKeysFailed, GetInnerMostExceptionMessage(exception)), exception);
}

OpenIdConnectConfigurationCache.Add(url, openIdConnectConfig, DateTime.UtcNow.AddDays(1));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromDays(1)
};
OpenIdConnectConfigurationCache.Set<OpenIdConnectConfiguration>(url, openIdConnectConfig, options);
}

return openIdConnectConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.Caching;
using System.Security.Cryptography;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;

// Enclave session locking model
// 1. For doing the enclave attestation, driver makes either 1, 2 or 3 API calls(in order)
Expand Down Expand Up @@ -84,7 +84,7 @@ internal abstract class EnclaveProviderBase : SqlColumnEncryptionEnclaveProvider
private static readonly Object lockUpdateSessionLock = new Object();

// It is used to save the attestation url and nonce value across API calls
protected static readonly MemoryCache ThreadRetryCache = new MemoryCache("ThreadRetryCache");
protected static readonly MemoryCache ThreadRetryCache = new MemoryCache(new MemoryCacheOptions());
#endregion

#region protected methods
Expand All @@ -102,7 +102,7 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa

// In case if on some thread we are running SQL workload which don't require attestation, then in those cases we don't want same thread to wait for event to be signaled.
// hence skipping it
string retryThreadID = ThreadRetryCache[Thread.CurrentThread.ManagedThreadId.ToString()] as string;
string retryThreadID = ThreadRetryCache.Get<string>(Thread.CurrentThread.ManagedThreadId.ToString());
if (!string.IsNullOrEmpty(retryThreadID))
{
sameThreadRetry = true;
Expand Down Expand Up @@ -167,7 +167,11 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa
retryThreadID = Thread.CurrentThread.ManagedThreadId.ToString();
}

ThreadRetryCache.Set(Thread.CurrentThread.ManagedThreadId.ToString(), retryThreadID, DateTime.UtcNow.AddMinutes(ThreadRetryCacheTimeoutInMinutes));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(ThreadRetryCacheTimeoutInMinutes)
};
ThreadRetryCache.Set<string>(Thread.CurrentThread.ManagedThreadId.ToString(), retryThreadID, options);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.Caching;
using Microsoft.Extensions.Caching.Memory;
using System.Threading;

namespace Microsoft.Data.SqlClient
{
// Maintains a cache of SqlEnclaveSession instances
internal class EnclaveSessionCache
{
private readonly MemoryCache enclaveMemoryCache = new MemoryCache("EnclaveMemoryCache");
private readonly MemoryCache enclaveMemoryCache = new MemoryCache(new MemoryCacheOptions());
private readonly object enclaveCacheLock = new object();

// Nonce for each message sent by the client to the server to prevent replay attacks by the server,
Expand All @@ -25,7 +25,7 @@ internal class EnclaveSessionCache
internal SqlEnclaveSession GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, out long counter)
{
string cacheKey = GenerateCacheKey(enclaveSessionParameters);
SqlEnclaveSession enclaveSession = enclaveMemoryCache[cacheKey] as SqlEnclaveSession;
SqlEnclaveSession enclaveSession = enclaveMemoryCache.Get<SqlEnclaveSession>(cacheKey);
counter = Interlocked.Increment(ref _counter);
return enclaveSession;
}
Expand All @@ -41,8 +41,12 @@ internal void InvalidateSession(EnclaveSessionParameters enclaveSessionParameter

if (enclaveSession != null && enclaveSession.SessionId == enclaveSessionToInvalidate.SessionId)
{
SqlEnclaveSession enclaveSessionRemoved = enclaveMemoryCache.Remove(cacheKey) as SqlEnclaveSession;
if (enclaveSessionRemoved == null)
enclaveMemoryCache.TryGetValue<SqlEnclaveSession>(cacheKey, out SqlEnclaveSession enclaveSessionToRemove);
if (enclaveSessionToRemove != null)
{
enclaveMemoryCache.Remove(cacheKey);
}
else
{
throw new InvalidOperationException(Strings.EnclaveSessionInvalidationFailed);
}
Expand All @@ -58,7 +62,11 @@ internal SqlEnclaveSession CreateSession(EnclaveSessionParameters enclaveSession
lock (enclaveCacheLock)
{
enclaveSession = new SqlEnclaveSession(sharedSecret, sessionId);
enclaveMemoryCache.Add(cacheKey, enclaveSession, DateTime.UtcNow.AddHours(enclaveCacheTimeOutInHours));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(enclaveCacheTimeOutInHours)
};
enclaveMemoryCache.Set<SqlEnclaveSession>(cacheKey, enclaveSession, options);
counter = Interlocked.Increment(ref _counter);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.Caching;
using System.Text;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;

namespace Microsoft.Data.SqlClient
{
Expand Down Expand Up @@ -35,7 +35,7 @@ internal class ColumnMasterKeyMetadataSignatureVerificationCache

private ColumnMasterKeyMetadataSignatureVerificationCache()
{
_cache = new MemoryCache(_className);
_cache = new MemoryCache(new MemoryCacheOptions());
_inTrim = 0;
}

Expand All @@ -46,17 +46,15 @@ private ColumnMasterKeyMetadataSignatureVerificationCache()
/// <param name="masterKeyPath">Key Path for CMK</param>
/// <param name="allowEnclaveComputations">boolean indicating whether the key can be sent to enclave</param>
/// <param name="signature">Signature for the CMK metadata</param>
/// <returns>null if the data is not found in cache otherwise returns true/false indicating signature verification success/failure</returns>
internal bool? GetSignatureVerificationResult(string keyStoreName, string masterKeyPath, bool allowEnclaveComputations, byte[] signature)
internal bool GetSignatureVerificationResult(string keyStoreName, string masterKeyPath, bool allowEnclaveComputations, byte[] signature)
{

ValidateStringArgumentNotNullOrEmpty(masterKeyPath, _masterkeypathArgumentName, _getSignatureVerificationResultMethodName);
ValidateStringArgumentNotNullOrEmpty(keyStoreName, _keyStoreNameArgumentName, _getSignatureVerificationResultMethodName);
ValidateSignatureNotNullOrEmpty(signature, _getSignatureVerificationResultMethodName);

string cacheLookupKey = GetCacheLookupKey(masterKeyPath, allowEnclaveComputations, signature, keyStoreName);

return _cache.Get(cacheLookupKey) as bool?;
return _cache.TryGetValue<bool>(cacheLookupKey, out bool value);
}

/// <summary>
Expand All @@ -69,7 +67,6 @@ private ColumnMasterKeyMetadataSignatureVerificationCache()
/// <param name="result">result indicating signature verification success/failure</param>
internal void AddSignatureVerificationResult(string keyStoreName, string masterKeyPath, bool allowEnclaveComputations, byte[] signature, bool result)
{

ValidateStringArgumentNotNullOrEmpty(masterKeyPath, _masterkeypathArgumentName, _addSignatureVerificationResultMethodName);
ValidateStringArgumentNotNullOrEmpty(keyStoreName, _keyStoreNameArgumentName, _addSignatureVerificationResultMethodName);
ValidateSignatureNotNullOrEmpty(signature, _addSignatureVerificationResultMethodName);
Expand All @@ -79,7 +76,11 @@ internal void AddSignatureVerificationResult(string keyStoreName, string masterK
TrimCacheIfNeeded();

// By default evict after 10 days.
_cache.Set(cacheLookupKey, result, DateTimeOffset.UtcNow.AddDays(10));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromDays(10)
};
_cache.Set<bool>(cacheLookupKey, result, options);
}

private void ValidateSignatureNotNullOrEmpty(byte[] signature, string methodName)
Expand Down Expand Up @@ -115,15 +116,17 @@ private void ValidateStringArgumentNotNullOrEmpty(string stringArgValue, string
private void TrimCacheIfNeeded()
{
// If the size of the cache exceeds the threshold, set that we are in trimming and trim the cache accordingly.
long currentCacheSize = _cache.GetCount();
long currentCacheSize = _cache.Count;
if ((currentCacheSize > _cacheSize + _cacheTrimThreshold) && (0 == Interlocked.CompareExchange(ref _inTrim, 1, 0)))
{
try
{
_cache.Trim((int)(((double)(currentCacheSize - _cacheSize) / (double)currentCacheSize) * 100));
// Example: 2301 - 2000 = 301; 301 / 2301 = 0.1308 * 100 = 13% compacting
_cache.Compact((((double)(currentCacheSize - _cacheSize) / (double)currentCacheSize) * 100));
}
finally
{
// Reset _inTrim flag
Interlocked.CompareExchange(ref _inTrim, 0, 1);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Runtime.Caching;
using System.Text;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;

namespace Microsoft.Data.SqlClient
{
Expand All @@ -34,7 +34,7 @@ sealed internal class SqlQueryMetadataCache

private SqlQueryMetadataCache()
{
_cache = new MemoryCache("SqlQueryMetadataCache");
_cache = new MemoryCache(new MemoryCacheOptions());
}

internal static SqlQueryMetadataCache GetInstance()
Expand All @@ -61,7 +61,7 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)
return false;
}

Dictionary<string, SqlCipherMetadata> cipherMetadataDictionary = _cache.Get(cacheLookupKey) as Dictionary<string, SqlCipherMetadata>;
Dictionary<string, SqlCipherMetadata> cipherMetadataDictionary = _cache.Get<Dictionary<string, SqlCipherMetadata>>(cacheLookupKey);

// If we had a cache miss just return false.
if (cipherMetadataDictionary is null)
Expand Down Expand Up @@ -144,7 +144,7 @@ internal bool GetQueryMetadataIfExists(SqlCommand sqlCommand)
}

ConcurrentDictionary<int, SqlTceCipherInfoEntry> enclaveKeys =
_cache.Get(enclaveLookupKey) as ConcurrentDictionary<int, SqlTceCipherInfoEntry>;
_cache.Get<ConcurrentDictionary<int, SqlTceCipherInfoEntry>>(enclaveLookupKey);
if (enclaveKeys is not null)
{
sqlCommand.keysToBeSentToEnclave = CreateCopyOfEnclaveKeys(enclaveKeys);
Expand Down Expand Up @@ -215,7 +215,7 @@ internal void AddQueryMetadata(SqlCommand sqlCommand, bool ignoreQueriesWithRetu
}

// If the size of the cache exceeds the threshold, set that we are in trimming and trim the cache accordingly.
long currentCacheSize = _cache.GetCount();
long currentCacheSize = _cache.Count;
if ((currentCacheSize > CacheSize + CacheTrimThreshold) && (0 == Interlocked.CompareExchange(ref _inTrim, 1, 0)))
{
try
Expand All @@ -226,7 +226,7 @@ internal void AddQueryMetadata(SqlCommand sqlCommand, bool ignoreQueriesWithRetu
Thread.Sleep(TimeSpan.FromSeconds(10));
}
#endif
_cache.Trim((int)(((double)(currentCacheSize - CacheSize) / (double)currentCacheSize) * 100));
_cache.Compact((int)(((double)(currentCacheSize - CacheSize) / (double)currentCacheSize) * 100));
}
finally
{
Expand All @@ -235,11 +235,15 @@ internal void AddQueryMetadata(SqlCommand sqlCommand, bool ignoreQueriesWithRetu
}

// By default evict after 10 hours.
_cache.Set(cacheLookupKey, cipherMetadataDictionary, DateTimeOffset.UtcNow.AddHours(10));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(10)
};
_cache.Set<Dictionary<string, SqlCipherMetadata>>(cacheLookupKey, cipherMetadataDictionary, options);
if (sqlCommand.requiresEnclaveComputations)
{
ConcurrentDictionary<int, SqlTceCipherInfoEntry> keysToBeCached = CreateCopyOfEnclaveKeys(sqlCommand.keysToBeSentToEnclave);
_cache.Set(enclaveLookupKey, keysToBeCached, DateTimeOffset.UtcNow.AddHours(10));
_cache.Set<ConcurrentDictionary<int, SqlTceCipherInfoEntry>>(enclaveLookupKey, keysToBeCached, options);
}
}

Expand Down
Loading

0 comments on commit df2a2dc

Please sign in to comment.