Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | Replaced System.Runtime.Caching with Microsoft.Extensions.Caching.Memory #2493

Merged
merged 18 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<Compile Include="..\..\ref\Microsoft.Data.SqlClient.Batch.NetCoreApp.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(MicrosoftExtensionsCachingMemoryVersion)" />
<PackageReference Include="System.Security.Cryptography.Cng" Version="$(SystemSecurityCryptographyCngVersion)" />
<PackageReference Include="Microsoft.Identity.Client" Version="$(MicrosoftIdentityClientVersion)" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,10 +936,10 @@
<!-- Package References Etc -->
<ItemGroup Condition="'$(TargetGroup)' == 'netcoreapp'">
<PackageReference Include="System.Configuration.ConfigurationManager" Version="$(SystemConfigurationConfigurationManagerVersion)" />
<PackageReference Include="System.Runtime.Caching" Version="$(SystemRuntimeCachingVersion)" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Data.SqlClient.SNI.runtime" Version="$(MicrosoftDataSqlClientSNIRuntimeVersion)" />
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(MicrosoftExtensionsCachingMemoryVersion)" />
<!-- Enable the project reference for debugging purposes. -->
<!-- <ProjectReference Include="$(SqlServerSourceCode)\Microsoft.SqlServer.Server.csproj" /> -->
<PackageReference Include="Microsoft.SqlServer.Server" Version="$(MicrosoftSqlServerServerVersion)" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<Reference Include="System.Transactions" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(MicrosoftExtensionsCachingMemoryVersion)" />
<PackageReference Include="Microsoft.Identity.Client" Version="$(MicrosoftIdentityClientVersion)" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@
</COMReference>
</ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="$(MicrosoftExtensionsCachingMemoryVersion)" />
<PackageReference Include="System.Text.Encodings.Web">
<Version>$(SystemTextEncodingsWebVersion)</Version>
</PackageReference>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
using System;
using System.Collections.Concurrent;
using System.Linq;
using System.Runtime.Caching;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Extensibility;

Expand All @@ -27,7 +27,7 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
private static readonly MemoryCache s_accountPwCache = new MemoryCache(new MemoryCacheOptions());
private static readonly int s_accountPwCacheTtlInHours = 2;
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
Expand Down Expand Up @@ -270,11 +270,11 @@ previousPw is byte[] previousPwBytes &&
// We cache the password hash to ensure future connection requests include a validated password
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
// against the same tenant could succeed with an invalid password when we re-use the cached token.
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
using (ICacheEntry entry = s_accountPwCache.CreateEntry(pwCacheKey))
{
s_accountPwCache.Remove(pwCacheKey);
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
}
entry.Value = GetHash(parameters.Password);
entry.AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(s_accountPwCacheTtlInHours);
};

SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IdentityModel.Tokens.Jwt;
using System.Runtime.Caching;
using System.Security.Claims;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.IdentityModel.JsonWebTokens;
using Microsoft.IdentityModel.Logging;
using Microsoft.IdentityModel.Protocols;
Expand Down Expand Up @@ -59,7 +59,7 @@ internal class AzureAttestationEnclaveProvider : EnclaveProviderBase
// such as https://sql.azure.attest.com/.well-known/openid-configuration
private const string AttestationUrlSuffix = @"/.well-known/openid-configuration";

private static readonly MemoryCache OpenIdConnectConfigurationCache = new MemoryCache("OpenIdConnectConfigurationCache");
private static readonly MemoryCache OpenIdConnectConfigurationCache = new MemoryCache(new MemoryCacheOptions());
#endregion

#region Internal methods
Expand Down Expand Up @@ -332,7 +332,7 @@ private static string GetInnerMostExceptionMessage(Exception exception)
// It also caches that information for 1 day to avoid DDOS attacks.
private OpenIdConnectConfiguration GetOpenIdConfigForSigningKeys(string url, bool forceUpdate)
{
OpenIdConnectConfiguration openIdConnectConfig = OpenIdConnectConfigurationCache[url] as OpenIdConnectConfiguration;
OpenIdConnectConfiguration openIdConnectConfig = OpenIdConnectConfigurationCache.Get<OpenIdConnectConfiguration>(url);
JRahnama marked this conversation as resolved.
Show resolved Hide resolved
if (forceUpdate || openIdConnectConfig == null)
{
// Compute the meta data endpoint
Expand All @@ -348,7 +348,11 @@ private OpenIdConnectConfiguration GetOpenIdConfigForSigningKeys(string url, boo
throw SQL.AttestationFailed(string.Format(Strings.GetAttestationTokenSigningKeysFailed, GetInnerMostExceptionMessage(exception)), exception);
}

OpenIdConnectConfigurationCache.Add(url, openIdConnectConfig, DateTime.UtcNow.AddDays(1));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromDays(1)
};
OpenIdConnectConfigurationCache.Set<OpenIdConnectConfiguration>(url, openIdConnectConfig, options);
}

