Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | Fix driver to not send expired token and refresh token first before sending it. #2273

Merged
merged 15 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,13 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");

if (_newDbConnectionPoolAuthenticationContext != null)
{
// Try adding this new _newDbConnectionPoolAuthenticationContext to the _dbConnectionPool's AuthenticationContextKeys if it is not in there yet.
arellegue marked this conversation as resolved.
Show resolved Hide resolved
// The DbConnectionPoolAuthenticationContextKeys collection is used to refresh a cached token just before it expires within 10 minutes.
_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
arellegue marked this conversation as resolved.
Show resolved Hide resolved
arellegue marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
else if (!attemptRefreshTokenLocked)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2680,6 +2680,13 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");

if (_newDbConnectionPoolAuthenticationContext != null)
{
// Try adding this new _newDbConnectionPoolAuthenticationContext to the _dbConnectionPool's AuthenticationContextKeys if it is not in there yet.
arellegue marked this conversation as resolved.
Show resolved Hide resolved
// The DbConnectionPoolAuthenticationContextKeys collection is used to refresh a cached token just before it expires within 10 minutes.
_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
arellegue marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
else if (!attemptRefreshTokenLocked)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@
<Compile Include="DataCommon\ProxyServer.cs" />
<Compile Include="DataCommon\SqlClientCustomTokenCredential.cs" />
<Compile Include="DataCommon\SystemDataResourceManager.cs" />
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
arellegue marked this conversation as resolved.
Show resolved Hide resolved
<Compile Include="SQL\Common\AsyncDebugScope.cs" />
<Compile Include="SQL\Common\ConnectionPoolWrapper.cs" />
<Compile Include="SQL\Common\InternalConnectionWrapper.cs" />
Expand Down Expand Up @@ -298,30 +299,19 @@
<ProjectReference Include="$(TestsPath)tools\TDS\TDS.EndPoint\TDS.EndPoint.csproj" />
<ProjectReference Include="$(TestsPath)tools\TDS\TDS.Servers\TDS.Servers.csproj" />
<ProjectReference Include="$(TestsPath)tools\TDS\TDS\TDS.csproj" />
<ProjectReference
Include="$(TestsPath)tools\Microsoft.Data.SqlClient.TestUtilities\Microsoft.Data.SqlClient.TestUtilities.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netcoreapp' AND $(ReferenceType)=='Project'"
Include="$(NetCoreSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netfx' AND $(ReferenceType)=='Project'"
Include="$(NetFxSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="$(ReferenceType.Contains('NetStandard'))"
Include="$(TestsPath)NSLibrary\Microsoft.Data.SqlClient.NSLibrary.csproj" />
<ProjectReference Condition="!$(ReferenceType.Contains('Package'))"
Include="$(SqlServerSource)Microsoft.SqlServer.Server.csproj" />
<PackageReference Condition="$(ReferenceType.Contains('Package'))"
Include="Microsoft.Data.SqlClient" Version="$(TestMicrosoftDataSqlClientVersion)" />
<ProjectReference
Include="$(TestsPath)CustomConfigurableRetryLogic\CustomRetryLogicProvider.csproj" />
<ProjectReference Include="$(TestsPath)tools\Microsoft.Data.SqlClient.TestUtilities\Microsoft.Data.SqlClient.TestUtilities.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netcoreapp' AND $(ReferenceType)=='Project'" Include="$(NetCoreSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="'$(TargetGroup)'=='netfx' AND $(ReferenceType)=='Project'" Include="$(NetFxSource)src\Microsoft.Data.SqlClient.csproj" />
<ProjectReference Condition="$(ReferenceType.Contains('NetStandard'))" Include="$(TestsPath)NSLibrary\Microsoft.Data.SqlClient.NSLibrary.csproj" />
<ProjectReference Condition="!$(ReferenceType.Contains('Package'))" Include="$(SqlServerSource)Microsoft.SqlServer.Server.csproj" />
<PackageReference Condition="$(ReferenceType.Contains('Package'))" Include="Microsoft.Data.SqlClient" Version="$(TestMicrosoftDataSqlClientVersion)" />
<ProjectReference Include="$(TestsPath)CustomConfigurableRetryLogic\CustomRetryLogicProvider.csproj" />
</ItemGroup>
<!-- XUnit and XUnit extensions -->
<ItemGroup>
<PackageReference Condition="$(TargetGroup) == 'netfx'"
Include="System.Runtime.InteropServices.RuntimeInformation"
Version="$(SystemRuntimeInteropServicesRuntimeInformationVersion)" />
<PackageReference Condition="$(TargetGroup) == 'netfx'" Include="System.Runtime.InteropServices.RuntimeInformation" Version="$(SystemRuntimeInteropServicesRuntimeInformationVersion)" />
<PackageReference Include="xunit" Version="$(XunitVersion)" />
<PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies"
Version="$(MicrosoftNETFrameworkReferenceAssembliesVersion)"
Condition="'$(TargetGroup)' == 'netfx'">
<PackageReference Include="Microsoft.NETFramework.ReferenceAssemblies" Version="$(MicrosoftNETFrameworkReferenceAssembliesVersion)" Condition="'$(TargetGroup)' == 'netfx'">
arellegue marked this conversation as resolved.
Show resolved Hide resolved
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
<PrivateAssets>all</PrivateAssets>
</PackageReference>
Expand All @@ -334,8 +324,7 @@
<PrivateAssets>all</PrivateAssets>
</PackageReference>
<PackageReference Include="xunit.runner.utility" Version="$(XunitVersion)" />
<PackageReference Include="Microsoft.DotNet.XUnitExtensions"
Version="$(MicrosoftDotNetXUnitExtensionsVersion)" />
<PackageReference Include="Microsoft.DotNet.XUnitExtensions" Version="$(MicrosoftDotNetXUnitExtensionsVersion)" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonVersion)" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
using System;
DavoudEshtehari marked this conversation as resolved.
Show resolved Hide resolved
using System.Collections;
using System.Linq;
using System.Reflection;
using System.Security.Cryptography;
using System.Text;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public class AADFedAuthTokenRefreshTest
{
private readonly ITestOutputHelper _testOutputHelper;

public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))]
public void FedAuthTokenRefreshTest()
{
// ------------------ Use settings below for local environment testing ------------------------
arellegue marked this conversation as resolved.
Show resolved Hide resolved
//SqlConnectionStringBuilder builder = new(DataTestUtility.AADPasswordConnectionString);
//string dataSourceStr = builder.DataSource;

//// set user id and password from AADPasswordConnectionString
//string user = builder.UserID;
//string password = builder.Password;

//// Set Environment variables used for ActiveDirectoryDefault authentication type
//Environment.SetEnvironmentVariable("AZURE_USERNAME", $"{user}");
//Environment.SetEnvironmentVariable("AZURE_PASSWORD", $"{password}");

//string userEnvVar = Environment.GetEnvironmentVariable("AZURE_USERNAME");
//string passwordEnvVar = Environment.GetEnvironmentVariable("AZURE_PASSWORD");
//Assert.True($"{user}" == userEnvVar, @"AZURE_USERNAME environment variable must be set");
//Assert.True($"{password}" == passwordEnvVar, @"AZURE_PASSWORD environment variable must be set");

//// Local environment connection string
//string connStr = $"Server={dataSourceStr};Persist Security Info=False;User ID={user};MultipleActiveResultSets=False;Encrypt=True;TrustServerCertificate=False;Authentication=ActiveDirectoryDefault;Timeout=90";

// ------------------ End of local environment settings ----------------------------------------------------

// ------------------ Pipeline environment setting ---------------------------------------------------------
// Use this connection string when running in a pipeline
string connStr = DataTestUtility.AADPasswordConnectionString;
// -------------------End of Pipeline environment setting --------------------------------------------------

// Create a new connection object and open it
SqlConnection connection = new SqlConnection(connStr);
arellegue marked this conversation as resolved.
Show resolved Hide resolved
connection.Open();

// Set the token expiry to expire in 1 minute from now to force token refresh
string tokenHash1 = "";
DateTime? oldExpiry = GetOrSetTokenExpiryDateTime(connection, true, out tokenHash1);
Assert.True(oldExpiry != null, "Failed to make token expiry to expire in one minute.");

// Display old expiry in local time which should be in 1 minute from now
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiry, TimeZoneInfo.Local);
LogInfo($"Token: {tokenHash1} Old Expiry: {oldLocalExpiryTime}");
TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now;
Assert.True(timeDiff.TotalSeconds <= 60, "Failed to set expiry after 1 minute from current time.");

