Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Optimize SqlClient query memory allocation #34047

Merged
merged 8 commits into from
Dec 14, 2018
180 changes: 111 additions & 69 deletions src/System.Data.SqlClient/src/System/Data/SqlClient/SqlCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ private enum EXECTYPE
// The OnReturnValue function will test this flag to determine whether the returned value is a _prepareHandle or something else.
//
// _prepareHandle - the handle of a prepared command. Apparently there can be multiple prepared commands at a time - a feature that we do not support yet.

private static readonly object s_cachedInvalidPrepareHandle = (object)-1;
private bool _inPrepare = false;
private int _prepareHandle = -1;
private object _prepareHandle = s_cachedInvalidPrepareHandle; // this is an int which is used in the object typed SqlParameter.Value field, avoid repeated boxing by storing in a box
private bool _hiddenPrepare = false;
private int _preparedConnectionCloseCount = -1;
private int _preparedConnectionReconnectCount = -1;
Expand Down Expand Up @@ -83,7 +83,7 @@ internal bool InPrepare
}

// Cached info for async executions
private class CachedAsyncState
private sealed class CachedAsyncState
{
private int _cachedAsyncCloseCount = -1; // value of the connection's CloseCount property when the asyncResult was set; tracks when connections are closed after an async operation
private TaskCompletionSource<object> _cachedAsyncResult = null;
Expand Down Expand Up @@ -261,7 +261,7 @@ private SqlCommand(SqlCommand from) : this()
// Don't allow the connection to be changed while in an async operation.
if (_activeConnection != value && _activeConnection != null)
{ // If new value...
if (cachedAsyncState.PendingAsyncOperation)
if (_cachedAsyncState != null && _cachedAsyncState.PendingAsyncOperation)
{ // If in pending async state, throw.
throw SQL.CannotModifyPropertyAsyncOperationInProgress();
}
Expand Down Expand Up @@ -292,7 +292,7 @@ private SqlCommand(SqlCommand from) : this()
finally
{
// clean prepare status (even successful Unprepare does not do that)
_prepareHandle = -1;
_prepareHandle = s_cachedInvalidPrepareHandle;
_execType = EXECTYPE.UNPREPARED;
}
}
Expand Down Expand Up @@ -672,7 +672,7 @@ internal void Unprepare()
if ((_activeConnection.CloseCount != _preparedConnectionCloseCount) || (_activeConnection.ReconnectCount != _preparedConnectionReconnectCount))
{
// reset our handle
_prepareHandle = -1;
_prepareHandle = s_cachedInvalidPrepareHandle;
}

_cachedMetaData = null;
Expand Down Expand Up @@ -918,7 +918,13 @@ public IAsyncResult BeginExecuteNonQuery(AsyncCallback callback, object stateObj
cachedAsyncState.SetActiveConnectionAndResult(completion, nameof(EndExecuteNonQuery), _activeConnection);
if (execNQ != null)
{
AsyncHelper.ContinueTask(execNQ, completion, () => BeginExecuteNonQueryInternalReadStage(completion));
AsyncHelper.ContinueTaskWithState(execNQ, completion,
state: Tuple.Create(this, completion),
onSuccess: state => {
var parameters = (Tuple<SqlCommand, TaskCompletionSource<object>>)state;
parameters.Item1.BeginExecuteNonQueryInternalReadStage(parameters.Item2);
}
);
}
else
{
Expand All @@ -941,7 +947,10 @@ public IAsyncResult BeginExecuteNonQuery(AsyncCallback callback, object stateObj
// Add callback after work is done to avoid overlapping Begin\End methods
if (callback != null)
{
completion.Task.ContinueWith((t) => callback(t), TaskScheduler.Default);
completion.Task.ContinueWith(
(task,state) => ((AsyncCallback)state)(task),
state: callback
);
}

return completion.Task;
Expand Down Expand Up @@ -1170,7 +1179,10 @@ private Task InternalExecuteNonQuery(TaskCompletionSource<object> completion, bo
{
if (task != null)
{
task = AsyncHelper.CreateContinuationTask(task, () => reader.Close());
task = AsyncHelper.CreateContinuationTaskWithState(task,
state: reader,
onSuccess: state => ((SqlDataReader)state).Close()
);
}
else
{
Expand Down Expand Up @@ -1265,7 +1277,13 @@ public IAsyncResult BeginExecuteXmlReader(AsyncCallback callback, object stateOb
cachedAsyncState.SetActiveConnectionAndResult(completion, nameof(EndExecuteXmlReader), _activeConnection);
if (writeTask != null)
{
AsyncHelper.ContinueTask(writeTask, completion, () => BeginExecuteXmlReaderInternalReadStage(completion));
AsyncHelper.ContinueTaskWithState(writeTask, completion,
state: Tuple.Create(this, completion),
onSuccess: state => {
saurabh500 marked this conversation as resolved.
Show resolved Hide resolved
var parameters = (Tuple<SqlCommand, TaskCompletionSource<object>>)state;
parameters.Item1.BeginExecuteXmlReaderInternalReadStage(parameters.Item2);
}
);
}
else
{
Expand Down Expand Up @@ -1528,7 +1546,13 @@ internal IAsyncResult BeginExecuteReader(CommandBehavior behavior, AsyncCallback
cachedAsyncState.SetActiveConnectionAndResult(completion, nameof(EndExecuteReader), _activeConnection);
if (writeTask != null)
{
AsyncHelper.ContinueTask(writeTask, completion, () => BeginExecuteReaderInternalReadStage(completion));
AsyncHelper.ContinueTaskWithState(writeTask, completion,
state: Tuple.Create(this, completion),
onSuccess: state => {
var parameters = (Tuple<SqlCommand, TaskCompletionSource<object>>)state;
parameters.Item1.BeginExecuteReaderInternalReadStage(parameters.Item2);
}
);
}
else
{
Expand Down Expand Up @@ -2328,27 +2352,7 @@ private Task RunExecuteNonQueryTds(string methodName, bool async, int timeout, b
TaskCompletionSource<object> completion = new TaskCompletionSource<object>();
_activeConnection.RegisterWaitingForReconnect(completion.Task);
_reconnectionCompletionSource = completion;
CancellationTokenSource timeoutCTS = new CancellationTokenSource();
AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token);
AsyncHelper.ContinueTask(reconnectTask, completion,
() =>
{
if (completion.Task.IsCompleted)
{
return;
}
Interlocked.CompareExchange(ref _reconnectionCompletionSource, null, completion);
timeoutCTS.Cancel();
Task subTask = RunExecuteNonQueryTds(methodName, async, TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart), asyncWrite);
if (subTask == null)
{
completion.SetResult(null);
}
else
{
AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null));
}
}, connectionToAbort: _activeConnection);
RunExecuteNonQueryTdsSetupReconnnectContinuation(methodName, async, timeout, asyncWrite, reconnectTask, reconnectionStart, completion);
return completion.Task;
}
else
Expand Down Expand Up @@ -2401,6 +2405,31 @@ private Task RunExecuteNonQueryTds(string methodName, bool async, int timeout, b
return null;
}

// This is in its own method to avoid always allocating the lambda in RunExecuteNonQueryTds, cannot use ContinueTaskWithState because of MarshalByRef and the CompareExchange
private void RunExecuteNonQueryTdsSetupReconnnectContinuation(string methodName, bool async, int timeout, bool asyncWrite, Task reconnectTask, long reconnectionStart, TaskCompletionSource<object> completion)
{
CancellationTokenSource timeoutCTS = new CancellationTokenSource();
AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token);
AsyncHelper.ContinueTask(reconnectTask, completion,
() =>
{
if (completion.Task.IsCompleted)
{
return;
}
Interlocked.CompareExchange(ref _reconnectionCompletionSource, null, completion);
timeoutCTS.Cancel();
Task subTask = RunExecuteNonQueryTds(methodName, async, TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart), asyncWrite);
if (subTask == null)
{
completion.SetResult(null);
}
else
{
AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null));
}
}, connectionToAbort: _activeConnection);
}

internal SqlDataReader RunExecuteReader(CommandBehavior cmdBehavior, RunBehavior runBehavior, bool returnStream, [CallerMemberName] string method = "")
{
Expand Down Expand Up @@ -2470,28 +2499,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi
TaskCompletionSource<object> completion = new TaskCompletionSource<object>();
_activeConnection.RegisterWaitingForReconnect(completion.Task);
_reconnectionCompletionSource = completion;
CancellationTokenSource timeoutCTS = new CancellationTokenSource();
AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token);
AsyncHelper.ContinueTask(reconnectTask, completion,
() =>
{
if (completion.Task.IsCompleted)
{
return;
}
Interlocked.CompareExchange(ref _reconnectionCompletionSource, null, completion);
timeoutCTS.Cancel();
Task subTask;
RunExecuteReaderTds(cmdBehavior, runBehavior, returnStream, async, TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart), out subTask, asyncWrite, ds);
if (subTask == null)
{
completion.SetResult(null);
}
else
{
AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null));
}
}, connectionToAbort: _activeConnection);
RunExecuteReaderTdsSetupReconnectContinuation(cmdBehavior, runBehavior, returnStream, async, timeout, asyncWrite, ds, reconnectTask, reconnectionStart, completion);
task = completion.Task;
return ds;
}
Expand Down Expand Up @@ -2568,7 +2576,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi

