Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advance column index to avoid double clean. #2825

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6388,16 +6388,36 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value, SqlMetaDataPriv md,
{
if (stateObj is not null)
{
// call to decrypt column keys has failed. The data wont be decrypted.
// Not setting the value to false, forces the driver to look for column value.
// Packet received from Key Vault will throws invalid token header.
if (stateObj.HasPendingData)
// Throwing an exception here circumvents the normal pending data checks and cleanup processes,
// so we need to ensure the appropriate state. Increment the _nextColumnDataToRead index because
// we already read the encrypted column data; Otherwise we'll double count and attempt to drain a
// corresponding number of bytes a second time. We don't want the rest of the pending data to
// interfere with future operations, so we must drain it. Set HasPendingData to false to indicate
// that we successfully drained the data.

// The SqlDataReader also maintains a state called dataReady. We need to set that to false if we've
// drained the data off the connection. Otherwise, a consumer that catches the exception may
// continue to use the reader and will timeout waiting to read data that doesn't exist.

// Order matters here. Must increment column before draining data.
// Update state objects after draining data.



if (stateObj._readerState != null)
{
// Drain the pending data now if setting the HasPendingData to false.
// SqlDataReader.TryCloseInternal can not drain if HasPendingData = false.
DrainData(stateObj);
stateObj._readerState._nextColumnDataToRead++;
}

DrainData(stateObj);

if (stateObj._readerState != null)
{
stateObj._readerState._dataReady = false;
}

stateObj.HasPendingData = false;

}
throw SQL.ColumnDecryptionFailed(columnName, null, e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup;
using Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted
{
public sealed class ColumnDecryptErrorTests : IClassFixture<SQLSetupStrategyAzureKeyVault>, IDisposable
{
private SQLSetupStrategyAzureKeyVault fixture;

private readonly string tableName;

public ColumnDecryptErrorTests(SQLSetupStrategyAzureKeyVault context)
{
fixture = context;
tableName = fixture.ColumnDecryptErrorTestTable.Name;
}

/*
* This test ensures that column decryption errors and connection pooling play nicely together.
* When a decryption error is encountered, we expect the connection to be drained of data and
* properly reset before being returned to the pool. If this doesn't happen, then random bytes
* may be left in the connection's state. These can interfere with the next operation that utilizes
* the connection.
*
* We test that state is properly reset by triggering the same error condition twice. Routing column key discovery
* away from AKV toward a dummy key store achieves this. Each connection pulls from a pool of max
* size one to ensure we are using the same internal connection/socket both times. We expect to
* receive the "Failed to decrypt column" exception twice. If the state were not cleaned properly,
* the second error would be different because the TDS stream would be unintelligible.
*
* Finally, we assert that restoring the connection to AKV allows a successful query.
*/
[ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.IsTargetReadyForAeWithKeyStore), nameof(DataTestUtility.IsAKVSetupAvailable))]
[ClassData(typeof(TestQueries))]
public void TestCleanConnectionAfterDecryptFail(string connString, string selectQuery, int totalColumnsInSelect, string[] types)
{
// Arrange
Assert.False(string.IsNullOrWhiteSpace(selectQuery), "FAILED: select query should not be null or empty.");
Assert.True(totalColumnsInSelect <= 3, "FAILED: totalColumnsInSelect should <= 3.");

using (SqlConnection sqlConnection = new SqlConnection(connString))
{
sqlConnection.Open();

Table.DeleteData(tableName, sqlConnection);

Customer customer = new Customer(
45,
"Microsoft",
"Corporation");

DatabaseHelper.InsertCustomerData(sqlConnection, null, tableName, customer);
}


// Act - Trigger a column decrypt error on the connection
Dictionary<String, SqlColumnEncryptionKeyStoreProvider> keyStoreProviders = new()
{
{ "AZURE_KEY_VAULT", new DummyKeyStoreProvider() }
};

String poolEnabledConnString = new SqlConnectionStringBuilder(connString) { Pooling = true, MaxPoolSize = 1 }.ToString();

using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString))
{
sqlConnection.Open();
sqlConnection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(keyStoreProviders);

using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName),
sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled);

using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader();

Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows.");

while (sqlDataReader.Read())
{
var error = Assert.Throws<SqlException>(() => DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect));
Assert.Contains("Failed to decrypt column", error.Message);
}
}


// Assert
using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString))
{
sqlConnection.Open();
sqlConnection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(keyStoreProviders);

using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName),
sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled);
using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader();

Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows.");

while (sqlDataReader.Read())
{
var error = Assert.Throws<SqlException>(() => DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect));
Assert.Contains("Failed to decrypt column", error.Message);
}
}

using (SqlConnection sqlConnection = new SqlConnection(poolEnabledConnString))
{
sqlConnection.Open();

using SqlCommand sqlCommand = new SqlCommand(string.Format(selectQuery, tableName),
sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled);
using SqlDataReader sqlDataReader = sqlCommand.ExecuteReader();

Assert.True(sqlDataReader.HasRows, "FAILED: Select statement did not return any rows.");

while (sqlDataReader.Read())
{
DatabaseHelper.CompareResults(sqlDataReader, types, totalColumnsInSelect);
}
}
}


