Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix | AE enclave retry logic not working for async queries #1988

Merged
merged 6 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ the enclave attestation protocol as well as the logic for creating and caching e
<param name="counter">A counter that the enclave provider is expected to increment each time SqlClient retrieves the session from the cache. The purpose of this field is to prevent replay attacks.</param>
<param name="customData">A set of extra data needed for attestating the enclave.</param>
David-Engel marked this conversation as resolved.
Show resolved Hide resolved
<param name="customDataLength">The length of the extra data needed for attestating the enclave.</param>
<param name="isRetry">Indicates if this is a retry from a failed call.</param>
<summary>When overridden in a derived class, looks up an existing enclave session information in the enclave session cache. If the enclave provider doesn't implement enclave session caching, this method is expected to return <see langword="null" /> in the <paramref name="sqlEnclaveSession" /> parameter.
</summary>
<remarks>To be added.</remarks>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace Microsoft.Data.SqlClient
internal abstract partial class SqlColumnEncryptionEnclaveProvider
{
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml' path='docs/members[@name="SqlColumnEncryptionEnclaveProvider"]/GetEnclaveSession/*'/>
internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength);
internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool isRetry);
David-Engel marked this conversation as resolved.
Show resolved Hide resolved

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml' path='docs/members[@name="SqlColumnEncryptionEnclaveProvider"]/GetAttestationParameters/*'/>
internal abstract SqlEnclaveAttestationParameters GetAttestationParameters(string attestationUrl, byte[] customData, int customDataLength);
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal abstract class SqlColumnEncryptionEnclaveProvider
{

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml' path='docs/members[@name="SqlColumnEncryptionEnclaveProvider"]/GetEnclaveSession/*'/>
internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength);
internal abstract void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool isRetry);

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlColumnEncryptionEnclaveProvider.xml' path='docs/members[@name="SqlColumnEncryptionEnclaveProvider"]/GetAttestationParameters/*'/>
internal abstract SqlEnclaveAttestationParameters GetAttestationParameters(string attestationUrl, byte[] customData, int customDataLength);
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ internal class AzureAttestationEnclaveProvider : EnclaveProviderBase
#region Internal methods
// When overridden in a derived class, looks up an existing enclave session information in the enclave session cache.
// If the enclave provider doesn't implement enclave session caching, this method is expected to return null in the sqlEnclaveSession parameter.
internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength)
internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool isRetry)
{
GetEnclaveSessionHelper(enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out counter, out customData, out customDataLength);
GetEnclaveSessionHelper(enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out counter, out customData, out customDataLength, isRetry);
}

// Gets the information that SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ internal sealed partial class EnclaveDelegate
/// <param name="attestationParameters">attestation parameters</param>
/// <param name="customData">A set of extra data needed for attestating the enclave.</param>
/// <param name="customDataLength">The length of the extra data needed for attestating the enclave.</param>
/// <param name="isRetry">Indicates if this is a retry from a failed call.</param>
internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters,
byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength)
byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength, bool isRetry)
{
lock (_lock)
{
Expand All @@ -35,7 +36,8 @@ internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationP
sqlEnclaveSession: out SqlEnclaveSession sqlEnclaveSession,
counter: out _,
customData: out _,
customDataLength: out _
customDataLength: out _,
isRetry: isRetry
);

if (sqlEnclaveSession != null)
Expand All @@ -60,15 +62,15 @@ internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationP
}
}

internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength)
internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength, bool isRetry)
{
GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out _, out customData, out customDataLength, throwIfNull: false);
GetEnclaveSession(attestationProtocol, enclaveType, enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out _, out customData, out customDataLength, throwIfNull: false, isRetry);
}

private void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool throwIfNull)
private void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool throwIfNull, bool isRetry)
{
SqlColumnEncryptionEnclaveProvider sqlColumnEncryptionEnclaveProvider = GetEnclaveProvider(attestationProtocol, enclaveType);
sqlColumnEncryptionEnclaveProvider.GetEnclaveSession(enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out counter, out customData, out customDataLength);
sqlColumnEncryptionEnclaveProvider.GetEnclaveSession(enclaveSessionParameters, generateCustomData, out sqlEnclaveSession, out counter, out customData, out customDataLength, isRetry);

if (throwIfNull && sqlEnclaveSession == null)
{
Expand Down Expand Up @@ -147,7 +149,8 @@ internal EnclavePackage GenerateEnclavePackage(SqlConnectionAttestationProtocol
counter: out counter,
customData: out _,
customDataLength: out _,
throwIfNull: true
throwIfNull: true,
isRetry: false
);
}
catch (Exception e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ internal byte[] GetSerializedAttestationParameters(
/// <param name="attestationParameters">attestation parameters</param>
/// <param name="customData">A set of extra data needed for attestating the enclave.</param>
/// <param name="customDataLength">The length of the extra data needed for attestating the enclave.</param>
/// <param name="isRetry">Indicates if this is a retry from a failed call.</param>
internal void CreateEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters,
byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength)
byte[] attestationInfo, SqlEnclaveAttestationParameters attestationParameters, byte[] customData, int customDataLength, bool isRetry)
{
throw new PlatformNotSupportedException();
}

internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength)
internal void GetEnclaveSession(SqlConnectionAttestationProtocol attestationProtocol, string enclaveType, EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out byte[] customData, out int customDataLength, bool isRetry)
{
throw new PlatformNotSupportedException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ internal abstract class EnclaveProviderBase : SqlColumnEncryptionEnclaveProvider

#region protected methods
// Helper method to get the enclave session from the cache if present
protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionParameters, bool shouldGenerateNonce, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength)
protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionParameters, bool shouldGenerateNonce, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool isRetry)
{
customData = null;
customDataLength = 0;
Expand All @@ -107,7 +107,7 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa
{
sameThreadRetry = true;
}
else
else if (!isRetry)
{
// We are explicitly not signalling the event here, as we want to hold the event till driver calls CreateEnclaveSession
// If we signal the event now, then multiple thread end up calling GetAttestationParameters which triggers the attestation workflow.
Expand All @@ -124,7 +124,7 @@ protected void GetEnclaveSessionHelper(EnclaveSessionParameters enclaveSessionPa

// In case of multi-threaded application, first thread will set the event and all the subsequent threads will wait here either until the enclave
// session is created or timeout happens.
if (sessionCacheLockTaken || sameThreadRetry)
if (sessionCacheLockTaken || sameThreadRetry || isRetry)
{
// While the current thread is waiting for event to be signaled and in the meanwhile we already completed the attestation on different thread
// then we need to signal the event here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ internal class NoneAttestationEnclaveProvider : EnclaveProviderBase

// When overridden in a derived class, looks up an existing enclave session information in the enclave session cache.
// If the enclave provider doesn't implement enclave session caching, this method is expected to return null in the sqlEnclaveSession parameter.
internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength)
internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool isRetry)
{
GetEnclaveSessionHelper(enclaveSessionParameters, false, out sqlEnclaveSession, out counter, out customData, out customDataLength);
GetEnclaveSessionHelper(enclaveSessionParameters, false, out sqlEnclaveSession, out counter, out customData, out customDataLength, isRetry);
}

// Gets the information that SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ internal abstract class VirtualizationBasedSecurityEnclaveProviderBase : Enclave

// When overridden in a derived class, looks up an existing enclave session information in the enclave session cache.
// If the enclave provider doesn't implement enclave session caching, this method is expected to return null in the sqlEnclaveSession parameter.
internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength)
internal override void GetEnclaveSession(EnclaveSessionParameters enclaveSessionParameters, bool generateCustomData, out SqlEnclaveSession sqlEnclaveSession, out long counter, out byte[] customData, out int customDataLength, bool isRetry)
{
GetEnclaveSessionHelper(enclaveSessionParameters, false, out sqlEnclaveSession, out counter, out customData, out customDataLength);
GetEnclaveSessionHelper(enclaveSessionParameters, false, out sqlEnclaveSession, out counter, out customData, out customDataLength, isRetry);
}

// Gets the information that SqlClient subsequently uses to initiate the process of attesting the enclave and to establish a secure session with the enclave.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2336,6 +2336,92 @@ public void TestRetryWhenAEParameterMetadataCacheIsStale(string connectionString
cmd.ExecuteNonQuery();
}

[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringSetupForAE), nameof(DataTestUtility.EnclaveEnabled))]
[ClassData(typeof(AEConnectionStringProvider))]
public void TestRetryWhenAEEnclaveCacheIsStale(string connectionString)
David-Engel marked this conversation as resolved.
Show resolved Hide resolved
{
CleanUpTable(connectionString, _tableName);

const int customerId = 50;
IList<object> values = GetValues(dataHint: customerId);
InsertRows(tableName: _tableName, numberofRows: 1, values: values, connection: connectionString);

ApiTestTable table = _fixture.ApiTestTable as ApiTestTable;
string enclaveSelectQuery = $@"SELECT CustomerId, FirstName, LastName FROM [{_tableName}] WHERE CustomerId > @CustomerId";
string alterCekQueryFormatString = "ALTER TABLE [{0}] " +
"ALTER COLUMN [CustomerId] [int] " +
"ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{1}], " +
"ENCRYPTION_TYPE = Randomized, " +
"ALGORITHM = 'AEAD_AES_256_CBC_HMAC_SHA_256'); " +
"ALTER DATABASE SCOPED CONFIGURATION CLEAR PROCEDURE_CACHE;";

using SqlConnection sqlConnection = new(connectionString);
sqlConnection.Open();

// change the CEK and encryption type to randomized for the CustomerId column to ensure enclaves are used
using SqlCommand cmd = new SqlCommand(
string.Format(alterCekQueryFormatString, _tableName, table.columnEncryptionKey2.Name),
sqlConnection,
null,
SqlCommandColumnEncryptionSetting.Enabled);
cmd.ExecuteNonQuery();

// execute the select query to create the cache entry
cmd.CommandText = enclaveSelectQuery;
cmd.Parameters.AddWithValue("@CustomerId", 0);
using (SqlDataReader reader = cmd.ExecuteReader())
{
while (reader.Read())
{
Assert.Equal(customerId, (int)reader[0]);
}
reader.Close();
}

CommandHelper.InvalidateEnclaveSession(cmd);

// Execute again to exercise the session retry logic
using (SqlDataReader reader = cmd.ExecuteReader())
{
while (reader.Read())
{
Assert.Equal(customerId, (int)reader[0]);
}
reader.Close();
}

CommandHelper.InvalidateEnclaveSession(cmd);

// Execute again to exercise the async session retry logic
Task readAsyncTask = ReadAsync(cmd, values, CommandBehavior.Default);
readAsyncTask.Wait();
David-Engel marked this conversation as resolved.
Show resolved Hide resolved

#if DEBUG
CommandHelper.ForceThrowDuringGenerateEnclavePackage(cmd);

// Execute again to exercise the session retry logic
using (SqlDataReader reader = cmd.ExecuteReader())
{
while (reader.Read())
{
Assert.Equal(customerId, (int)reader[0]);
}
reader.Close();
}

CommandHelper.ForceThrowDuringGenerateEnclavePackage(cmd);

// Execute again to exercise the async session retry logic
Task readAsyncTask2 = ReadAsync(cmd, values, CommandBehavior.Default);
readAsyncTask2.Wait();
#endif

// revert the CEK change to the CustomerId column
cmd.Parameters.Clear();
cmd.CommandText = string.Format(alterCekQueryFormatString, _tableName, table.columnEncryptionKey1.Name);
cmd.ExecuteNonQuery();
}

private void ExecuteQueryThatRequiresCustomKeyStoreProvider(SqlConnection connection)
{
using (SqlCommand command = CreateCommandThatRequiresCustomKeyStoreProvider(connection))
Expand Down
Loading