Skip to content

Commit

Permalink
[Release 2.1] Fix | Fixes Kerberos auth when SPN does not contain port (
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra committed Feb 26, 2021
1 parent 4c45dce commit ff19a0b
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,18 @@ private static SecurityStatusPal EstablishSecurityContext(
}
catch (Exception ex)
{
if (NetEventSource.IsEnabled) NetEventSource.Error(null, ex);
if (NetEventSource.IsEnabled)
{
NetEventSource.Error(null, ex);
}
return new SecurityStatusPal(SecurityStatusPalErrorCode.InternalError, ex);
}
}

internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
string spn,
string[] spns,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
Expand All @@ -156,20 +159,33 @@ internal static SecurityStatusPal InitializeSecurityContext(
}

SafeFreeNegoCredentials negoCredentialsHandle = (SafeFreeNegoCredentials)credentialsHandle;
SecurityStatusPal status = default;

if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
foreach (string spn in spns)
{
throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
}
if (negoCredentialsHandle.IsDefault && string.IsNullOrEmpty(spn))
{
throw new PlatformNotSupportedException(Strings.net_nego_not_supported_empty_target_with_defaultcreds);
}

SecurityStatusPal status = EstablishSecurityContext(
negoCredentialsHandle,
ref securityContext,
spn,
requestedContextFlags,
((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
outSecurityBuffer,
ref contextFlags);
status = EstablishSecurityContext(
negoCredentialsHandle,
ref securityContext,
spn,
requestedContextFlags,
((inSecurityBufferArray != null && inSecurityBufferArray.Length != 0) ? inSecurityBufferArray[0] : null),
outSecurityBuffer,
ref contextFlags);

if (status.ErrorCode != SecurityStatusPalErrorCode.InternalError)
{
break; // Successful case, exit the loop with current SPN.
}
else
{
securityContext = null; // Reset security context to be generated again for next SPN.
}
}

// Confidentiality flag should not be set if not requested
if (status.ErrorCode == SecurityStatusPalErrorCode.CompleteNeeded)
Expand All @@ -180,7 +196,6 @@ internal static SecurityStatusPal InitializeSecurityContext(
throw new PlatformNotSupportedException(Strings.net_nego_protection_level_not_supported);
}
}

return status;
}

Expand Down Expand Up @@ -224,7 +239,7 @@ internal static SafeFreeCredentials AcquireCredentialsHandle(string package, boo
new SafeFreeNegoCredentials(false, string.Empty, string.Empty, string.Empty) :
new SafeFreeNegoCredentials(ntlmOnly, credential.UserName, credential.Password, credential.Domain);
}
catch(Exception ex)
catch (Exception ex)
{
throw new Win32Exception(NTE_FAIL, ex.Message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal static string QueryContextAuthenticationPackage(SafeDeleteContext secur
internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
string spn,
string[] spn,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
Expand All @@ -81,7 +81,7 @@ internal static SecurityStatusPal InitializeSecurityContext(
GlobalSSPI.SSPIAuth,
credentialsHandle,
ref securityContext,
spn,
spn[0],
ContextFlagsAdapterPal.GetInteropFromContextFlagsPal(requestedContextFlags),
Interop.SspiCli.Endianness.SECURITY_NETWORK_DREP,
inSecurityBufferArray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ internal uint DisableSsl(SNIHandle handle)
/// <param name="sendBuff">Send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[] serverName)
internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -104,12 +104,15 @@ internal void GenSspiClientContext(SspiClientContextStatus sspiClientContextStat
| ContextFlagsPal.Delegate
| ContextFlagsPal.MutualAuth;

string serverSPN = System.Text.Encoding.UTF8.GetString(serverName);

string[] serverSPNs = new string[serverName.Length];
for (int i = 0; i < serverName.Length; i++)
{
serverSPNs[i] = System.Text.Encoding.UTF8.GetString(serverName[i]);
}
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
credentialsHandle,
ref securityContext,
serverSPN,
serverSPNs,
requestedContextFlags,
inSecurityBufferArray,
outSecurityBuffer,
Expand Down Expand Up @@ -253,7 +256,7 @@ internal uint WritePacket(SNIHandle handle, SNIPacket packet, bool sync)
/// <param name="cachedFQDN">Used for DNS Cache</param>
/// <param name="pendingDNSInfo">Used for DNS Cache</param>
/// <returns>SNI handle</returns>
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo)
{
instanceName = new byte[1];

Expand Down Expand Up @@ -294,7 +297,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
{
try
{
spnBuffer = GetSqlServerSPN(details);
spnBuffer = GetSqlServerSPNs(details);
}
catch (Exception e)
{
Expand All @@ -305,7 +308,7 @@ internal SNIHandle CreateConnectionHandle(string fullServerName, bool ignoreSniO
return sniHandle;
}

private static byte[] GetSqlServerSPN(DataSource dataSource)
private static byte[][] GetSqlServerSPNs(DataSource dataSource)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));

Expand All @@ -319,16 +322,11 @@ private static byte[] GetSqlServerSPN(DataSource dataSource)
{
postfix = dataSource.InstanceName;
}
// For handling tcp:<hostname> format
else if (dataSource._connectionProtocol == DataSource.Protocol.TCP)
{
postfix = DefaultSqlServerPort.ToString();
}

return GetSqlServerSPN(hostName, postfix);
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
}

private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrInstanceName)
private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand All @@ -347,16 +345,22 @@ private static byte[] GetSqlServerSPN(string hostNameOrAddress, string portOrIns
// If the DNS lookup failed, then resort to using the user provided hostname to construct the SPN.
fullyQualifiedDomainName = hostEntry?.HostName ?? hostNameOrAddress;
}

string serverSpn = SqlServerSpnHeader + "/" + fullyQualifiedDomainName;

if (!string.IsNullOrWhiteSpace(portOrInstanceName))
{
serverSpn += ":" + portOrInstanceName;
}
else
else if (protocol == DataSource.Protocol.None || protocol == DataSource.Protocol.TCP) // Default is TCP
{
serverSpn += $":{DefaultSqlServerPort}";
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
return new byte[][] { Encoding.UTF8.GetBytes(serverSpn), Encoding.UTF8.GetBytes(serverSpnWithDefaultPort) };
}
return Encoding.UTF8.GetBytes(serverSpn);
// else Named Pipes do not need to valid port

return new byte[][] { Encoding.UTF8.GetBytes(serverSpn) };
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ internal sealed partial class TdsParser

private bool _isDenali = false;

private byte[] _sniSpnBuffer = null;
private byte[][] _sniSpnBuffer = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ private void ResetCancelAndProcessAttention()
}
}

internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);
internal abstract void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity = false);