public void Dispose()
{
foreach (string connStrAE in DataTestUtility.AEConnStringsSetup)
{
using (SqlConnection sqlConnection = new SqlConnection(connStrAE))
{
sqlConnection.Open();
Table.DeleteData(fixture.ColumnDecryptErrorTestTable.Name, sqlConnection);
}
}
}

private sealed class DummyKeyStoreProvider : SqlColumnEncryptionKeyStoreProvider
{
public override byte[] DecryptColumnEncryptionKey(string masterKeyPath, string encryptionAlgorithm, byte[] encryptedColumnEncryptionKey)
{
// Must be 32 to match the key length expected for the 'AEAD_AES_256_CBC_HMAC_SHA256' algorithm
return new byte[32];
}

public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string encryptionAlgorithm, byte[] columnEncryptionKey)
{
return new byte[32];
}
}
}

public class TestQueries : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
{
foreach (string connStrAE in DataTestUtility.AEConnStrings)
{
yield return new object[] { connStrAE, @"select CustomerId, FirstName, LastName from [{0}] ", 3, new string[] { @"int", @"string", @"string" } };
yield return new object[] { connStrAE, @"select CustomerId, FirstName from [{0}] ", 2, new string[] { @"int", @"string" } };
yield return new object[] { connStrAE, @"select LastName from [{0}] ", 1, new string[] { @"string" } };
}
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class SQLSetupStrategy : IDisposable
public Table ApiTestTable { get; private set; }
public Table BulkCopyAEErrorMessageTestTable { get; private set; }
public Table BulkCopyAETestTable { get; private set; }
public Table ColumnDecryptErrorTestTable { get; private set; }
public Table SqlParameterPropertiesTable { get; private set; }
public Table DateOnlyTestTable { get; private set; }
public Table End2EndSmokeTable { get; private set; }
Expand Down Expand Up @@ -127,6 +128,9 @@ protected List<Table> CreateTables(IList<ColumnEncryptionKey> columnEncryptionKe
BulkCopyAETestTable = new BulkCopyAETestTable(GenerateUniqueName("BulkCopyAETestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]);
tables.Add(BulkCopyAETestTable);

ColumnDecryptErrorTestTable = new ColumnDecryptErrorTestTable(GenerateUniqueName("ColumnDecryptErrorTestTable"), columnEncryptionKeys[0], columnEncryptionKeys[1]);
tables.Add(ColumnDecryptErrorTestTable);

SqlParameterPropertiesTable = new SqlParameterPropertiesTable(GenerateUniqueName("SqlParameterPropertiesTable"));
tables.Add(SqlParameterPropertiesTable);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup
{
public class ColumnDecryptErrorTestTable : Table
{
private const string ColumnEncryptionAlgorithmName = @"AEAD_AES_256_CBC_HMAC_SHA_256";
public ColumnEncryptionKey columnEncryptionKey1;
public ColumnEncryptionKey columnEncryptionKey2;
private bool useDeterministicEncryption;

public ColumnDecryptErrorTestTable(string tableName, ColumnEncryptionKey columnEncryptionKey1, ColumnEncryptionKey columnEncryptionKey2, bool useDeterministicEncryption = false) : base(tableName)
{
this.columnEncryptionKey1 = columnEncryptionKey1;
this.columnEncryptionKey2 = columnEncryptionKey2;
this.useDeterministicEncryption = useDeterministicEncryption;
}

public override void Create(SqlConnection sqlConnection)
{
string encryptionType = useDeterministicEncryption ? "DETERMINISTIC" : DataTestUtility.EnclaveEnabled ? "RANDOMIZED" : "DETERMINISTIC";
string sql =
$@"CREATE TABLE [dbo].[{Name}]
(
[CustomerId] [int] ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey1.Name}], ENCRYPTION_TYPE = {encryptionType}, ALGORITHM = '{ColumnEncryptionAlgorithmName}'),
[FirstName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey2.Name}], ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = '{ColumnEncryptionAlgorithmName}'),
[LastName] [nvarchar](50) COLLATE Latin1_General_BIN2 ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey2.Name}], ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = '{ColumnEncryptionAlgorithmName}')
)";

using (SqlCommand command = sqlConnection.CreateCommand())
{
command.CommandText = sql;
command.ExecuteNonQuery();
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\BulkCopyAETestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\BulkCopyAEErrorMessageTestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\BulkCopyTruncationTables.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\ColumnDecryptErrorTestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\DateOnlyTestTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\SqlNullValuesTable.cs" />
<Compile Include="AlwaysEncrypted\TestFixtures\Setup\SqlParameterPropertiesTable.cs" />
Expand All @@ -73,6 +74,7 @@
</ItemGroup>
<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net6.0')) AND ('$(TestSet)' == '' OR '$(TestSet)' == 'AE')">
<Compile Include="AlwaysEncrypted\DateOnlyReadTests.cs" />
<Compile Include="AlwaysEncrypted\ColumnDecryptErrorTests.cs" />
</ItemGroup>
<ItemGroup Condition="'$(TestSet)' == '' OR '$(TestSet)' == '1'">
<Compile Include="SQL\AsyncTest\AsyncTimeoutTest.cs" />
Expand Down
Loading