Skip to content

Commit

Permalink
Port double clean fix to netfx (#2843)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdaigle authored Sep 10, 2024
1 parent 069c052 commit b1d3a82
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7223,16 +7223,34 @@ internal TdsOperationStatus TryReadSqlValue(SqlBuffer value,
{
if (stateObj is not null)
{
// call to decrypt column keys has failed. The data wont be decrypted.
// Not setting the value to false, forces the driver to look for column value.
// Packet received from Key Vault will throws invalid token header.
if (stateObj.HasPendingData)
// Throwing an exception here circumvents the normal pending data checks and cleanup processes,
// so we need to ensure the appropriate state. Increment the _nextColumnDataToRead index because
// we already read the encrypted column data; Otherwise we'll double count and attempt to drain a
// corresponding number of bytes a second time. We don't want the rest of the pending data to
// interfere with future operations, so we must drain it. Set HasPendingData to false to indicate
// that we successfully drained the data.

// The SqlDataReader also maintains a state called dataReady. We need to set that to false if we've
// drained the data off the connection. Otherwise, a consumer that catches the exception may
// continue to use the reader and will timeout waiting to read data that doesn't exist.

// Order matters here. Must increment column before draining data.
// Update state objects after draining data.

if (stateObj._readerState != null)
{
// Drain the pending data now if setting the HasPendingData to false.
// SqlDataReader.TryCloseInternal can not drain if HasPendingData = false.
DrainData(stateObj);
stateObj._readerState._nextColumnDataToRead++;
}

DrainData(stateObj);

if (stateObj._readerState != null)
{
stateObj._readerState._dataReady = false;
}

stateObj.HasPendingData = false;

}
throw SQL.ColumnDecryptionFailed(columnName, null, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Collections;
using System.Collections.Generic;
using Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup;
using Microsoft.Data.SqlClient.ManualTesting.Tests.SystemDataInternals;
using Xunit;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
<Compile Include="AlwaysEncrypted\ApiShould.cs" />
<Compile Include="AlwaysEncrypted\BulkCopyAE.cs" />
<Compile Include="AlwaysEncrypted\BulkCopyAEErrorMessage.cs" />
<Compile Include="AlwaysEncrypted\ColumnDecryptErrorTests.cs" />
<Compile Include="AlwaysEncrypted\End2EndSmokeTests.cs" />
<Compile Include="AlwaysEncrypted\SqlBulkCopyTruncation.cs" />
<Compile Include="AlwaysEncrypted\SqlNullValues.cs" />
Expand Down Expand Up @@ -74,7 +75,6 @@
</ItemGroup>
<ItemGroup Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net6.0')) AND ('$(TestSet)' == '' OR '$(TestSet)' == 'AE')">
<Compile Include="AlwaysEncrypted\DateOnlyReadTests.cs" />
<Compile Include="AlwaysEncrypted\ColumnDecryptErrorTests.cs" />
</ItemGroup>
<ItemGroup Condition="'$(TestSet)' == '' OR '$(TestSet)' == '1'">
<Compile Include="SQL\AsyncTest\AsyncTimeoutTest.cs" />
Expand Down

0 comments on commit b1d3a82

Please sign in to comment.