Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Hang on infinite timeout and managed SNI #1742

Merged
merged 2 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;

Expand Down Expand Up @@ -194,6 +195,15 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C
}
}

internal static IPAddress[] GetDnsIpAddresses(string serverName)
{
using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses)))
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, "Getting DNS host entries for serverName {0}.", args0: serverName);
return Dns.GetHostAddresses(serverName);
}
}

/// <summary>
/// Sets last error encountered for SNI
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
Socket availableSocket = null;
Task<Socket> connectTask;

Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(hostName);
bool complete = serverAddrTask.Wait(ts);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] serverAddresses = serverAddrTask.Result;
IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName);

if (serverAddresses.Length > MaxParallelIpAddresses)
{
Expand Down Expand Up @@ -358,14 +351,7 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));

Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(serverName);
bool complete = serverAddrTask.Wait(timeout);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] ipAddresses = serverAddrTask.Result;
IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(serverName);

string IPv4String = null;
string IPv6String = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,16 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
Debug.Assert(port >= 0 && port <= 65535, "Invalid port");
Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array");

bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address);
if (IPAddress.TryParse(browserHostname, out IPAddress address))
{
SsrpResult response = SendUDPRequest(new IPAddress[] { address }, port, requestPacket, allIPsInParallel);
if (response != null && response.ResponsePacket != null)
return response.ResponsePacket;
else if (response != null && response.Error != null)
throw response.Error;
else
return null;
}

TimeSpan ts = default;
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
Expand All @@ -181,27 +190,7 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
}

IPAddress[] ipAddresses = null;
if (!isIpAddress)
{
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(browserHostname);
bool taskComplete;
try
{
taskComplete = serverAddrTask.Wait(ts);
}
catch (AggregateException ae)
{
throw ae.InnerException;
}

// If DNS took too long, need to return instead of blocking
if (!taskComplete)
return null;

ipAddresses = serverAddrTask.Result;
}

IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname);
Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");

switch (ipPreference)
Expand Down Expand Up @@ -278,7 +267,7 @@ private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte
for (int i = 0; i < ipAddresses.Length; i++)
{
IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port);
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket)));
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket), cts.Token));
}

List<Task<SsrpResult>> completedTasks = new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,21 @@ public static void EnvironmentHostNameSPIDTest()
Assert.True(false, "No non-empty hostname found for the application");
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static async void ConnectionTimeoutInfiniteTest()
{
// Exercise the special-case infinite connect timeout code path
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString)
{
ConnectTimeout = 0 // Infinite
};

using SqlConnection conn = new(builder.ConnectionString);
CancellationTokenSource cts = new(30000);
// Will throw TaskCanceledException and fail the test in the event of a hang
await conn.OpenAsync(cts.Token);
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectionTimeoutTestWithThread()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
Expand All @@ -12,17 +13,17 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public static class InstanceNameTest
{
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))]
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.AreConnStringsSetup))]
public static void ConnectToSQLWithInstanceNameTest()
{
SqlConnectionStringBuilder builder = new(DataTestUtility.TCPConnectionString);

bool proceed = true;
string dataSourceStr = builder.DataSource.Replace("tcp:", "");
string[] serverNamePartsByBackSlash = dataSourceStr.Split('\\');
string hostname = serverNamePartsByBackSlash[0];
if (!dataSourceStr.Contains(",") && serverNamePartsByBackSlash.Length == 2)
{
string hostname = serverNamePartsByBackSlash[0];
proceed = !string.IsNullOrWhiteSpace(hostname) && IsBrowserAlive(hostname);
}

Expand All @@ -31,6 +32,17 @@ public static void ConnectToSQLWithInstanceNameTest()
using SqlConnection connection = new(builder.ConnectionString);
connection.Open();
connection.Close();

if (builder.Encrypt != SqlConnectionEncryptOption.Strict)
{
// Exercise the IP address-specific code in SSRP
IPAddress[] addresses = Dns.GetHostAddresses(hostname);
builder.DataSource = builder.DataSource.Replace(hostname, addresses[0].ToString());
builder.TrustServerCertificate = true;
using SqlConnection connection2 = new(builder.ConnectionString);
connection2.Open();
connection2.Close();
}
}
}

Expand Down