// Check if connection is alive
string result = "";
SqlCommand cmd = connection.CreateCommand();
cmd.CommandText = "select @@version";
result = $"{cmd.ExecuteScalar()}";
Assert.True(result != string.Empty, "The connection's command must return a value");

// The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute.
SqlConnection connection2 = new SqlConnection(connStr);
connection2.Open();

// Check again if connection is alive
cmd = connection2.CreateCommand();
cmd.CommandText = "select 1";
result = $"{cmd.ExecuteScalar()}";
Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh.");

// Get the refreshed token expiry
string tokenHash2 = "";
DateTime? newExpiry = GetOrSetTokenExpiryDateTime(connection2, false, out tokenHash2);
// Display new expiry in local time
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiry, TimeZoneInfo.Local);
LogInfo($"Token: {tokenHash2} New Expiry: {newLocalExpiryTime}");

Assert.True(tokenHash1 == tokenHash2, "The token's hash before and after token refresh must be identical.");
Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time.");

connection.Close();
connection2.Close();
}

private void LogInfo(string message)
{
//Console.WriteLine(message);
_testOutputHelper.WriteLine(message);
}

private DateTime? GetOrSetTokenExpiryDateTime(SqlConnection connection, bool setExpiry, out string tokenHash)
arellegue marked this conversation as resolved.
Show resolved Hide resolved
{
try
{
// Get the inner connection
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);

// Get the db connection pool
object poolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);

// Get the Authentication Contexts
IEnumerable authContextCollection = (IEnumerable)poolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(poolObj, null);

// Get the first authentication context
object authContextObj = authContextCollection.Cast<object>().FirstOrDefault();

// Get the token object from the authentication context
object tokenObj = authContextObj.GetType().GetProperty("Value").GetValue(authContextObj, null);

DateTime expiry = (DateTime)tokenObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(tokenObj, null);

if (setExpiry)
{
// Forcing 1 minute expiry to trigger token refresh.
expiry = DateTime.UtcNow.AddMinutes(1);

// Apply the expiry to the token object
FieldInfo expirationTime = tokenObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
expirationTime.SetValue(tokenObj, expiry);
}

byte[] tokenBytes = (byte[])tokenObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(tokenObj, null);

tokenHash = GetTokenHash(tokenBytes);

return expiry;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

private string GetTokenHash(byte[] tokenBytes)
arellegue marked this conversation as resolved.
Show resolved Hide resolved
{
string token = Encoding.Unicode.GetString(tokenBytes);
var bytesInUtf8 = Encoding.UTF8.GetBytes(token);
using (var sha256 = SHA256.Create())
{
var hash = sha256.ComputeHash(bytesInUtf8);
return Convert.ToBase64String(hash);
}
}
}
}
Loading