return openIdConnectConfig;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.Caching;
using System.Security.Cryptography;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;

// Enclave session locking model
// 1. For doing the enclave attestation, driver makes either 1, 2 or 3 API calls(in order)
Expand Down Expand Up @@ -84,7 +84,7 @@ internal abstract class EnclaveProviderBase : SqlColumnEncryptionEnclaveProvider
private static readonly Object lockUpdateSessionLock = new Object();

// It is used to save the attestation url and nonce value across API calls
protected static readonly MemoryCache ThreadRetryCache = new MemoryCache("ThreadRetryCache");
protected static readonly MemoryCache ThreadRetryCache = new MemoryCache(new MemoryCacheOptions());
#endregion

#region protected methods
Expand All @@ -102,7 +102,7 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa

// In case if on some thread we are running SQL workload which don't require attestation, then in those cases we don't want same thread to wait for event to be signaled.
// hence skipping it
string retryThreadID = ThreadRetryCache[Thread.CurrentThread.ManagedThreadId.ToString()] as string;
string retryThreadID = ThreadRetryCache.Get<string>(Thread.CurrentThread.ManagedThreadId.ToString());
if (!string.IsNullOrEmpty(retryThreadID))
{
sameThreadRetry = true;
Expand Down Expand Up @@ -167,7 +167,11 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa
retryThreadID = Thread.CurrentThread.ManagedThreadId.ToString();
}

ThreadRetryCache.Set(Thread.CurrentThread.ManagedThreadId.ToString(), retryThreadID, DateTime.UtcNow.AddMinutes(ThreadRetryCacheTimeoutInMinutes));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromMinutes(ThreadRetryCacheTimeoutInMinutes)
};
ThreadRetryCache.Set<string>(Thread.CurrentThread.ManagedThreadId.ToString(), retryThreadID, options);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.Caching;
using Microsoft.Extensions.Caching.Memory;
using System.Threading;