if (_execType == EXECTYPE.PREPARED)
{
Debug.Assert(this.IsPrepared && (_prepareHandle != -1), "invalid attempt to call sp_execute without a handle!");
Debug.Assert(this.IsPrepared && ((int)_prepareHandle != -1), "invalid attempt to call sp_execute without a handle!");
rpc = BuildExecute(inSchema);
}
else if (_execType == EXECTYPE.PREPAREPENDING)
Expand Down Expand Up @@ -2627,15 +2635,7 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi
decrementAsyncCountOnFailure = false;
if (writeTask != null)
{
task = AsyncHelper.CreateContinuationTask(writeTask, () =>
{
_activeConnection.GetOpenTdsConnection(); // it will throw if connection is closed
cachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings);
},
onFailure: (exc) =>
{
_activeConnection.GetOpenTdsConnection().DecrementAsyncCount();
});
task = RunExecuteReaderTdsSetupContinuation(runBehavior, ds, optionSettings, writeTask);
}
else
{
Expand Down Expand Up @@ -2674,6 +2674,48 @@ private SqlDataReader RunExecuteReaderTds(CommandBehavior cmdBehavior, RunBehavi
return ds;
}

// This is in its own method to avoid always allocating the lambda in RunExecuteReaderTds
private Task RunExecuteReaderTdsSetupContinuation(RunBehavior runBehavior, SqlDataReader ds, string optionSettings, Task writeTask)
{
Task task = AsyncHelper.CreateContinuationTask(writeTask, () =>
{
_activeConnection.GetOpenTdsConnection(); // it will throw if connection is closed
cachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings);
},
onFailure: (exc) =>
{
_activeConnection.GetOpenTdsConnection().DecrementAsyncCount();
});
return task;
}

