Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix 4.1.1] | Parallelize SSRP requests when MSF is specified (#1578) #1708

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ internal class SNICommon
internal const int ConnTimeoutError = 11;
internal const int ConnNotUsableError = 19;
internal const int InvalidConnStringError = 25;
internal const int ErrorLocatingServerInstance = 26;
internal const int HandshakeFailureError = 31;
internal const int InternalExceptionError = 35;
internal const int ConnOpenFailedError = 40;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)
/// <param name="isIntegratedSecurity"></param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal static SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer,
bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
Expand Down Expand Up @@ -263,7 +263,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
/// <param name="parallel">Should MultiSubnetFailover be used</param>
/// <param name="ipPreference">IP address preference</param>
/// <param name="cachedFQDN">Key for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNITCPHandle</returns>
private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
Expand All @@ -285,12 +285,12 @@ private static SNITCPHandle CreateTcpHandle(DataSource details, long timerExpire
try
{
port = isAdminConnection ?
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName);
SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) :
SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference);
}
catch (SocketException se)
{
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.InvalidConnStringError, se);
SNILoadHandle.SingletonInstance.LastError = new SNIError(SNIProviders.TCP_PROV, SNICommon.ErrorLocatingServerInstance, se);
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ public SNITCPHandle(string serverName, int port, long timerExpire, bool parallel
bool reportError = true;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port);
// We will always first try to connect with serverName as before and let the DNS server to resolve the serverName.
// If the DSN resolution fails, we will try with IPs in the DNS cache if existed. We try with cached IPs based on IPAddressPreference.
// The exceptions will be throw to upper level and be handled as before.
// We will always first try to connect with serverName as before and let DNS resolve the serverName.
// If DNS resolution fails, we will try with IPs in the DNS cache if they exist. We try with cached IPs based on IPAddressPreference.
// Exceptions will be thrown to the caller and be handled as before.
try
{
if (parallel)
Expand Down Expand Up @@ -280,7 +280,12 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i
Task<Socket> connectTask;

Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(hostName);
serverAddrTask.Wait(ts);
bool complete = serverAddrTask.Wait(ts);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] serverAddresses = serverAddrTask.Result;

if (serverAddresses.Length > MaxParallelIpAddresses)
Expand Down Expand Up @@ -324,7 +329,6 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i

availableSocket = connectTask.Result;
return availableSocket;

}

// Connect to server with hostName and port.
Expand All @@ -334,7 +338,14 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference));

IPAddress[] ipAddresses = Dns.GetHostAddresses(serverName);
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(serverName);
bool complete = serverAddrTask.Wait(timeout);

// DNS timed out - don't block
if (!complete)
return null;

IPAddress[] ipAddresses = serverAddrTask.Result;

string IPv4String = null;
string IPv6String = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.SNI
Expand All @@ -21,8 +24,11 @@ internal class SSRP
/// </summary>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to the ported PR, this class is sealed in other branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's sealed in the main branch by #1430. For backporting, we should consider a PR from the main branch to keep track of the changes. In my opinion, this is a helper class with static members only and we can avoid this amendment.

/// <param name="browserHostName">SQL Sever Browser hostname</param>
/// <param name="instanceName">instance name to find port number</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>port number for given instance name</returns>
internal static int GetPortByInstanceName(string browserHostName, string instanceName)
internal static int GetPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");
Expand All @@ -32,7 +38,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc
byte[] responsePacket = null;
try
{
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest);
responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timerExpire, allIPsInParallel, ipPreference);
}
catch (SocketException se)
{
Expand Down Expand Up @@ -87,14 +93,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName)
/// </summary>
/// <param name="browserHostName">SQL Sever Browser hostname</param>
/// <param name="instanceName">instance name to lookup DAC port</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>DAC port for given instance name</returns>
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName)
internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace");
Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace");

byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName);
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest);
byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timerExpire, allIPsInParallel, ipPreference);

const byte SvrResp = 0x05;
const byte ProtocolVersion = 0x01;
Expand Down Expand Up @@ -131,43 +140,198 @@ private static byte[] CreateDacPortInfoRequest(string instanceName)
return requestPacket;
}

private class SsrpResult
{
public byte[] ResponsePacket;
public Exception Error;
}