namespace Microsoft.Data.SqlClient
{
// Maintains a cache of SqlEnclaveSession instances
internal class EnclaveSessionCache
{
private readonly MemoryCache enclaveMemoryCache = new MemoryCache("EnclaveMemoryCache");
private readonly MemoryCache enclaveMemoryCache = new MemoryCache(new MemoryCacheOptions());
private readonly object enclaveCacheLock = new object();

// Nonce for each message sent by the client to the server to prevent replay attacks by the server,
Expand All @@ -25,7 +25,7 @@ internal class EnclaveSessionCache
internal SqlEnclaveSession GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, out long counter)
{
string cacheKey = GenerateCacheKey(enclaveSessionParameters);
SqlEnclaveSession enclaveSession = enclaveMemoryCache[cacheKey] as SqlEnclaveSession;
SqlEnclaveSession enclaveSession = enclaveMemoryCache.Get<SqlEnclaveSession>(cacheKey);
counter = Interlocked.Increment(ref _counter);
return enclaveSession;
}
Expand All @@ -41,8 +41,12 @@ internal void InvalidateSession(EnclaveSessionParameters enclaveSessionParameter

if (enclaveSession != null && enclaveSession.SessionId == enclaveSessionToInvalidate.SessionId)
{
SqlEnclaveSession enclaveSessionRemoved = enclaveMemoryCache.Remove(cacheKey) as SqlEnclaveSession;
if (enclaveSessionRemoved == null)
enclaveMemoryCache.TryGetValue<SqlEnclaveSession>(cacheKey, out SqlEnclaveSession enclaveSessionToRemove);
if (enclaveSessionToRemove != null)
{
enclaveMemoryCache.Remove(cacheKey);
}
else
{
throw new InvalidOperationException(Strings.EnclaveSessionInvalidationFailed);
}
Expand All @@ -58,7 +62,11 @@ internal SqlEnclaveSession CreateSession(EnclaveSessionParameters enclaveSession
lock (enclaveCacheLock)
{
enclaveSession = new SqlEnclaveSession(sharedSecret, sessionId);
enclaveMemoryCache.Add(cacheKey, enclaveSession, DateTime.UtcNow.AddHours(enclaveCacheTimeOutInHours));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromHours(enclaveCacheTimeOutInHours)
};
enclaveMemoryCache.Set<SqlEnclaveSession>(cacheKey, enclaveSession, options);
counter = Interlocked.Increment(ref _counter);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.Caching;
using System.Text;
using System.Threading;
using Microsoft.Extensions.Caching.Memory;

namespace Microsoft.Data.SqlClient
{
Expand Down Expand Up @@ -35,7 +35,7 @@ internal class ColumnMasterKeyMetadataSignatureVerificationCache

private ColumnMasterKeyMetadataSignatureVerificationCache()
{
_cache = new MemoryCache(_className);
_cache = new MemoryCache(new MemoryCacheOptions());
_inTrim = 0;
}

Expand All @@ -46,17 +46,15 @@ private ColumnMasterKeyMetadataSignatureVerificationCache()
/// <param name="masterKeyPath">Key Path for CMK</param>
/// <param name="allowEnclaveComputations">boolean indicating whether the key can be sent to enclave</param>
/// <param name="signature">Signature for the CMK metadata</param>
/// <returns>null if the data is not found in cache otherwise returns true/false indicating signature verification success/failure</returns>
internal bool? GetSignatureVerificationResult(string keyStoreName, string masterKeyPath, bool allowEnclaveComputations, byte[] signature)
internal bool GetSignatureVerificationResult(string keyStoreName, string masterKeyPath, bool allowEnclaveComputations, byte[] signature)
{

ValidateStringArgumentNotNullOrEmpty(masterKeyPath, _masterkeypathArgumentName, _getSignatureVerificationResultMethodName);
ValidateStringArgumentNotNullOrEmpty(keyStoreName, _keyStoreNameArgumentName, _getSignatureVerificationResultMethodName);
ValidateSignatureNotNullOrEmpty(signature, _getSignatureVerificationResultMethodName);

string cacheLookupKey = GetCacheLookupKey(masterKeyPath, allowEnclaveComputations, signature, keyStoreName);

return _cache.Get(cacheLookupKey) as bool?;
return _cache.TryGetValue<bool>(cacheLookupKey, out bool value);
}

/// <summary>
Expand All @@ -69,7 +67,6 @@ private ColumnMasterKeyMetadataSignatureVerificationCache()
/// <param name="result">result indicating signature verification success/failure</param>
internal void AddSignatureVerificationResult(string keyStoreName, string masterKeyPath, bool allowEnclaveComputations, byte[] signature, bool result)
{

ValidateStringArgumentNotNullOrEmpty(masterKeyPath, _masterkeypathArgumentName, _addSignatureVerificationResultMethodName);
ValidateStringArgumentNotNullOrEmpty(keyStoreName, _keyStoreNameArgumentName, _addSignatureVerificationResultMethodName);
ValidateSignatureNotNullOrEmpty(signature, _addSignatureVerificationResultMethodName);
Expand All @@ -79,7 +76,11 @@ internal void AddSignatureVerificationResult(string keyStoreName, string masterK
TrimCacheIfNeeded();

// By default evict after 10 days.
_cache.Set(cacheLookupKey, result, DateTimeOffset.UtcNow.AddDays(10));
MemoryCacheEntryOptions options = new MemoryCacheEntryOptions
{
AbsoluteExpirationRelativeToNow = TimeSpan.FromDays(10)
};
_cache.Set<bool>(cacheLookupKey, result, options);
}

private void ValidateSignatureNotNullOrEmpty(byte[] signature, string methodName)
Expand Down Expand Up @@ -115,15 +116,17 @@ private void ValidateStringArgumentNotNullOrEmpty(string stringArgValue, string
private void TrimCacheIfNeeded()
{
// If the size of the cache exceeds the threshold, set that we are in trimming and trim the cache accordingly.
long currentCacheSize = _cache.GetCount();
long currentCacheSize = _cache.Count;
if ((currentCacheSize > _cacheSize + _cacheTrimThreshold) && (0 == Interlocked.CompareExchange(ref _inTrim, 1, 0)))
{
try
{
_cache.Trim((int)(((double)(currentCacheSize - _cacheSize) / (double)currentCacheSize) * 100));
// Example: 2301 - 2000 = 301; 301 / 2301 = 0.1308 * 100 = 13% compacting
_cache.Compact((((double)(currentCacheSize - _cacheSize) / (double)currentCacheSize) * 100));
}
finally
{
// Reset _inTrim flag
Interlocked.CompareExchange(ref _inTrim, 0, 1);
}
}
Expand Down
Loading
Loading