From 2b3864c215a630cb896a2dabc20d3acef5585296 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Mon, 3 May 2021 01:17:20 +0100 Subject: [PATCH 1/5] netcore static lambda rework --- .../Data/SqlClient/SNI/SNILoadHandle.cs | 2 +- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 180 ++++++++++------- .../Microsoft/Data/SqlClient/SqlCommand.cs | 189 +++++++++--------- .../Microsoft/Data/SqlClient/SqlConnection.cs | 2 +- .../src/Microsoft/Data/SqlClient/SqlUtil.cs | 31 ++- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 4 +- .../Data/SqlClient/TdsParserStateObject.cs | 11 +- 7 files changed, 237 insertions(+), 182 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNILoadHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNILoadHandle.cs index 1fa6c58a3f..5f932703af 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNILoadHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNILoadHandle.cs @@ -14,7 +14,7 @@ internal class SNILoadHandle public static readonly SNILoadHandle SingletonInstance = new SNILoadHandle(); public readonly EncryptionOptions _encryptionOption = EncryptionOptions.OFF; - public ThreadLocal _lastError = new ThreadLocal(() => { return new SNIError(SNIProviders.INVALID_PROV, 0, TdsEnums.SNI_SUCCESS, string.Empty); }); + public ThreadLocal _lastError = new ThreadLocal(static () => new SNIError(SNIProviders.INVALID_PROV, 0, TdsEnums.SNI_SUCCESS, string.Empty)); private readonly uint _status = TdsEnums.SNI_SUCCESS; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index b0dce3a6f9..3442bc47c6 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -1078,14 +1078,19 @@ private Task ReadFromRowSourceAsync(CancellationToken cts) if (_isAsyncBulkCopy && _dbDataReaderRowSource != null) { // This will call ReadAsync for DbDataReader (for SqlDataReader it will be truly async read; for non-SqlDataReader it may block.) - return _dbDataReaderRowSource.ReadAsync(cts).ContinueWith((t) => - { - if (t.Status == TaskStatus.RanToCompletion) + return _dbDataReaderRowSource.ReadAsync(cts).ContinueWith( + static (Task task, object state) => { - _hasMoreRowToCopy = t.Result; - } - return t; - }, TaskScheduler.Default).Unwrap(); + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + if (task.Status == TaskStatus.RanToCompletion) + { + sqlBulkCopy._hasMoreRowToCopy = task.Result; + } + return task; + }, + state: this, + scheduler: TaskScheduler.Default + ).Unwrap(); } else { // This will call Read for DataRows, DataTable and IDataReader (this includes all IDataReader except DbDataReader) @@ -1940,7 +1945,7 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok { AsyncHelper.ContinueTaskWithState(writeTask, tcs, state: tcs, - onSuccess: state => ((TaskCompletionSource)state).SetResult(null) + onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) ); } }, ctoken); // We do not need to propagate exception, etc, from reconnect task, we just need to wait for it to finish. @@ -1948,7 +1953,7 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok } else { - AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, () => { throw SQL.CR_ReconnectTimeout(); }, rethrowExceptions: false); + AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, static () => throw SQL.CR_ReconnectTimeout(), rethrowExceptions: false); } } @@ -1969,27 +1974,32 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok if (resultTask != null) { finishedSynchronously = false; - return resultTask.ContinueWith((t) => - { - try + return resultTask.ContinueWith( + static (Task t, object state) => { - AbortTransaction(); // If there is one, on success transactions will be committed. - } - finally - { - _isBulkCopyingInProgress = false; - if (_parser != null) + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + try { - _parser._asyncWrite = false; + sqlBulkCopy.AbortTransaction(); // If there is one, on success transactions will be committed. } - if (_parserLock != null) + finally { - _parserLock.Release(); - _parserLock = null; + sqlBulkCopy._isBulkCopyingInProgress = false; + if (sqlBulkCopy._parser != null) + { + sqlBulkCopy._parser._asyncWrite = false; + } + if (sqlBulkCopy._parserLock != null) + { + sqlBulkCopy._parserLock.Release(); + sqlBulkCopy._parserLock = null; + } } - } - return t; - }, TaskScheduler.Default).Unwrap(); + return t; + }, + state: this, + scheduler: TaskScheduler.Default + ).Unwrap(); } return null; } @@ -2254,12 +2264,13 @@ private Task CopyColumnsAsync(int col, TaskCompletionSource source = nul // This is in its own method to avoid always allocating the lambda in CopyColumnsAsync private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource source, Task task, int i) { - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { - if (i + 1 < _sortedColumnMappings.Count) + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + if (i + 1 < sqlBulkCopy._sortedColumnMappings.Count) { - CopyColumnsAsync(i + 1, source); //continue from the next column + sqlBulkCopy.CopyColumnsAsync(i + 1, source); //continue from the next column } else { @@ -2401,8 +2412,9 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, } resultTask = source.Task; - AsyncHelper.ContinueTask(readTask, source, - onSuccess: () => CopyRowsAsync(i + 1, totalRows, cts, source) + AsyncHelper.ContinueTaskWithState(readTask, source, this, + onSuccess: (object state) => ((SqlBulkCopy)state).CopyRowsAsync(i + 1, totalRows, cts, source) + ); return resultTask; // Associated task will be completed when all rows are copied to server/exception/cancelled. } @@ -2412,20 +2424,21 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, source = source ?? new TaskCompletionSource(); resultTask = source.Task; - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { - CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + sqlBulkCopy.CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. - Task readTask = ReadFromRowSourceAsync(cts); + Task readTask = sqlBulkCopy.ReadFromRowSourceAsync(cts); if (readTask == null) { - CopyRowsAsync(i + 1, totalRows, cts, source); + sqlBulkCopy.CopyRowsAsync(i + 1, totalRows, cts, source); } else { - AsyncHelper.ContinueTask(readTask, source, - onSuccess: () => CopyRowsAsync(i + 1, totalRows, cts, source) + AsyncHelper.ContinueTaskWithState(readTask, source, sqlBulkCopy, + onSuccess: (object state2) => ((SqlBulkCopy)state2).CopyRowsAsync(i + 1, totalRows, cts, source) ); } } @@ -2498,14 +2511,15 @@ private Task CopyBatchesAsync(BulkCopySimpleResultSet internalResults, string up source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(commandTask, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(commandTask, source, this, + onSuccess: (object state) => { - Task continuedTask = CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); if (continuedTask == null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } } ); @@ -2562,18 +2576,19 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults, { // First time only source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { - Task continuedTask = CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); if (continuedTask == null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }, - onFailure: _ => CopyBatchesAsyncContinuedOnError(cleanupParser: false), - onCancellation: () => CopyBatchesAsyncContinuedOnError(cleanupParser: true) + onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false), + onCancellation: static (object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: true) ); return source.Task; @@ -2621,24 +2636,25 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(writeTask, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(writeTask, source, this, + onSuccess: (object state) => { + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; try { - RunParser(); - CommitTransaction(); + sqlBulkCopy.RunParser(); + sqlBulkCopy.CommitTransaction(); } catch (Exception) { - CopyBatchesAsyncContinuedOnError(cleanupParser: false); + sqlBulkCopy.CopyBatchesAsyncContinuedOnError(cleanupParser: false); throw; } // Always call back into CopyBatchesAsync - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); }, - onFailure: _ => CopyBatchesAsyncContinuedOnError(cleanupParser: false) + onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false) ); return source.Task; } @@ -2758,16 +2774,17 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int { source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; // Bulk copy task is completed at this moment. if (task.IsCanceled) { - _localColumnMappings = null; + sqlBulkCopy._localColumnMappings = null; try { - CleanUpStateObject(); + sqlBulkCopy.CleanUpStateObject(); } finally { @@ -2780,10 +2797,10 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int } else { - _localColumnMappings = null; + sqlBulkCopy._localColumnMappings = null; try { - CleanUpStateObject(isCancelRequested: false); + sqlBulkCopy.CleanUpStateObject(isCancelRequested: false); } finally { @@ -2889,19 +2906,27 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio TaskCompletionSource cancellableReconnectTS = new TaskCompletionSource(); if (cts.CanBeCanceled) { - regReconnectCancel = cts.Register(s => ((TaskCompletionSource)s).TrySetCanceled(), cancellableReconnectTS); + regReconnectCancel = cts.Register(static (object tcs) => ((TaskCompletionSource)tcs).TrySetCanceled(), cancellableReconnectTS); } AsyncHelper.ContinueTaskWithState(reconnectTask, cancellableReconnectTS, state: cancellableReconnectTS, - onSuccess: (state) => { ((TaskCompletionSource)state).SetResult(null); } + onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) ); // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection. - AsyncHelper.SetTimeoutException(cancellableReconnectTS, BulkCopyTimeout, - () => { return SQL.BulkLoadInvalidDestinationTable(_destinationTableName, SQL.CR_ReconnectTimeout()); }, CancellationToken.None); - AsyncHelper.ContinueTask(cancellableReconnectTS.Task, source, - onSuccess: () => + AsyncHelper.SetTimeoutExceptionWithState( + completion: cancellableReconnectTS, + timeout: BulkCopyTimeout, + state: _destinationTableName, + onFailure: static (object state) => SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()), + cancellationToken: CancellationToken.None + ); + AsyncHelper.ContinueTaskWithState( + task: cancellableReconnectTS.Task, + completion: source, + state: regReconnectCancel, + onSuccess: (object state) => { - regReconnectCancel.Dispose(); + ((CancellationTokenRegistration)state).Dispose(); if (_parserLock != null) { _parserLock.Release(); @@ -2911,8 +2936,8 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio _parserLock.Wait(canReleaseFromAnyThread: true); WriteToServerInternalRestAsync(cts, source); }, - onFailure: (e) => { regReconnectCancel.Dispose(); }, - onCancellation: () => { regReconnectCancel.Dispose(); }, + onFailure: static (Exception _, object state) => ((CancellationTokenRegistration)state).Dispose(), + onCancellation: static (object state) => ((CancellationTokenRegistration)state).Dispose(), exceptionConverter: (ex) => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex)); return; } @@ -2920,7 +2945,7 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio { try { - AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, () => { throw SQL.CR_ReconnectTimeout(); }); + AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, static () => throw SQL.CR_ReconnectTimeout()); } catch (SqlException ex) { @@ -2961,8 +2986,8 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio if (internalResultsTask != null) { - AsyncHelper.ContinueTask(internalResultsTask, source, - onSuccess: () => WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source) + AsyncHelper.ContinueTaskWithState(internalResultsTask, source, this, + onSuccess: (object state) => ((SqlBulkCopy)state).WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source) ); } else @@ -3034,16 +3059,17 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) else { Debug.Assert(_isAsyncBulkCopy, "Read must not return a Task in the Sync mode"); - AsyncHelper.ContinueTask(readTask, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(readTask, source, this, + onSuccess: (object state) => { - if (!_hasMoreRowToCopy) + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + if (!sqlBulkCopy._hasMoreRowToCopy) { source.SetResult(null); // No rows to copy! } else { - WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. + sqlBulkCopy.WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. } } ); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index 13a66ac1c4..53e6e94191 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -1244,7 +1244,7 @@ private IAsyncResult BeginExecuteNonQueryInternal(CommandBehavior behavior, Asyn if (callback != null) { globalCompletion.Task.ContinueWith( - (task, state) => ((AsyncCallback)state)(task), + static (task, state) => ((AsyncCallback)state)(task), state: callback ); } @@ -1753,7 +1753,7 @@ private IAsyncResult BeginExecuteXmlReaderInternal(CommandBehavior behavior, Asy if (callback != null) { localCompletion.Task.ContinueWith( - (task, state) => ((AsyncCallback)state)(task), + static (task, state) => ((AsyncCallback)state)(task), state: callback ); } @@ -2184,7 +2184,7 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC if (callback != null) { globalCompletion.Task.ContinueWith( - (task, state) => ((AsyncCallback)state)(task), + static (task, state) => ((AsyncCallback)state)(task), state: callback ); } @@ -2505,14 +2505,19 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok /// protected override Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) { - return ExecuteReaderAsync(behavior, cancellationToken).ContinueWith((result) => - { - if (result.IsFaulted) + return ExecuteReaderAsync(behavior, cancellationToken).ContinueWith( + static (Task result) => { - throw result.Exception.InnerException; - } - return result.Result; - }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default); + if (result.IsFaulted) + { + throw result.Exception.InnerException; + } + return result.Result; + }, + CancellationToken.None, + TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, + TaskScheduler.Default + ); } /// @@ -3250,7 +3255,7 @@ private Task RunExecuteNonQueryTds(string methodName, bool isAsync, int timeout, } else { - AsyncHelper.WaitForCompletion(reconnectTask, timeout, () => { throw SQL.CR_ReconnectTimeout(); }); + AsyncHelper.WaitForCompletion(reconnectTask, timeout, static () => throw SQL.CR_ReconnectTimeout()); timeout = TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart); } } @@ -3307,7 +3312,7 @@ private Task RunExecuteNonQueryTds(string methodName, bool isAsync, int timeout, private void RunExecuteNonQueryTdsSetupReconnnectContinuation(string methodName, bool isAsync, int timeout, bool asyncWrite, Task reconnectTask, long reconnectionStart, TaskCompletionSource completion) { CancellationTokenSource timeoutCTS = new CancellationTokenSource(); - AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token); + AsyncHelper.SetTimeoutException(completion, timeout, static () => SQL.CR_ReconnectTimeout(), timeoutCTS.Token); AsyncHelper.ContinueTask(reconnectTask, completion, () => { @@ -3326,7 +3331,7 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation(string methodName, { AsyncHelper.ContinueTaskWithState(subTask, completion, state: completion, - onSuccess: (state) => ((TaskCompletionSource)state).SetResult(null) + onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) ); } } @@ -3572,70 +3577,72 @@ private SqlDataReader GetParameterEncryptionDataReader(out Task returnTask, Task SqlDataReader describeParameterEncryptionDataReader, ReadOnlyDictionary<_SqlRPC, _SqlRPC> describeParameterEncryptionRpcOriginalRpcMap, bool describeParameterEncryptionNeeded) { - returnTask = AsyncHelper.CreateContinuationTask(fetchInputParameterEncryptionInfoTask, () => - { - bool processFinallyBlockAsync = true; - bool decrementAsyncCountInFinallyBlockAsync = true; - - RuntimeHelpers.PrepareConstrainedRegions(); - try + returnTask = AsyncHelper.CreateContinuationTaskWithState(fetchInputParameterEncryptionInfoTask, this, + (object state) => { - // Check for any exceptions on network write, before reading. - CheckThrowSNIException(); + SqlCommand command = (SqlCommand)state; + bool processFinallyBlockAsync = true; + bool decrementAsyncCountInFinallyBlockAsync = true; - // If it is async, then TryFetchInputParameterEncryptionInfo-> RunExecuteReaderTds would have incremented the async count. - // Decrement it when we are about to complete async execute reader. - SqlInternalConnectionTds internalConnectionTds = _activeConnection.GetOpenTdsConnection(); - if (internalConnectionTds != null) + RuntimeHelpers.PrepareConstrainedRegions(); + try { - internalConnectionTds.DecrementAsyncCount(); - decrementAsyncCountInFinallyBlockAsync = false; - } + // Check for any exceptions on network write, before reading. + command.CheckThrowSNIException(); - // Complete executereader. - describeParameterEncryptionDataReader = - CompleteAsyncExecuteReader(forDescribeParameterEncryption: true); - Debug.Assert(null == _stateObj, "non-null state object in PrepareForTransparentEncryption."); + // If it is async, then TryFetchInputParameterEncryptionInfo-> RunExecuteReaderTds would have incremented the async count. + // Decrement it when we are about to complete async execute reader. + SqlInternalConnectionTds internalConnectionTds = command._activeConnection.GetOpenTdsConnection(); + if (internalConnectionTds != null) + { + internalConnectionTds.DecrementAsyncCount(); + decrementAsyncCountInFinallyBlockAsync = false; + } - // Read the results of describe parameter encryption. - ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, - describeParameterEncryptionRpcOriginalRpcMap); + // Complete executereader. + describeParameterEncryptionDataReader = command.CompleteAsyncExecuteReader(forDescribeParameterEncryption: true); + Debug.Assert(null == command._stateObj, "non-null state object in PrepareForTransparentEncryption."); -#if DEBUG - // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. - if (_sleepAfterReadDescribeEncryptionParameterResults) + // Read the results of describe parameter encryption. + command.ReadDescribeEncryptionParameterResults(describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap); + + #if DEBUG + // Failpoint to force the thread to halt to simulate cancellation of SqlCommand. + if (_sleepAfterReadDescribeEncryptionParameterResults) + { + Thread.Sleep(10000); + } + #endif + } + catch (Exception e) { - Thread.Sleep(10000); + processFinallyBlockAsync = ADP.IsCatchableExceptionType(e); + throw; } -#endif - } - catch (Exception e) - { - processFinallyBlockAsync = ADP.IsCatchableExceptionType(e); - throw; - } - finally - { - PrepareTransparentEncryptionFinallyBlock(closeDataReader: processFinallyBlockAsync, - decrementAsyncCount: decrementAsyncCountInFinallyBlockAsync, - clearDataStructures: processFinallyBlockAsync, - wasDescribeParameterEncryptionNeeded: describeParameterEncryptionNeeded, - describeParameterEncryptionRpcOriginalRpcMap: describeParameterEncryptionRpcOriginalRpcMap, - describeParameterEncryptionDataReader: describeParameterEncryptionDataReader); - } - }, - onFailure: ((exception) => - { - if (_cachedAsyncState != null) + finally + { + command.PrepareTransparentEncryptionFinallyBlock(closeDataReader: processFinallyBlockAsync, + decrementAsyncCount: decrementAsyncCountInFinallyBlockAsync, + clearDataStructures: processFinallyBlockAsync, + wasDescribeParameterEncryptionNeeded: describeParameterEncryptionNeeded, + describeParameterEncryptionRpcOriginalRpcMap: describeParameterEncryptionRpcOriginalRpcMap, + describeParameterEncryptionDataReader: describeParameterEncryptionDataReader); + } + }, + onFailure: static (Exception exception, object state) => { - _cachedAsyncState.ResetAsyncState(); - } + SqlCommand command = (SqlCommand)state; + if (command._cachedAsyncState != null) + { + command._cachedAsyncState.ResetAsyncState(); + } - if (exception != null) - { - throw exception; + if (exception != null) + { + throw exception; + } } - })); + ); return describeParameterEncryptionDataReader; } @@ -4450,40 +4457,35 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( { long parameterEncryptionStart = ADP.TimerCurrent(); TaskCompletionSource completion = new TaskCompletionSource(); - AsyncHelper.ContinueTask(describeParameterEncryptionTask, completion, - () => + AsyncHelper.ContinueTaskWithState(describeParameterEncryptionTask, completion, this, + (object state) => { + SqlCommand command = (SqlCommand)state; Task subTask = null; - GenerateEnclavePackage(); - RunExecuteReaderTds(cmdBehavior, runBehavior, returnStream, isAsync, TdsParserStaticMethods.GetRemainingTimeout(timeout, parameterEncryptionStart), out subTask, asyncWrite, inRetry, ds); + command.GenerateEnclavePackage(); + command.RunExecuteReaderTds(cmdBehavior, runBehavior, returnStream, isAsync, TdsParserStaticMethods.GetRemainingTimeout(timeout, parameterEncryptionStart), out subTask, asyncWrite, inRetry, ds); if (subTask == null) { completion.SetResult(null); } else { - AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null)); + AsyncHelper.ContinueTaskWithState(subTask, completion, completion, static (object state) => ((TaskCompletionSource)state).SetResult(null)); } }, - onFailure: ((exception) => + onFailure: static (Exception exception, object state) => { - if (_cachedAsyncState != null) - { - _cachedAsyncState.ResetAsyncState(); - } + ((SqlCommand)state)._cachedAsyncState?.ResetAsyncState(); if (exception != null) { throw exception; } - }), - onCancellation: (() => + }, + onCancellation: static (object state) => { - if (_cachedAsyncState != null) - { - _cachedAsyncState.ResetAsyncState(); - } - }) - ); + ((SqlCommand)state)._cachedAsyncState?.ResetAsyncState(); + } + ); task = completion.Task; return ds; } @@ -4556,7 +4558,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi } else { - AsyncHelper.WaitForCompletion(reconnectTask, timeout, () => { throw SQL.CR_ReconnectTimeout(); }); + AsyncHelper.WaitForCompletion(reconnectTask, timeout, static () => throw SQL.CR_ReconnectTimeout()); timeout = TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart); } } @@ -4770,15 +4772,18 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi private Task RunExecuteReaderTdsSetupContinuation(RunBehavior runBehavior, SqlDataReader ds, string optionSettings, Task writeTask) { - Task task = AsyncHelper.CreateContinuationTask(writeTask, - onSuccess: () => + Task task = AsyncHelper.CreateContinuationTaskWithState( + task: writeTask, + state: _activeConnection, + onSuccess: (object state) => { - _activeConnection.GetOpenTdsConnection(); // it will throw if connection is closed + SqlConnection sqlConnection = (SqlConnection)state; + sqlConnection.GetOpenTdsConnection(); // it will throw if connection is closed cachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings); }, - onFailure: (exc) => + onFailure: static (Exception exc, object state) => { - _activeConnection.GetOpenTdsConnection().DecrementAsyncCount(); + ((SqlConnection)state).GetOpenTdsConnection().DecrementAsyncCount(); } ); return task; @@ -4788,7 +4793,7 @@ private Task RunExecuteReaderTdsSetupContinuation(RunBehavior runBehavior, SqlDa private void RunExecuteReaderTdsSetupReconnectContinuation(CommandBehavior cmdBehavior, RunBehavior runBehavior, bool returnStream, bool isAsync, int timeout, bool asyncWrite, bool inRetry, SqlDataReader ds, Task reconnectTask, long reconnectionStart, TaskCompletionSource completion) { CancellationTokenSource timeoutCTS = new CancellationTokenSource(); - AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token); + AsyncHelper.SetTimeoutException(completion, timeout, static () => SQL.CR_ReconnectTimeout(), timeoutCTS.Token); AsyncHelper.ContinueTask(reconnectTask, completion, () => { @@ -4808,7 +4813,7 @@ private void RunExecuteReaderTdsSetupReconnectContinuation(CommandBehavior cmdBe { AsyncHelper.ContinueTaskWithState(subTask, completion, state: completion, - onSuccess: (state) => ((TaskCompletionSource)state).SetResult(null) + onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) ); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 7cdd3b56d5..239f5d9d7c 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -2139,7 +2139,7 @@ internal Task RegisterForConnectionCloseNotification(Task outerTask, ob { // Connection exists, schedule removal, will be added to ref collection after calling ValidateAndReconnect return outerTask.ContinueWith( - continuationFunction: (task, state) => + continuationFunction: static (task, state) => { Tuple parameters = (Tuple)state; SqlConnection connection = parameters.Item1; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs index 8a0f5293b4..75391a7ec5 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -34,7 +34,7 @@ internal static Task CreateContinuationTask(Task task, Action onSuccess, Action< TaskCompletionSource completion = new TaskCompletionSource(); ContinueTaskWithState(task, completion, state: Tuple.Create(onSuccess, onFailure, completion), - onSuccess: (state) => + onSuccess: static (object state) => { var parameters = (Tuple, TaskCompletionSource>)state; Action success = parameters.Item1; @@ -42,7 +42,7 @@ internal static Task CreateContinuationTask(Task task, Action onSuccess, Action< success(); taskCompletionSource.SetResult(null); }, - onFailure: (exception, state) => + onFailure: static (Exception exception, object state) => { var parameters = (Tuple, TaskCompletionSource>)state; Action failure = parameters.Item2; @@ -64,7 +64,7 @@ internal static Task CreateContinuationTaskWithState(Task task, object state, Ac { var completion = new TaskCompletionSource(); ContinueTaskWithState(task, completion, state, - onSuccess: (continueState) => + onSuccess: (object continueState) => { onSuccess(continueState); completion.SetResult(null); @@ -205,11 +205,8 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout } if (!task.IsCompleted) { - task.ContinueWith(t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception - if (onTimeout != null) - { - onTimeout(); - } + task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception + onTimeout?.Invoke(); } } @@ -226,6 +223,24 @@ internal static void SetTimeoutException(TaskCompletionSource completion }); } } + + internal static void SetTimeoutExceptionWithState(TaskCompletionSource completion, int timeout, object state, Func onFailure, CancellationToken cancellationToken) + { + if (timeout > 0) + { + Task.Delay(timeout * 1000, cancellationToken).ContinueWith( + (task, state) => + { + if (!task.IsCanceled && !completion.Task.IsCompleted) + { + completion.TrySetException(onFailure(state)); + } + }, + state: state, + cancellationToken: CancellationToken.None + ); + } + } } internal static class SQL diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index e0ebfdf669..81db5847f4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -8833,7 +8833,7 @@ internal Task TdsExecuteSQLBatch(string text, int timeout, SqlNotificationReques bool taskReleaseConnectionLock = releaseConnectionLock; releaseConnectionLock = false; return executeTask.ContinueWith( - (task, state) => + static (Task task, object state) => { Debug.Assert(!task.IsCanceled, "Task should not be canceled"); var parameters = (Tuple)state; @@ -9059,7 +9059,7 @@ internal Task TdsExecuteRPC(SqlCommand cmd, _SqlRPC[] rpcArray, int timeout, boo if (releaseConnectionLock) { task.ContinueWith( - (_, state) => ((SqlInternalConnectionTds)state)._parserLock.Release(), + static (Task _, object state) => ((SqlInternalConnectionTds)state)._parserLock.Release(), state: _connHandler, TaskScheduler.Default ); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index f3299816ed..2fab84baa5 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -1015,7 +1015,16 @@ internal Task ExecuteFlush() } else { - return AsyncHelper.CreateContinuationTask(writePacketTask, () => { HasPendingData = true; _messageStatus = 0; }); + return AsyncHelper.CreateContinuationTaskWithState( + task: writePacketTask, + state: this, + onSuccess: static (object state) => + { + TdsParserStateObject stateObject = (TdsParserStateObject)state; + stateObject.HasPendingData = true; + stateObject._messageStatus = 0; + } + ); } } } From 5a9e99792055a8c96f81ae5693b463e742b654af Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Mon, 3 May 2021 21:41:18 +0100 Subject: [PATCH 2/5] minor netcore api cleanup --- .../src/Microsoft/Data/SqlClient/SqlUtil.cs | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs index 75391a7ec5..b4146492ce 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -81,12 +81,12 @@ internal static Task CreateContinuationTask(Task task, Action on } internal static void ContinueTask(Task task, - TaskCompletionSource completion, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null - ) + TaskCompletionSource completion, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null + ) { task.ContinueWith( tsk => @@ -145,7 +145,7 @@ internal static void ContinueTaskWithState(Task task, ) { task.ContinueWith( - tsk => + (Task tsk, object state2) => { if (tsk.Exception != null) { @@ -156,7 +156,7 @@ internal static void ContinueTaskWithState(Task task, } try { - onFailure?.Invoke(exc, state); + onFailure?.Invoke(exc, state2); } finally { @@ -167,7 +167,7 @@ internal static void ContinueTaskWithState(Task task, { try { - onCancellation?.Invoke(state); + onCancellation?.Invoke(state2); } finally { @@ -178,14 +178,16 @@ internal static void ContinueTaskWithState(Task task, { try { - onSuccess(state); + onSuccess(state2); } catch (Exception e) { completion.SetException(e); } } - }, TaskScheduler.Default + }, + state: state, + scheduler: TaskScheduler.Default ); } @@ -210,17 +212,19 @@ internal static void WaitForCompletion(Task task, int timeout, Action onTimeout } } - internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func exc, CancellationToken ctoken) + internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) { if (timeout > 0) { - Task.Delay(timeout * 1000, ctoken).ContinueWith((tsk) => - { - if (!tsk.IsCanceled && !completion.Task.IsCompleted) + Task.Delay(timeout * 1000, ctoken).ContinueWith( + (Task task) => { - completion.TrySetException(exc()); + if (!task.IsCanceled && !completion.Task.IsCompleted) + { + completion.TrySetException(onFailure()); + } } - }); + ); } } @@ -229,7 +233,7 @@ internal static void SetTimeoutExceptionWithState(TaskCompletionSource c if (timeout > 0) { Task.Delay(timeout * 1000, cancellationToken).ContinueWith( - (task, state) => + (Task task, object state) => { if (!task.IsCanceled && !completion.Task.IsCompleted) { From e28eec438b27b3521a6585a0dc94a31b1e19ff20 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Mon, 3 May 2021 21:50:05 +0100 Subject: [PATCH 3/5] netfx static lambda rework --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 200 +++++++++--------- .../Microsoft/Data/SqlClient/SqlCommand.cs | 76 +++---- .../src/Microsoft/Data/SqlClient/SqlUtil.cs | 152 ++++++++++++- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 27 ++- .../Data/SqlClient/TdsParserStateObject.cs | 5 +- 5 files changed, 299 insertions(+), 161 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 19ad9e6b53..3a2f91e7e7 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -1118,14 +1118,18 @@ private Task ReadFromRowSourceAsync(CancellationToken cts) if (_isAsyncBulkCopy && _dbDataReaderRowSource != null) { // This will call ReadAsync for DbDataReader (for SqlDataReader it will be truly async read; for non-SqlDataReader it may block.) - return _dbDataReaderRowSource.ReadAsync(cts).ContinueWith((t) => - { - if (t.Status == TaskStatus.RanToCompletion) + return _dbDataReaderRowSource.ReadAsync(cts).ContinueWith( + static (Task task, object state) => { - _hasMoreRowToCopy = t.Result; - } - return t; - }, TaskScheduler.Default).Unwrap(); + if (task.Status == TaskStatus.RanToCompletion) + { + ((SqlBulkCopy)state)._hasMoreRowToCopy = task.Result; + } + return task; + }, + state: this, + scheduler: TaskScheduler.Default + ).Unwrap(); } else { // This will call Read for DataRows, DataTable and IDataReader (this includes all IDataReader except DbDataReader) @@ -2023,8 +2027,8 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok } else { - AsyncHelper.ContinueTask(writeTask, tcs, - onSuccess: () => tcs.SetResult(null) + AsyncHelper.ContinueTaskWithState(writeTask, tcs, tcs, + onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) ); } }, ctoken); // We do not need to propagate exception, etc, from reconnect task, we just need to wait for it to finish. @@ -2032,7 +2036,7 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok } else { - AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, () => { throw SQL.CR_ReconnectTimeout(); }, rethrowExceptions: false); + AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, static () => throw SQL.CR_ReconnectTimeout(), rethrowExceptions: false); } } @@ -2066,27 +2070,32 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok if (resultTask != null) { finishedSynchronously = false; - return resultTask.ContinueWith((t) => - { - try - { - AbortTransaction(); // if there is one, on success transactions will be commited - } - finally + return resultTask.ContinueWith( + static (Task task, object state) => { - _isBulkCopyingInProgress = false; - if (_parser != null) + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + try { - _parser._asyncWrite = false; + sqlBulkCopy.AbortTransaction(); // if there is one, on success transactions will be commited } - if (_parserLock != null) + finally { - _parserLock.Release(); - _parserLock = null; + sqlBulkCopy._isBulkCopyingInProgress = false; + if (sqlBulkCopy._parser != null) + { + sqlBulkCopy._parser._asyncWrite = false; + } + if (sqlBulkCopy._parserLock != null) + { + sqlBulkCopy._parserLock.Release(); + sqlBulkCopy._parserLock = null; + } } - } - return t; - }, TaskScheduler.Default).Unwrap(); + return task; + }, + state: this, + scheduler: TaskScheduler.Default + ).Unwrap(); } return null; } @@ -2360,12 +2369,13 @@ private Task CopyColumnsAsync(int col, TaskCompletionSource source = nul // This is in its own method to avoid always allocating the lambda in CopyColumnsAsync private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource source, Task task, int i) { - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { - if (i + 1 < _sortedColumnMappings.Count) + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + if (i + 1 < sqlBulkCopy._sortedColumnMappings.Count) { - CopyColumnsAsync(i + 1, source); //continue from the next column + sqlBulkCopy.CopyColumnsAsync(i + 1, source); //continue from the next column } else { @@ -2468,26 +2478,6 @@ private Task CheckForCancellation(CancellationToken cts, TaskCompletionSource ContinueTaskPend(Task task, TaskCompletionSource source, Func> action) - { - if (task == null) - { - return action(); - } - else - { - Debug.Assert(source != null, "source should already be initialized if task is not null"); - AsyncHelper.ContinueTask(task, source, - onSuccess: () => - { - TaskCompletionSource newSource = action(); - Debug.Assert(newSource == null, "Shouldn't create a new source when one already exists"); - } - ); - } - return null; - } - // Copies all the rows in a batch. // Maintains state machine with state variable: rowSoFar. // Returned Task could be null in two cases: (1) _isAsyncBulkCopy == false, or (2) _isAsyncBulkCopy == true but all async writes finished synchronously. @@ -2528,8 +2518,8 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, } resultTask = source.Task; - AsyncHelper.ContinueTask(readTask, source, - onSuccess: () => CopyRowsAsync(i + 1, totalRows, cts, source), + AsyncHelper.ContinueTaskWithState(readTask, source, this, + onSuccess: (object state) => ((SqlBulkCopy)state).CopyRowsAsync(i + 1, totalRows, cts, source), connectionToDoom: _connection.GetOpenTdsConnection() ); return resultTask; // Associated task will be completed when all rows are copied to server/exception/cancelled. @@ -2540,20 +2530,21 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, source = source ?? new TaskCompletionSource(); resultTask = source.Task; - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { - CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + sqlBulkCopy.CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. - Task readTask = ReadFromRowSourceAsync(cts); + Task readTask = sqlBulkCopy.ReadFromRowSourceAsync(cts); if (readTask == null) { - CopyRowsAsync(i + 1, totalRows, cts, source); + sqlBulkCopy.CopyRowsAsync(i + 1, totalRows, cts, source); } else { - AsyncHelper.ContinueTask(readTask, source, - onSuccess: () => CopyRowsAsync(i + 1, totalRows, cts, source), + AsyncHelper.ContinueTaskWithState(readTask, source, sqlBulkCopy, + onSuccess: (object state2) => ((SqlBulkCopy)state2).CopyRowsAsync(i + 1, totalRows, cts, source), connectionToDoom: _connection.GetOpenTdsConnection() ); } @@ -2628,14 +2619,15 @@ private Task CopyBatchesAsync(BulkCopySimpleResultSet internalResults, string up source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(commandTask, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(commandTask, source, this, + onSuccess: (object state) => { - Task continuedTask = CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); if (continuedTask == null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }, connectionToDoom: _connection.GetOpenTdsConnection() @@ -2693,20 +2685,21 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults, { // First time only source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { - Task continuedTask = CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); if (continuedTask == null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }, - connectionToDoom: _connection.GetOpenTdsConnection(), - onFailure: _ => CopyBatchesAsyncContinuedOnError(cleanupParser: false), - onCancellation: () => CopyBatchesAsyncContinuedOnError(cleanupParser: true) - ); + onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false), + onCancellation: (object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: true) +, + connectionToDoom: _connection.GetOpenTdsConnection()); return source.Task; } @@ -2753,25 +2746,26 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(writeTask, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(writeTask, source, this, + onSuccess: (object state) => { + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; try { - RunParser(); - CommitTransaction(); + sqlBulkCopy.RunParser(); + sqlBulkCopy.CommitTransaction(); } catch (Exception) { - CopyBatchesAsyncContinuedOnError(cleanupParser: false); + sqlBulkCopy.CopyBatchesAsyncContinuedOnError(cleanupParser: false); throw; } // Always call back into CopyBatchesAsync - CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); }, - connectionToDoom: _connection.GetOpenTdsConnection(), - onFailure: _ => CopyBatchesAsyncContinuedOnError(cleanupParser: false) + onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false), + connectionToDoom: _connection.GetOpenTdsConnection() ); return source.Task; } @@ -2906,16 +2900,17 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int { source = new TaskCompletionSource(); } - AsyncHelper.ContinueTask(task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(task, source, this, + onSuccess: (object state) => { + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; // Bulk copy task is completed at this moment. if (task.IsCanceled) { - _localColumnMappings = null; + sqlBulkCopy._localColumnMappings = null; try { - CleanUpStateObject(); + sqlBulkCopy.CleanUpStateObject(); } finally { @@ -2928,10 +2923,10 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int } else { - _localColumnMappings = null; + sqlBulkCopy._localColumnMappings = null; try { - CleanUpStateObject(isCancelRequested: false); + sqlBulkCopy.CleanUpStateObject(isCancelRequested: false); } finally { @@ -3034,22 +3029,22 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio { if (_isAsyncBulkCopy) { - CancellationTokenRegistration regReconnectCancel = new CancellationTokenRegistration(); + StrongBox regReconnectCancel = new StrongBox(new CancellationTokenRegistration()); TaskCompletionSource cancellableReconnectTS = new TaskCompletionSource(); if (cts.CanBeCanceled) { - regReconnectCancel = cts.Register(() => cancellableReconnectTS.TrySetCanceled()); + regReconnectCancel.Value = cts.Register(() => cancellableReconnectTS.TrySetCanceled()); } - AsyncHelper.ContinueTask(reconnectTask, cancellableReconnectTS, - onSuccess: () => { cancellableReconnectTS.SetResult(null); } + AsyncHelper.ContinueTaskWithState(reconnectTask, cancellableReconnectTS, cancellableReconnectTS, + onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) ); // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection AsyncHelper.SetTimeoutException(cancellableReconnectTS, BulkCopyTimeout, () => { return SQL.BulkLoadInvalidDestinationTable(_destinationTableName, SQL.CR_ReconnectTimeout()); }, CancellationToken.None); - AsyncHelper.ContinueTask(cancellableReconnectTS.Task, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(cancellableReconnectTS.Task, source, regReconnectCancel, + onSuccess: (object state) => { - regReconnectCancel.Dispose(); + ((StrongBox)state).Value.Dispose(); if (_parserLock != null) { _parserLock.Release(); @@ -3060,9 +3055,9 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio WriteToServerInternalRestAsync(cts, source); }, connectionToAbort: _connection, - onFailure: (e) => { regReconnectCancel.Dispose(); }, - onCancellation: () => { regReconnectCancel.Dispose(); }, - exceptionConverter: (ex) => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex) + onFailure: static (Exception _, object state) => ((StrongBox)state).Value.Dispose(), + onCancellation: static (object state) => ((StrongBox)state).Value.Dispose(), + exceptionConverter: (Exception ex, object state) => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex) ); return; } @@ -3070,7 +3065,7 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio { try { - AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, () => { throw SQL.CR_ReconnectTimeout(); }); + AsyncHelper.WaitForCompletion(reconnectTask, BulkCopyTimeout, static () => throw SQL.CR_ReconnectTimeout()); } catch (SqlException ex) { @@ -3111,8 +3106,8 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio if (internalResultsTask != null) { - AsyncHelper.ContinueTask(internalResultsTask, source, - onSuccess: () => WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source), + AsyncHelper.ContinueTaskWithState(internalResultsTask, source, this, + onSuccess: (object state) => ((SqlBulkCopy)state).WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source), connectionToDoom: _connection.GetOpenTdsConnection() ); } @@ -3185,16 +3180,17 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) else { Debug.Assert(_isAsyncBulkCopy, "Read must not return a Task in the Sync mode"); - AsyncHelper.ContinueTask(readTask, source, - onSuccess: () => + AsyncHelper.ContinueTaskWithState(readTask, source, this, + onSuccess: (object state) => { - if (!_hasMoreRowToCopy) + SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + if (!sqlBulkCopy._hasMoreRowToCopy) { source.SetResult(null); // No rows to copy! } else { - WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. + sqlBulkCopy.WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. } }, connectionToDoom: _connection.GetOpenTdsConnection() diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index e962a5695c..15a014004d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -1541,7 +1541,7 @@ private IAsyncResult BeginExecuteNonQueryInternal(CommandBehavior behavior, Asyn Task execNQ = InternalExecuteNonQuery(localCompletion, ADP.BeginExecuteNonQuery, false, timeout, out usedCache, asyncWrite, inRetry: inRetry); if (execNQ != null) { - AsyncHelper.ContinueTask(execNQ, localCompletion, () => BeginExecuteNonQueryInternalReadStage(localCompletion)); + AsyncHelper.ContinueTaskWithState(execNQ, localCompletion, this, (object state) => ((SqlCommand)state).BeginExecuteNonQueryInternalReadStage(localCompletion)); } else { @@ -2151,7 +2151,7 @@ private IAsyncResult BeginExecuteXmlReaderInternal(CommandBehavior behavior, Asy if (writeTask != null) { - AsyncHelper.ContinueTask(writeTask, localCompletion, () => BeginExecuteXmlReaderInternalReadStage(localCompletion)); + AsyncHelper.ContinueTaskWithState(writeTask, localCompletion, this, (object state) => ((SqlCommand)state).BeginExecuteXmlReaderInternalReadStage(localCompletion)); } else { @@ -2640,7 +2640,7 @@ private IAsyncResult BeginExecuteReaderInternal(CommandBehavior behavior, AsyncC if (writeTask != null) { - AsyncHelper.ContinueTask(writeTask, localCompletion, () => BeginExecuteReaderInternalReadStage(localCompletion)); + AsyncHelper.ContinueTaskWithState(writeTask, localCompletion, this, (object state) => ((SqlCommand)state).BeginExecuteReaderInternalReadStage(localCompletion)); } else { @@ -2980,14 +2980,19 @@ private Task InternalExecuteNonQueryAsync(CancellationToken cancellationTok /// protected override Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) { - return ExecuteReaderAsync(behavior, cancellationToken).ContinueWith((result) => - { - if (result.IsFaulted) + return ExecuteReaderAsync(behavior, cancellationToken).ContinueWith( + static (Task result) => { - throw result.Exception.InnerException; - } - return result.Result; - }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default); + if (result.IsFaulted) + { + throw result.Exception.InnerException; + } + return result.Result; + }, + CancellationToken.None, + TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, + TaskScheduler.Default + ); } private Task InternalExecuteReaderWithRetryAsync(CommandBehavior behavior, CancellationToken cancellationToken) @@ -3732,7 +3737,7 @@ private Task RunExecuteNonQueryTds(string methodName, bool async, int timeout, b _activeConnection.RegisterWaitingForReconnect(completion.Task); _reconnectionCompletionSource = completion; CancellationTokenSource timeoutCTS = new CancellationTokenSource(); - AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token); + AsyncHelper.SetTimeoutException(completion, timeout, static () => SQL.CR_ReconnectTimeout(), timeoutCTS.Token); AsyncHelper.ContinueTask(reconnectTask, completion, () => { @@ -3749,14 +3754,16 @@ private Task RunExecuteNonQueryTds(string methodName, bool async, int timeout, b } else { - AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null)); + AsyncHelper.ContinueTaskWithState(subTask, completion, completion, static (object state) => ((TaskCompletionSource)state).SetResult(null)); } - }, connectionToAbort: _activeConnection); + }, + connectionToAbort: _activeConnection + ); return completion.Task; } else { - AsyncHelper.WaitForCompletion(reconnectTask, timeout, () => { throw SQL.CR_ReconnectTimeout(); }); + AsyncHelper.WaitForCompletion(reconnectTask, timeout, static () => throw SQL.CR_ReconnectTimeout()); timeout = TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart); } } @@ -5100,39 +5107,32 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( { long parameterEncryptionStart = ADP.TimerCurrent(); TaskCompletionSource completion = new TaskCompletionSource(); - AsyncHelper.ContinueTask(describeParameterEncryptionTask, completion, - () => + AsyncHelper.ContinueTaskWithState(describeParameterEncryptionTask, completion, this, + (object state) => { + SqlCommand command = (SqlCommand)state; Task subTask = null; - GenerateEnclavePackage(); - RunExecuteReaderTds(cmdBehavior, runBehavior, returnStream, async, TdsParserStaticMethods.GetRemainingTimeout(timeout, parameterEncryptionStart), out subTask, asyncWrite, inRetry, ds); + command.GenerateEnclavePackage(); + command.RunExecuteReaderTds(cmdBehavior, runBehavior, returnStream, async, TdsParserStaticMethods.GetRemainingTimeout(timeout, parameterEncryptionStart), out subTask, asyncWrite, inRetry, ds); if (subTask == null) { completion.SetResult(null); } else { - AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null)); + AsyncHelper.ContinueTaskWithState(subTask, completion, completion, static (object state2) => ((TaskCompletionSource)state2).SetResult(null)); } - }, connectionToDoom: null, - onFailure: ((exception) => + }, + onFailure: static (Exception exception, object state) => { - if (_cachedAsyncState != null) - { - _cachedAsyncState.ResetAsyncState(); - } + ((SqlCommand)state)._cachedAsyncState?.ResetAsyncState(); if (exception != null) { throw exception; } - }), - onCancellation: (() => - { - if (_cachedAsyncState != null) - { - _cachedAsyncState.ResetAsyncState(); - } - }), + }, + onCancellation: static (object state) => ((SqlCommand)state)._cachedAsyncState?.ResetAsyncState(), + connectionToDoom: null, connectionToAbort: _activeConnection); task = completion.Task; return ds; @@ -5201,7 +5201,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi _activeConnection.RegisterWaitingForReconnect(completion.Task); _reconnectionCompletionSource = completion; CancellationTokenSource timeoutCTS = new CancellationTokenSource(); - AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token); + AsyncHelper.SetTimeoutException(completion, timeout, static () => SQL.CR_ReconnectTimeout(), timeoutCTS.Token); AsyncHelper.ContinueTask(reconnectTask, completion, () => { @@ -5219,15 +5219,17 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi } else { - AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null)); + AsyncHelper.ContinueTaskWithState(subTask, completion, completion, static (object state) => ((TaskCompletionSource)state).SetResult(null)); } - }, connectionToAbort: _activeConnection); + }, + connectionToAbort: _activeConnection + ); task = completion.Task; return ds; } else { - AsyncHelper.WaitForCompletion(reconnectTask, timeout, () => { throw SQL.CR_ReconnectTimeout(); }); + AsyncHelper.WaitForCompletion(reconnectTask, timeout, static () => throw SQL.CR_ReconnectTimeout()); timeout = TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart); } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs index bdec925681..1fd3767bbe 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -33,8 +33,14 @@ internal static Task CreateContinuationTask(Task task, Action onSuccess, SqlInte { TaskCompletionSource completion = new TaskCompletionSource(); ContinueTask(task, completion, - () => { onSuccess(); completion.SetResult(null); }, - connectionToDoom, onFailure); + onSuccess: () => + { + onSuccess(); + completion.SetResult(null); + }, + onFailure: onFailure, + connectionToDoom: connectionToDoom + ); return completion.Task; } } @@ -45,14 +51,14 @@ internal static Task CreateContinuationTask(Task task, Action on } internal static void ContinueTask(Task task, - TaskCompletionSource completion, - Action onSuccess, - SqlInternalConnectionTds connectionToDoom = null, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null, - SqlConnection connectionToAbort = null - ) + TaskCompletionSource completion, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null, + SqlInternalConnectionTds connectionToDoom = null, + SqlConnection connectionToAbort = null + ) { Debug.Assert((connectionToAbort == null) || (connectionToDoom == null), "Should not specify both connectionToDoom and connectionToAbort"); task.ContinueWith( @@ -172,6 +178,132 @@ internal static void ContinueTask(Task task, ); } + internal static void ContinueTaskWithState(Task task, + TaskCompletionSource completion, + object state, + Action onSuccess, + Action onFailure = null, + Action onCancellation = null, + Func exceptionConverter = null, + SqlInternalConnectionTds connectionToDoom = null, + SqlConnection connectionToAbort = null + ) + { + Debug.Assert((connectionToAbort == null) || (connectionToDoom == null), "Should not specify both connectionToDoom and connectionToAbort"); + task.ContinueWith( + (Task tsk, object state) => + { + if (tsk.Exception != null) + { + Exception exc = tsk.Exception.InnerException; + if (exceptionConverter != null) + { + exc = exceptionConverter(exc, state); + } + try + { + onFailure?.Invoke(exc, state); + } + finally + { + completion.TrySetException(exc); + } + } + else if (tsk.IsCanceled) + { + try + { + onCancellation?.Invoke(state); + } + finally + { + completion.TrySetCanceled(); + } + } + else + { + if (connectionToDoom != null || connectionToAbort != null) + { + RuntimeHelpers.PrepareConstrainedRegions(); + try + { +#if DEBUG + TdsParser.ReliabilitySection tdsReliabilitySection = new TdsParser.ReliabilitySection(); + RuntimeHelpers.PrepareConstrainedRegions(); + try + { + tdsReliabilitySection.Start(); +#endif //DEBUG + onSuccess(state); +#if DEBUG + } + finally + { + tdsReliabilitySection.Stop(); + } +#endif //DEBUG + } + catch (System.OutOfMemoryException e) + { + if (connectionToDoom != null) + { + connectionToDoom.DoomThisConnection(); + } + else + { + connectionToAbort.Abort(e); + } + completion.SetException(e); + throw; + } + catch (System.StackOverflowException e) + { + if (connectionToDoom != null) + { + connectionToDoom.DoomThisConnection(); + } + else + { + connectionToAbort.Abort(e); + } + completion.SetException(e); + throw; + } + catch (System.Threading.ThreadAbortException e) + { + if (connectionToDoom != null) + { + connectionToDoom.DoomThisConnection(); + } + else + { + connectionToAbort.Abort(e); + } + completion.SetException(e); + throw; + } + catch (Exception e) + { + completion.SetException(e); + } + } + else + { // no connection to doom - reliability section not required + try + { + onSuccess(state); + } + catch (Exception e) + { + completion.SetException(e); + } + } + } + }, + state: state, + scheduler: TaskScheduler.Default + ); + } internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 6b4d322c1f..005ca92acd 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -10376,19 +10376,25 @@ internal Task TdsExecuteRPC(SqlCommand cmd, _SqlRPC[] rpcArray, int timeout, boo task = completion.Task; } - AsyncHelper.ContinueTask(writeParamTask, completion, - () => TdsExecuteRPC(cmd, rpcArray, timeout, inSchema, notificationRequest, stateObj, isCommandProc, sync, completion, - startRpc: ii, startParam: i + 1), - connectionToDoom: _connHandler, - onFailure: exc => TdsExecuteRPC_OnFailure(exc, stateObj)); + AsyncHelper.ContinueTaskWithState(writeParamTask, completion, this, + (object state) => + { + TdsParser tdsParser = (TdsParser)state; + TdsExecuteRPC(cmd, rpcArray, timeout, inSchema, notificationRequest, stateObj, isCommandProc, sync, completion, + startRpc: ii, startParam: i + 1); + }, + onFailure: (Exception exc, object state) => ((TdsParser)state).TdsExecuteRPC_OnFailure(exc, stateObj), + connectionToDoom: _connHandler + ); // Take care of releasing the locks if (releaseConnectionLock) { - task.ContinueWith(_ => - { - _connHandler._parserLock.Release(); - }, TaskScheduler.Default); + task.ContinueWith( + static (Task _, object state) => ((TdsParser)state)._connHandler._parserLock.Release(), + state: this, + scheduler: TaskScheduler.Default + ); releaseConnectionLock = false; } @@ -11953,7 +11959,8 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy { return AsyncHelper.CreateContinuationTask(unterminatedWriteTask, WriteInt, 0, stateObj, - connectionToDoom: _connHandler); + connectionToDoom: _connHandler + ); } } else diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index ebd75aaefa..f4d48ff4fc 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -3359,11 +3359,12 @@ internal Task WriteByteArray(Byte[] b, int len, int offsetBuffer, bool canAccumu } // This is in its own method to avoid always allocating the lambda in WriteByteArray - private void WriteByteArraySetupContinuation(Byte[] b, int len, TaskCompletionSource completion, int offset, Task packetTask) + private void WriteByteArraySetupContinuation(byte[] b, int len, TaskCompletionSource completion, int offset, Task packetTask) { AsyncHelper.ContinueTask(packetTask, completion, () => WriteByteArray(b, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion), - connectionToDoom: _parser.Connection); + connectionToDoom: _parser.Connection + ); } // Dumps contents of buffer to SNI for network write. From 4ffd7fd3c1022ce86b2f0beb5b6d79221456d905 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Mon, 3 May 2021 21:52:38 +0100 Subject: [PATCH 4/5] netcore add strongbox to avoid boxing --- .../src/Microsoft/Data/SqlClient/SqlBulkCopy.cs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 3442bc47c6..8d01563b37 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -9,6 +9,7 @@ using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -2902,11 +2903,11 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio { if (_isAsyncBulkCopy) { - CancellationTokenRegistration regReconnectCancel = new CancellationTokenRegistration(); + StrongBox regReconnectCancel = new StrongBox(new CancellationTokenRegistration()); TaskCompletionSource cancellableReconnectTS = new TaskCompletionSource(); if (cts.CanBeCanceled) { - regReconnectCancel = cts.Register(static (object tcs) => ((TaskCompletionSource)tcs).TrySetCanceled(), cancellableReconnectTS); + regReconnectCancel.Value = cts.Register(static (object tcs) => ((TaskCompletionSource)tcs).TrySetCanceled(), cancellableReconnectTS); } AsyncHelper.ContinueTaskWithState(reconnectTask, cancellableReconnectTS, state: cancellableReconnectTS, @@ -2926,7 +2927,7 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio state: regReconnectCancel, onSuccess: (object state) => { - ((CancellationTokenRegistration)state).Dispose(); + ((StrongBox)state).Value.Dispose(); if (_parserLock != null) { _parserLock.Release(); @@ -2936,8 +2937,8 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio _parserLock.Wait(canReleaseFromAnyThread: true); WriteToServerInternalRestAsync(cts, source); }, - onFailure: static (Exception _, object state) => ((CancellationTokenRegistration)state).Dispose(), - onCancellation: static (object state) => ((CancellationTokenRegistration)state).Dispose(), + onFailure: static (Exception _, object state) => ((StrongBox)state).Value.Dispose(), + onCancellation: static (object state) => ((StrongBox)state).Value.Dispose(), exceptionConverter: (ex) => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex)); return; } From c90d33550972f91f43ae8983dbc42b8760e4c940 Mon Sep 17 00:00:00 2001 From: Wraith2 Date: Fri, 11 Jun 2021 19:01:30 +0100 Subject: [PATCH 5/5] fixup 2 anonymous typed lambdas to be explicit --- .../netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 5bbbc07904..cddba0e8f6 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -5090,7 +5090,7 @@ private Task GetBytesAsyncReadDataStage(int i, byte[] buffer, int index, in else { // setup for cleanup\completing - retryTask.ContinueWith((t) => CompleteRetryable(t, source, timeoutCancellationSource), TaskScheduler.Default); + retryTask.ContinueWith((Task t) => CompleteRetryable(t, source, timeoutCancellationSource), TaskScheduler.Default); return source.Task; } } @@ -5654,7 +5654,7 @@ private Task InvokeRetryable(Func> moreFunc, TaskCompletionS } else { - task.ContinueWith((t) => CompleteRetryable(t, source, objectToDispose), TaskScheduler.Default); + task.ContinueWith((Task t) => CompleteRetryable(t, source, objectToDispose), TaskScheduler.Default); } } catch (AggregateException e)