/// <summary>
/// Sends request to server, and receives response from server by UDP.
/// </summary>
/// <param name="browserHostname">UDP server hostname</param>
/// <param name="port">UDP server port</param>
/// <param name="requestPacket">request packet</param>
/// <param name="timerExpire">Connection timer expiration</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <param name="ipPreference">IP address preference</param>
/// <returns>response packet from UDP server</returns>
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket)
private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference)
{
using (TrySNIEventScope.Create(nameof(SSRP)))
{
Debug.Assert(!string.IsNullOrWhiteSpace(browserHostname), "browserhostname should not be null, empty, or whitespace");
Debug.Assert(port >= 0 && port <= 65535, "Invalid port");
Debug.Assert(requestPacket != null && requestPacket.Length > 0, "requestPacket should not be null or 0-length array");

const int sendTimeOutMs = 1000;
const int receiveTimeOutMs = 1000;
bool isIpAddress = IPAddress.TryParse(browserHostname, out IPAddress address);

IPAddress address = null;
bool isIpAddress = IPAddress.TryParse(browserHostname, out address);
TimeSpan ts = default;
// In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count
// The infinite Timeout is a function of ConnectionString Timeout=0
if (long.MaxValue != timerExpire)
{
ts = DateTime.FromFileTime(timerExpire) - DateTime.Now;
ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts;
}

byte[] responsePacket = null;
using (UdpClient client = new UdpClient(!isIpAddress ? AddressFamily.InterNetwork : address.AddressFamily))
IPAddress[] ipAddresses = null;
if (!isIpAddress)
{
Task<IPAddress[]> serverAddrTask = Dns.GetHostAddressesAsync(browserHostname);
bool taskComplete;
try
{
taskComplete = serverAddrTask.Wait(ts);
}
catch (AggregateException ae)
{
throw ae.InnerException;
}

// If DNS took too long, need to return instead of blocking
if (!taskComplete)
return null;

ipAddresses = serverAddrTask.Result;
}

Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");

switch (ipPreference)
{
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, browserHostname, port);
case SqlConnectionIPAddressPreference.IPv4First:
{
SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
return response4.ResponsePacket;

SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
return response6.ResponsePacket;

// No responses so throw first error
if (response4 != null && response4.Error != null)
throw response4.Error;
else if (response6 != null && response6.Error != null)
throw response6.Error;

break;
}
case SqlConnectionIPAddressPreference.IPv6First:
{
SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
return response6.ResponsePacket;

SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
return response4.ResponsePacket;

// No responses so throw first error
if (response6 != null && response6.Error != null)
throw response6.Error;
else if (response4 != null && response4.Error != null)
throw response4.Error;

break;
}
default:
{
SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel);
if (response != null && response.ResponsePacket != null)
return response.ResponsePacket;
else if (response != null && response.Error != null)
throw response.Error;

break;
}
}

return null;
}
}

/// <summary>
/// Sends request to server, and receives response from server by UDP.
/// </summary>
/// <param name="ipAddresses">IP Addresses</param>
/// <param name="port">UDP server port</param>
/// <param name="requestPacket">request packet</param>
/// <param name="allIPsInParallel">query all resolved IP addresses in parallel</param>
/// <returns>response packet from UDP server</returns>
private static SsrpResult SendUDPRequest(IPAddress[] ipAddresses, int port, byte[] requestPacket, bool allIPsInParallel)
{
if (ipAddresses.Length == 0)
return null;

if (allIPsInParallel) // Used for MultiSubnetFailover
{
List<Task<SsrpResult>> tasks = new(ipAddresses.Length);
CancellationTokenSource cts = new CancellationTokenSource();
for (int i = 0; i < ipAddresses.Length; i++)
{
IPEndPoint endPoint = new IPEndPoint(ipAddresses[i], port);
tasks.Add(Task.Factory.StartNew<SsrpResult>(() => SendUDPRequest(endPoint, requestPacket)));
}

List<Task<SsrpResult>> completedTasks = new();
while (tasks.Count > 0)
{
int first = Task.WaitAny(tasks.ToArray());
if (tasks[first].Result.ResponsePacket != null)
{
cts.Cancel();
return tasks[first].Result;
}
else
{
completedTasks.Add(tasks[first]);
tasks.Remove(tasks[first]);
}
}

Debug.Assert(completedTasks.Count > 0, "completedTasks should never be 0");

// All tasks failed. Return the error from the first failure.
return completedTasks[0].Result;
}
else
{
// If not parallel, use the first IP address provided
IPEndPoint endPoint = new IPEndPoint(ipAddresses[0], port);
return SendUDPRequest(endPoint, requestPacket);
}
}

private static SsrpResult SendUDPRequest(IPEndPoint endPoint, byte[] requestPacket)
{
const int sendTimeOutMs = 1000;
const int receiveTimeOutMs = 1000;

SsrpResult result = new();

try
{
using (UdpClient client = new UdpClient(endPoint.AddressFamily))
{
Task<int> sendTask = client.SendAsync(requestPacket, requestPacket.Length, endPoint);
Task<UdpReceiveResult> receiveTask = null;

SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Waiting for UDP Client to fetch Port info.");
if (sendTask.Wait(sendTimeOutMs) && (receiveTask = client.ReceiveAsync()).Wait(receiveTimeOutMs))
{
SqlClientEventSource.Log.TrySNITraceEvent(nameof(SSRP), EventType.INFO, "Received Port info from UDP Client.");
responsePacket = receiveTask.Result.Buffer;
result.ResponsePacket = receiveTask.Result.Buffer;
}
}

return responsePacket;
}
catch (Exception e)
{
result.Error = e;
}

return result;
}
}
}
Loading