From 897789b9983c30f34caea1d113bf2e72d586e114 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Sun, 14 Jun 2020 21:46:26 +0100 Subject: [PATCH] fix SqlSequentialStream multipacket reads stalling and add covering test --- .../Microsoft/Data/SqlClient/SqlDataReader.cs | 19 +++- .../SQL/DataStreamTest/DataStreamTest.cs | 93 ++++++++++++++++++- 2 files changed, 106 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 10c624e97a..4896d8e9cb 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -4517,14 +4517,23 @@ private Task GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, b SetTimeout(_defaultTimeoutMilliseconds); // Try to read without any continuations (all the data may already be in the stateObj's buffer) - if (!TryGetBytesInternalSequential(context.columnIndex, context.buffer, context.index, context.length, out bytesRead)) + bool filledBuffer = context._reader.TryGetBytesInternalSequential( + context.columnIndex, + context.buffer, + context.index + context.totalBytesRead, + context.length - context.totalBytesRead, + out bytesRead + ); + context.totalBytesRead += bytesRead; + Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required"); + + if (!filledBuffer) { // This will be the 'state' for the callback - int totalBytesRead = bytesRead; - if (!isContinuation) { // This is the first async operation which is happening - setup the _currentTask and timeout + Debug.Assert(context._source==null, "context._source should not be non-null when trying to change to async"); source = new TaskCompletionSource(); Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); if (original != null) @@ -4532,7 +4541,7 @@ private Task GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, b source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); return source.Task; } - + context._source = source; // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) if (_cancelAsyncOnCloseToken.IsCancellationRequested) { @@ -4561,7 +4570,7 @@ private Task GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, b } else { - Debug.Assert(context._source != null, "context.source should not be null when continuing"); + Debug.Assert(context._source != null, "context._source should not be null when continuing"); // setup for cleanup/completing retryTask.ContinueWith( continuationAction: AAsyncCallContext.s_completeCallback, diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs index 0872e457d8..1eab7e65c3 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/DataStreamTest/DataStreamTest.cs @@ -37,6 +37,97 @@ public static void RunAllTestsForSingleServer_TCP() RunAllTestsForSingleServer(DataTestUtility.TCPConnectionString); } + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + public static async Task AsyncMultiPacketStreamRead() + { + int packetSize = 514; // force small packet size so we can quickly check multi packet reads + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString); + builder.PacketSize = 514; + string connectionString = builder.ToString(); + + byte[] inputData = null; + byte[] outputData = null; + string tableName = DataTestUtility.GetUniqueNameForSqlServer("data"); + + using (SqlConnection connection = new SqlConnection(connectionString)) + { + await connection.OpenAsync(); + + try + { + inputData = CreateBinaryTable(connection, tableName, packetSize); + + using (SqlCommand command = new SqlCommand($"SELECT foo FROM {tableName}", connection)) + using (SqlDataReader reader = await command.ExecuteReaderAsync(System.Data.CommandBehavior.SequentialAccess)) + { + await reader.ReadAsync(); + + using (Stream stream = reader.GetStream(0)) + using (CancellationTokenSource cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(60))) + using (MemoryStream memory = new MemoryStream(16 * 1024)) + { + await stream.CopyToAsync(memory, 37, cancellationTokenSource.Token); // prime number sized buffer to cause many cross packet partial reads + outputData = memory.ToArray(); + } + } + } + finally + { + DataTestUtility.DropTable(connection, tableName); + } + } + + Assert.NotNull(outputData); + int sharedLength = Math.Min(inputData.Length, outputData.Length); + if (sharedLength < outputData.Length) + { + Assert.False(true, $"output is longer than input, input={inputData.Length} bytes, output={outputData.Length} bytes"); + } + if (sharedLength < inputData.Length) + { + Assert.False(true, $"input is longer than output, input={inputData.Length} bytes, output={outputData.Length} bytes"); + } + for (int index = 0; index < sharedLength; index++) + { + if (inputData[index] != outputData[index]) // avoid formatting the output string unless there is a difference + { + Assert.True(false, $"input and output differ at index {index}, input={inputData[index]}, output={outputData[index]}"); + } + } + + } + + private static byte[] CreateBinaryTable(SqlConnection connection, string tableName, int packetSize) + { + byte[] pattern = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 }; + byte[] data = new byte[packetSize * 10]; + int position = 0; + while (position < data.Length) + { + int copyCount = Math.Min(pattern.Length, data.Length - position); + Array.Copy(pattern, 0, data, position, copyCount); + position += copyCount; + } + + using (var cmd = connection.CreateCommand()) + { + cmd.CommandText = $@" +IF OBJECT_ID('dbo.{tableName}', 'U') IS NOT NULL +DROP TABLE {tableName}; +CREATE TABLE {tableName} (id INT, foo VARBINARY(MAX)) +"; + cmd.ExecuteNonQuery(); + + cmd.CommandText = $"INSERT INTO {tableName} (id, foo) VALUES (@id, @foo)"; + cmd.Parameters.AddWithValue("id", 1); + cmd.Parameters.AddWithValue("foo", data); + cmd.ExecuteNonQuery(); + } + + return data; + } + private static void RunAllTestsForSingleServer(string connectionString, bool usingNamePipes = false) { RowBuffer(connectionString); @@ -1811,7 +1902,7 @@ private static void TestXEventsStreaming(string connectionString) SqlDataReader reader = cmd.ExecuteReader(System.Data.CommandBehavior.SequentialAccess); for (int i = 0; i < streamXeventCount && reader.Read(); i++) { - Int32 colType = reader.GetInt32(0); + int colType = reader.GetInt32(0); int cb = (int)reader.GetBytes(1, 0, null, 0, 0); byte[] bytes = new byte[cb];