Skip to content

Commit

Permalink
Fix | Fix the issue with Socke.Connect in managed SNI (dotnet#2777)
Browse files Browse the repository at this point in the history
  • Loading branch information
Javad Rahnama authored Aug 15, 2024
1 parent d3658ed commit 619fa74
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,14 @@ internal static bool ValidateSslServerCertificate(Guid connectionId, string targ

internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer timeout)
{
using (TrySNIEventScope.Create(nameof(SNICommon)))
IPAddress[] ipAddresses = GetDnsIpAddresses(serverName);

// We cannot timeout accurately in sync code above, so throw TimeoutException if we've now exceeded the timeout.
if (timeout.IsExpired)
{
int remainingTimeout = timeout.MillisecondsRemainingInt;
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO,
"Getting DNS host entries for serverName {0} within {1} milliseconds.",
args0: serverName,
args1: remainingTimeout);
using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout);
// using this overload to support netstandard
Task<IPAddress[]> task = Dns.GetHostAddressesAsync(serverName);
task.ConfigureAwait(false);
task.Wait(cts.Token);
return task.Result;
throw new TimeoutException();
}
return ipAddresses;
}

internal static IPAddress[] GetDnsIpAddresses(string serverName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,33 +426,16 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout,
bool isConnected;
try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select
{
if (isInfiniteTimeout)
socket.Connect(ipAddress, port);
if (!isInfiniteTimeout)
{
socket.Connect(ipAddress, port);
}
else
{
if (timeout.IsExpired)
{
return null;
}
// Socket.Connect does not support infinite timeouts, so we use Task to simulate it
Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port));
socketConnectTask.ConfigureAwait(false);
socketConnectTask.Start();
int remainingTimeout = timeout.MillisecondsRemainingInt;
if (!socketConnectTask.Wait(remainingTimeout))
{
throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time.");
}
throw SQL.SocketDidNotThrow();
}

isConnected = true;
}
catch (AggregateException aggregateException) when (!isInfiniteTimeout
&& aggregateException.InnerException is SocketException socketException
&& socketException.SocketErrorCode == SocketError.WouldBlock)
catch (SocketException socketException) when (!isInfiniteTimeout &&
socketException.SocketErrorCode == SocketError.WouldBlock)
{
// https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118
// Socket.Select is used because it supports timeouts, while Socket.Connect does not
Expand Down Expand Up @@ -509,11 +492,11 @@ private static Socket Connect(string serverName, int port, TimeoutTimer timeout,
return socket;
}
}
catch (AggregateException aggregateException) when (aggregateException.InnerException is SocketException socketException)
catch (SocketException e)
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: socketException?.Message);
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message);
SqlClientEventSource.Log.TryAdvancedTraceEvent(
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {socketException}");
$"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}");
}
finally
{
Expand Down

0 comments on commit 619fa74

Please sign in to comment.