Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port CoreFx PR 38271: Fix Statement Command Cancellation (Managed SNI) #248

Merged
merged 8 commits into from
Jan 8, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Net.Security;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Threading;

namespace Microsoft.Data.SqlClient.SNI
{
Expand All @@ -22,6 +23,7 @@ internal sealed class SNINpHandle : SNIHandle

private readonly string _targetServer;
private readonly object _callbackObject;
private readonly object _sendSync;

private Stream _stream;
private NamedPipeClientStream _pipeStream;
Expand All @@ -38,6 +40,7 @@ internal sealed class SNINpHandle : SNIHandle

public SNINpHandle(string serverName, string pipeName, long timerExpire, object callbackObject)
{
_sendSync = new object();
_targetServer = serverName;
_callbackObject = callbackObject;

Expand Down Expand Up @@ -206,20 +209,48 @@ public override uint ReceiveAsync(ref SNIPacket packet)

public override uint Send(SNIPacket packet)
{
lock (this)
bool releaseLock = false;
try
{
try
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
Monitor.TryEnter(this, ref releaseLock);
}
catch (ObjectDisposedException ode)
else
{
return ReportErrorAndReleasePacket(packet, ode);
Monitor.Enter(this);
releaseLock = true;
}
catch (IOException ioe)

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
{
return ReportErrorAndReleasePacket(packet, ode);
}
catch (IOException ioe)
{
return ReportErrorAndReleasePacket(packet, ioe);
}
}
}
finally
{
if (releaseLock)
{
return ReportErrorAndReleasePacket(packet, ioe);
Monitor.Exit(this);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public SNIPacket(int headerSize, int dataSize)
/// </summary>
public int DataLeft => (_dataLength - _dataOffset);

/// <summary>
/// Indicates that the packet should be sent out of band bypassing the normal send-recieve lock
/// </summary>
public bool IsOutOfBand { get; set; }

/// <summary>
/// Length of data
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ internal sealed class SNITCPHandle : SNIHandle
{
private readonly string _targetServer;
private readonly object _callbackObject;
private readonly object _sendSync;
private readonly Socket _socket;
private NetworkStream _tcpStream;

Expand Down Expand Up @@ -104,6 +105,7 @@ public SNITCPHandle(string serverName, int port, long timerExpire, object callba
{
_callbackObject = callbackObject;
_targetServer = serverName;
_sendSync = new object();

try
{
Expand Down Expand Up @@ -432,24 +434,52 @@ public override void SetBufferSize(int bufferSize)
/// <returns>SNI error code</returns>
public override uint Send(SNIPacket packet)
{
lock (this)
bool releaseLock = false;
try
{
try
// is the packet is marked out out-of-band (attention packets only) it must be
// sent immediately even if a send of recieve operation is already in progress
// because out of band packets are used to cancel ongoing operations
// so try to take the lock if possible but continue even if it can't be taken
if (packet.IsOutOfBand)
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
Monitor.TryEnter(this, ref releaseLock);
}
catch (ObjectDisposedException ode)
else
{
return ReportTcpSNIError(ode);
Monitor.Enter(this);
releaseLock = true;
}
catch (SocketException se)

// this lock ensures that two packets are not being written to the transport at the same time
// so that sending a standard and an out-of-band packet are both written atomically no data is
// interleaved
lock (_sendSync)
{
return ReportTcpSNIError(se);
try
{
packet.WriteToStream(_stream);
return TdsEnums.SNI_SUCCESS;
}
catch (ObjectDisposedException ode)
{
return ReportTcpSNIError(ode);
}
catch (SocketException se)
{
return ReportTcpSNIError(se);
}
catch (IOException ioe)
{
return ReportTcpSNIError(ioe);
}
}
catch (IOException ioe)
}
finally
{
if (releaseLock)
{
return ReportTcpSNIError(ioe);
Monitor.Exit(this);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ internal override PacketHandle CreateAndSetAttentionPacket()
{
PacketHandle packetHandle = GetResetWritePacket(TdsEnums.HEADER_LEN);
SetPacketData(packetHandle, SQL.AttentionHeader, TdsEnums.HEADER_LEN);
packetHandle.ManagedPacket.IsOutOfBand = true;
return packetHandle;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,23 @@ private static void MultiThreadedCancel(string constr, bool async)
using (SqlConnection con = new SqlConnection(constr))
{
con.Open();
var command = con.CreateCommand();
command.CommandText = "select * from orders; waitfor delay '00:00:08'; select * from customers";
using (var command = con.CreateCommand())
{
command.CommandText = "select * from orders; waitfor delay '00:00:08'; select * from customers";

Barrier threadsReady = new Barrier(2);
object state = new Tuple<bool, SqlCommand, Barrier>(async, command, threadsReady);
Barrier threadsReady = new Barrier(2);
object state = new Tuple<bool, SqlCommand, Barrier>(async, command, threadsReady);

Task[] tasks = new Task[2];
tasks[0] = new Task(ExecuteCommandCancelExpected, state);
tasks[1] = new Task(CancelSharedCommand, state);
tasks[0].Start();
tasks[1].Start();
Task[] tasks = new Task[2];
tasks[0] = new Task(ExecuteCommandCancelExpected, state);
tasks[1] = new Task(CancelSharedCommand, state);
tasks[0].Start();
tasks[1].Start();

Task.WaitAll(tasks, 15 * 1000);
Task.WaitAll(tasks, 15 * 1000);

SqlCommandCancelTest.VerifyConnection(command);
SqlCommandCancelTest.VerifyConnection(command);
}
}
}

Expand All @@ -148,14 +150,16 @@ private static void TimeoutCancel(string constr)
using (SqlConnection con = new SqlConnection(constr))
{
con.Open();
SqlCommand cmd = con.CreateCommand();
cmd.CommandTimeout = 1;
cmd.CommandText = "WAITFOR DELAY '00:00:30';select * from Customers";
using (SqlCommand cmd = con.CreateCommand())
{
cmd.CommandTimeout = 1;
cmd.CommandText = "WAITFOR DELAY '00:00:30';select * from Customers";

string errorMessage = SystemDataResourceManager.Instance.SQL_Timeout_Execution;
DataTestUtility.ExpectFailure<SqlException>(() => cmd.ExecuteReader(), new string[] { errorMessage });
string errorMessage = SystemDataResourceManager.Instance.SQL_Timeout_Execution;
DataTestUtility.ExpectFailure<SqlException>(() => cmd.ExecuteReader(), new string[] { errorMessage });

VerifyConnection(cmd);
VerifyConnection(cmd);
}
}
}

Expand Down Expand Up @@ -253,34 +257,99 @@ private static void TimeOutDuringRead(string constr)
{
// Start the command
conn.Open();
SqlCommand cmd = new SqlCommand("SELECT @p", conn);
cmd.Parameters.AddWithValue("p", new byte[20000]);
SqlDataReader reader = cmd.ExecuteReader();
reader.Read();

// Tweak the timeout to 1ms, stop the proxy from proxying and then try GetValue (which should timeout)
reader.SetDefaultTimeout(1);
proxy.PauseCopying();
string errorMessage = SystemDataResourceManager.Instance.SQL_Timeout_Execution;
Exception exception = Assert.Throws<SqlException>(() => reader.GetValue(0));
Assert.Contains(errorMessage, exception.Message);

// Return everything to normal and close
proxy.ResumeCopying();
reader.SetDefaultTimeout(30000);
reader.Dispose();
using (SqlCommand cmd = new SqlCommand("SELECT @p", conn))
{
cmd.Parameters.AddWithValue("p", new byte[20000]);
using (SqlDataReader reader = cmd.ExecuteReader())
{
reader.Read();

// Tweak the timeout to 1ms, stop the proxy from proxying and then try GetValue (which should timeout)
reader.SetDefaultTimeout(1);
proxy.PauseCopying();
string errorMessage = SystemDataResourceManager.Instance.SQL_Timeout_Execution;
Exception exception = Assert.Throws<SqlException>(() => reader.GetValue(0));
Assert.Contains(errorMessage, exception.Message);

// Return everything to normal and close
proxy.ResumeCopying();
reader.SetDefaultTimeout(30000);
reader.Dispose();
}
}
}

proxy.Stop();
}
catch
{
// In case of error, stop the proxy and dump its logs (hopefully this will help with debugging
proxy.Stop();
Console.WriteLine(proxy.GetServerEventLog());
Assert.True(false, "Error while reading through proxy");
throw;
}
finally
{
proxy.Stop();
}
}

[CheckConnStrSetupFact]
public static void CancelDoesNotWait()
{
const int delaySeconds = 30;
const int cancelSeconds = 1;

using (SqlConnection conn = new SqlConnection(s_connStr))
using (var cmd = new SqlCommand($"WAITFOR DELAY '00:00:{delaySeconds:D2}'", conn))
{
conn.Open();

Task.Delay(TimeSpan.FromSeconds(cancelSeconds * 2))
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
.ContinueWith(t => cmd.Cancel());

DateTime started = DateTime.UtcNow;
DateTime ended = DateTime.UtcNow;
Exception exception = null;
try
{
cmd.ExecuteNonQuery();
}
catch (Exception ex)
{
exception = ex;
}
ended = DateTime.UtcNow;

Assert.NotNull(exception);
Assert.InRange((ended - started).TotalSeconds, cancelSeconds, delaySeconds - 1);
}
}

[CheckConnStrSetupFact]
public static async Task AsyncCancelDoesNotWait()
{
const int delaySeconds = 30;
const int cancelSeconds = 1;

using (SqlConnection conn = new SqlConnection(s_connStr))
using (var cmd = new SqlCommand($"WAITFOR DELAY '00:00:{delaySeconds:D2}'", conn))
{
await conn.OpenAsync();

DateTime started = DateTime.UtcNow;
Exception exception = null;
try
{
await cmd.ExecuteNonQueryAsync(new CancellationTokenSource(2000).Token);
}
catch (Exception ex)
{
exception = ex;
}
DateTime ended = DateTime.UtcNow;

Assert.NotNull(exception);
Assert.InRange((ended - started).TotalSeconds, cancelSeconds, delaySeconds - 1);
}
}
}
}