Skip to content

Commit

Permalink
[Hotfix 4.1.1] | Fix hang on infinite timeout and managed SNI (#1742) (
Browse files Browse the repository at this point in the history
  • Loading branch information
DavoudEshtehari authored Aug 31, 2022
1 parent a06894f commit 8b4022b
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;

Expand Down Expand Up @@ -194,6 +195,15 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName)
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName);
return Dns.GetHostAddresses(serverName);
}
}

/// <summary>
/// Sets last error encountered for SNI
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
Socket availableSocket = null;
Task<Socket> connectTask;

Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(hostName);
bool complete = serverAddrTask.Wait(ts);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] serverAddresses = serverAddrTask.Result;
IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName);

if (serverAddresses.Length > MaxParallelIpAddresses)
{
Expand Down Expand Up @@ -338,14 +331,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));

Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(serverName);
bool complete = serverAddrTask.Wait(timeout);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] ipAddresses = serverAddrTask.Result;
IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(serverName);

string IPv4String = null;
string IPv6String = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,16 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
Debug.Assert(port >= 0 && port <= 65535, "Invalid port");
Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array");

bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address);
if (IPAddress.TryParse(browserHostname, out IPAddress address))
{
SsrpResult response = SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel);
if (response != null && response.ResponsePacket != null)
return response.ResponsePacket;
else if (response != null && response.Error != null)
throw response.Error;
else
return null;
}

TimeSpan ts = default;
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
Expand All @@ -175,27 +184,7 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
}

IPAddress[] ipAddresses = null;
if (!isIpAddress)
{
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(browserHostname);
bool taskComplete;
try
{
taskComplete = serverAddrTask.Wait(ts);
}
catch (AggregateException ae)
{
throw ae.InnerException;
}

// If DNS took too long, need to return instead of blocking
if (!taskComplete)
return null;

ipAddresses = serverAddrTask.Result;
}

IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname);
Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");

switch (ipPreference)
Expand Down Expand Up @@ -272,7 +261,7 @@ private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte
for (int i = 0; i < ipAddresses.Length; i++)
{
IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port);
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket)));
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket), cts.Token));
}

List<Task<SsrpResult>> completedTasks = new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ public static void EnvironmentHostNameSPIDTest()
Assert.True(false, "No non-empty hostname found for the application");
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static async void ConnectionTimeoutInfiniteTest()
{
// Exercise the special-case infinite connect timeout code path
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString)
{
ConnectTimeout = 0 // Infinite
};

using SqlConnection conn = new(builder.ConnectionString);
CancellationTokenSource cts = new(30000);
// Will throw TaskCanceledException and fail the test in the event of a hang
await conn.OpenAsync(cts.Token);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectionTimeoutTestWithThread()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
Expand All @@ -12,17 +13,17 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public static class InstanceNameTest
{
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectToSQLWithInstanceNameTest()
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);

bool proceed = true;
string dataSourceStr = builder.DataSource.Replace("tcp:", "");
string[] serverNamePartsByBackSlash = dataSourceStr.Split('\\');
string hostname = serverNamePartsByBackSlash[0];
if (!dataSourceStr.Contains(",") && serverNamePartsByBackSlash.Length == 2)
{
string hostname = serverNamePartsByBackSlash[0];
proceed = !string.IsNullOrWhiteSpace(hostname) && IsBrowserAlive(hostname);
}

Expand All @@ -31,6 +32,14 @@ public static void ConnectToSQLWithInstanceNameTest()
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
connection.Close();

// Exercise the IP address-specific code in SSRP
IPAddress[] addresses = Dns.GetHostAddresses(hostname);
builder.DataSource = builder.DataSource.Replace(hostname, addresses[0].ToString());
builder.TrustServerCertificate = true;
using SqlConnection connection2 = new(builder.ConnectionString);
connection2.Open();
connection2.Close();
}
}

Expand Down

0 comments on commit 8b4022b

Please sign in to comment.