diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index 29500ef5b4..5606ebf9c1 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -6,6 +6,9 @@ using System.Collections.Concurrent; using System.Linq; using System.Security; +using System.Runtime.Caching; +using System.Security.Cryptography; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client; @@ -23,6 +26,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro /// private static ConcurrentDictionary s_pcaMap = new ConcurrentDictionary(); + private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider)); + private static readonly int s_accountPwCacheTtlInHours = 2; private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient"; private static readonly string s_defaultScopeSuffix = "/.default"; private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name; @@ -101,7 +106,9 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication) /// public override Task AcquireTokenAsync(SqlAuthenticationParameters parameters) => Task.Run(async () => { - AuthenticationResult result; + CancellationTokenSource cts = new(); + + AuthenticationResult result = null; string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix; string[] scopes = new string[] { scope }; @@ -147,69 +154,84 @@ public override Task AcquireTokenAsync(SqlAuthentication if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated) { - if (!string.IsNullOrEmpty(parameters.UserId)) - { - result = app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .WithUsername(parameters.UserId) - .ExecuteAsync().Result; - } - else + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + + if (result == null) { - result = app.AcquireTokenByIntegratedWindowsAuth(scopes) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync().Result; + if (!string.IsNullOrEmpty(parameters.UserId)) + { + result = app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .WithUsername(parameters.UserId) + .ExecuteAsync(cancellationToken: cts.Token).Result; + } + else + { + result = app.AcquireTokenByIntegratedWindowsAuth(scopes) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token).Result; + } + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn); } - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn); } else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword) { - SecureString password = new SecureString(); - foreach (char c in parameters.Password) - password.AppendChar(c); - password.MakeReadOnly(); - result = app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) - .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync().Result; - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn); + string pwCacheKey = GetAccountPwCacheKey(parameters); + object previousPw = s_accountPwCache.Get(pwCacheKey); + byte[] currPwHash = GetHash(parameters.Password); + + if (null != previousPw && + previousPw is byte[] previousPwBytes && + // Only get the cached token if the current password hash matches the previously used password hash + currPwHash.SequenceEqual(previousPwBytes)) + { + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + } + + if (result == null) + { + SecureString password = new SecureString(); + foreach (char c in parameters.Password) + password.AppendChar(c); + password.MakeReadOnly(); + result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync() + .ConfigureAwait(false); + + // We cache the password hash to ensure future connection requests include a validated password + // when we check for a cached MSAL account. Otherwise, a connection request with the same username + // against the same tenant could succeed with an invalid password when we re-use the cached token. + if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours))) + { + s_accountPwCache.Remove(pwCacheKey); + s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)); + } + + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn); + } } else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) { - // Fetch available accounts from 'app' instance - System.Collections.Generic.IEnumerable accounts = await app.GetAccountsAsync(); - IAccount account; - if (!string.IsNullOrEmpty(parameters.UserId)) + try { - account = accounts.FirstOrDefault(a => parameters.UserId.Equals(a.Username, System.StringComparison.InvariantCultureIgnoreCase)); + result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); } - else + catch (MsalUiRequiredException) { - account = accounts.FirstOrDefault(); + // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, + // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), + // or the user needs to perform two factor authentication. + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); } - if (null != account) - { - try - { - // If 'account' is available in 'app', we use the same to acquire token silently. - // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent - result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); - } - catch (MsalUiRequiredException) - { - // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, - // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), - // or the user needs to perform two factor authentication. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod); - SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); - } - } - else + if (result == null) { // If no existing 'account' is found, we request user to sign in interactively. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod); + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn); } } @@ -222,11 +244,58 @@ public override Task AcquireTokenAsync(SqlAuthentication return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); }); + private static async Task TryAcquireTokenSilent(IPublicClientApplication app, + SqlAuthenticationParameters parameters, + string[] scopes, + CancellationTokenSource cts) + { + AuthenticationResult result = null; + + // Fetch available accounts from 'app' instance + System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); + + IAccount account = default; + if (accounts.MoveNext()) + { + if (!string.IsNullOrEmpty(parameters.UserId)) + { + do + { + IAccount currentVal = accounts.Current; + if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0) + { + account = currentVal; + break; + } + } + while (accounts.MoveNext()); + } + else + { + account = accounts.Current; + } + } + + if (null != account) + { + // If 'account' is available in 'app', we use the same to acquire token silently. + // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent + result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); + SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); + } + + return result; + } - private async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId, - SqlAuthenticationMethod authenticationMethod) + private static async Task AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, + string[] scopes, + Guid connectionId, + string userId, + SqlAuthenticationMethod authenticationMethod, + CancellationTokenSource cts, + ICustomWebUi customWebUI, + Func deviceCodeFlowCallback) { - CancellationTokenSource cts = new CancellationTokenSource(); #if NETCOREAPP /* * On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser, @@ -243,11 +312,11 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( { if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive) { - if (_customWebUI != null) + if (customWebUI != null) { return await app.AcquireTokenInteractive(scopes) .WithCorrelationId(connectionId) - .WithCustomWebUi(_customWebUI) + .WithCustomWebUi(customWebUI) .WithLoginHint(userId) .ExecuteAsync(cts.Token); } @@ -279,7 +348,7 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( else { AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, - deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync(); + deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync(); return result; } } @@ -329,6 +398,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p return clientApplicationInstance; } + private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters) + { + return parameters.Authority + "+" + parameters.UserId; + } + + private static byte[] GetHash(string input) + { + byte[] unhashedBytes = Encoding.Unicode.GetBytes(input); + SHA256 sha256 = SHA256.Create(); + byte[] hashedBytes = sha256.ComputeHash(unhashedBytes); + return hashedBytes; + } + private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey) { IPublicClientApplication publicClientApplication;