// This is in its own method to avoid always allocating the lambda in RunExecuteReaderTds
private void RunExecuteReaderTdsSetupReconnectContinuation(CommandBehavior cmdBehavior, RunBehavior runBehavior, bool returnStream, bool async, int timeout, bool asyncWrite, SqlDataReader ds, Task reconnectTask, long reconnectionStart, TaskCompletionSource<object> completion)
{
CancellationTokenSource timeoutCTS = new CancellationTokenSource();
AsyncHelper.SetTimeoutException(completion, timeout, SQL.CR_ReconnectTimeout, timeoutCTS.Token);
AsyncHelper.ContinueTask(reconnectTask, completion,
() =>
{
if (completion.Task.IsCompleted)
{
return;
}
Interlocked.CompareExchange(ref _reconnectionCompletionSource, null, completion);
timeoutCTS.Cancel();
Task subTask;
RunExecuteReaderTds(cmdBehavior, runBehavior, returnStream, async, TdsParserStaticMethods.GetRemainingTimeout(timeout, reconnectionStart), out subTask, asyncWrite, ds);
if (subTask == null)
{
completion.SetResult(null);
}
else
{
AsyncHelper.ContinueTask(subTask, completion, () => completion.SetResult(null));
}
}, connectionToAbort: _activeConnection
);
}

private SqlDataReader CompleteAsyncExecuteReader()
{
Expand Down Expand Up @@ -2863,16 +2905,16 @@ private void ValidateCommand(bool async, [CallerMemberName] string method = "")

private void ValidateAsyncCommand()
{
if (cachedAsyncState.PendingAsyncOperation)
if (_cachedAsyncState != null && _cachedAsyncState.PendingAsyncOperation)
Wraith2 marked this conversation as resolved.
Show resolved Hide resolved
{ // Enforce only one pending async execute at a time.
if (cachedAsyncState.IsActiveConnectionValid(_activeConnection))
if (_cachedAsyncState.IsActiveConnectionValid(_activeConnection))
{
throw SQL.PendingBeginXXXExists();
}
else
{
_stateObj = null; // Session was re-claimed by session pool upon connection close.
cachedAsyncState.ResetAsyncState();
_cachedAsyncState.ResetAsyncState();
}
}
}
Expand Down Expand Up @@ -3371,7 +3413,7 @@ private void BuildRPC(bool inSchema, SqlParameterCollection parameters, ref _Sql

private _SqlRPC BuildExecute(bool inSchema)
{
Debug.Assert(_prepareHandle != -1, "Invalid call to sp_execute without a valid handle!");
Debug.Assert((int)_prepareHandle != -1, "Invalid call to sp_execute without a valid handle!");
int j = 1;

int count = CountSendableParameters(_parameters);
Expand Down Expand Up @@ -3401,7 +3443,7 @@ private _SqlRPC BuildExecute(bool inSchema)
private void BuildExecuteSql(CommandBehavior behavior, string commandText, SqlParameterCollection parameters, ref _SqlRPC rpc)
{

Debug.Assert(_prepareHandle == -1, "This command has an existing handle, use sp_execute!");
Debug.Assert((int)_prepareHandle == -1, "This command has an existing handle, use sp_execute!");
Debug.Assert(CommandType.Text == this.CommandType, "invalid use of sp_executesql for stored proc invocation!");
int j;
SqlParameter sqlParam;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,8 @@ internal Task ValidateAndReconnect(Action beforeDisconnect, int timeout)
catch (SqlException)
{
}
runningReconnect = Task.Run(() => ReconnectAsync(timeout));
// use Task.Factory.StartNew with state overload instead of Task.Run to avoid anonymous closure context capture in method scope and avoid the unneeded allocation
runningReconnect = Task.Factory.StartNew(state => ReconnectAsync((int)state), timeout, CancellationToken.None, TaskCreationOptions.DenyChildAttach, TaskScheduler.Default);
// if current reconnect is not null, somebody already started reconnection task - some kind of race condition
Debug.Assert(_currentReconnectionTask == null, "Duplicate reconnection tasks detected");
_currentReconnectionTask = runningReconnect;
Expand Down
Loading