internal abstract void AssignPendingDNSInfo(string userProtocol, string DNSCacheKey, ref SQLDNSInfo pendingDNSInfo);

Expand Down Expand Up @@ -831,7 +831,7 @@ private void ResetCancelAndProcessAttention()

protected abstract void RemovePacketFromPendingList(PacketHandle pointer);

internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer);
internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);

internal bool Deactivate()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async)
protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
=> SNIProxy.GetInstance().PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize);

internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool parallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
_sessionHandle = SNIProxy.GetInstance().CreateConnectionHandle(serverName, ignoreSniOpenTimeout, timerExpire, out instanceName, ref spnBuffer, flushCache, async, parallel, isIntegratedSecurity, cachedFQDN, ref pendingDNSInfo);
if (_sessionHandle == null)
Expand Down Expand Up @@ -215,7 +215,7 @@ internal override uint EnableMars(ref uint info)

internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize) => SNIProxy.GetInstance().SetConnectionBufferSize(Handle, unsignedPacketSize);

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
if (_sspiClientContextStatus == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ internal override void AssignPendingDNSInfo(string userProtocol, string DNSCache

if (string.IsNullOrEmpty(userProtocol))
{

result = SNINativeMethodWrapper.SniGetProviderNumber(Handle, ref providerNumber);
Debug.Assert(result == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetProviderNumber");
_parser.isTcpProtocol = (providerNumber == SNINativeMethodWrapper.ProviderEnum.TCP_PROV);
}
else if (userProtocol == TdsEnums.TCP)
else if (userProtocol == TdsEnums.TCP)
{
_parser.isTcpProtocol = true;
}
Expand Down Expand Up @@ -138,14 +138,14 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async)
return myInfo;
}

internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, bool async, bool fParallel, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo, bool isIntegratedSecurity)
{
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
spnBuffer = null;
spnBuffer = new byte[1][];
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
spnBuffer = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
}

SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
Expand All @@ -172,7 +172,7 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);

_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel, cachedDNSInfo);
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down Expand Up @@ -385,8 +385,8 @@ internal override uint EnableSsl(ref uint info)
internal override uint SetConnectionBufferSize(ref uint unsignedPacketSize)
=> SNINativeMethodWrapper.SNISetInfo(Handle, SNINativeMethodWrapper.QTypes.SNI_QUERY_CONN_BUFSIZE, ref unsignedPacketSize);

internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer)
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer);
internal override uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
=> SNINativeMethodWrapper.SNISecGenClientContext(Handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]);

internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
{
Expand Down Expand Up @@ -421,7 +421,7 @@ internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion)
protocolVersion = (int)SslProtocols.Ssl2;
#pragma warning restore CS0618 // Type or member is obsolete : SSL is depricated
}
else if(nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_NONE))
{
protocolVersion = (int)SslProtocols.None;
}
Expand Down

0 comments on commit ff19a0b

Please sign in to comment.