Skip to content

Commit

Permalink
Move SPN to ReadOnlyList
Browse files Browse the repository at this point in the history
  • Loading branch information
twsouthwick committed Jan 19, 2024
1 parent 2b73e80 commit f40acdc
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using Microsoft.Win32.SafeHandles;
Expand Down Expand Up @@ -146,7 +147,7 @@ private static SecurityStatusPal EstablishSecurityContext(
internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
string[] spns,
IReadOnlyList<string> spns,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Globalization;
using System.ComponentModel;
using Microsoft.Data;
using System.Collections.Generic;

namespace System.Net.Security
{
Expand Down Expand Up @@ -71,7 +72,7 @@ internal static string QueryContextAuthenticationPackage(SafeDeleteContext secur
internal static SecurityStatusPal InitializeSecurityContext(
SafeFreeCredentials credentialsHandle,
ref SafeDeleteContext securityContext,
string[] spn,
IReadOnlyList<string> spn,
ContextFlagsPal requestedContextFlags,
SecurityBuffer[] inSecurityBufferArray,
SecurityBuffer outSecurityBuffer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.Data.Common;
Expand Down Expand Up @@ -398,7 +399,7 @@ internal static unsafe uint SNIOpenSyncEx(
ConsumerInfo consumerInfo,
string constring,
ref IntPtr pConn,
byte[] spnBuffer,
ref string spn,
byte[] instanceName,
bool fOverrideCache,
bool fSync,
Expand Down Expand Up @@ -436,13 +437,25 @@ internal static unsafe uint SNIOpenSyncEx(
clientConsumerInfo.DNSCacheInfo.wszCachedTcpIPv6 = cachedDNSInfo?.AddrIPv6;
clientConsumerInfo.DNSCacheInfo.wszCachedTcpPort = cachedDNSInfo?.Port;

if (spnBuffer != null)
if (spn != null)
{
fixed (byte* pin_spnBuffer = &spnBuffer[0])
var array = ArrayPool<byte>.Shared.Rent(SniMaxComposedSpnLength);
try
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)spnBuffer.Length;
return SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);
fixed (byte* pin_spnBuffer = &array[0])
{
clientConsumerInfo.szSPN = pin_spnBuffer;
clientConsumerInfo.cchSPN = (uint)SniMaxComposedSpnLength;
var result = SNIOpenSyncExWrapper(ref clientConsumerInfo, out pConn);

spn = Encoding.Unicode.GetString(array, 0, (int)clientConsumerInfo.cchSPN).TrimEnd();

return result;
}
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}
else
Expand Down Expand Up @@ -471,24 +484,36 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, ref uint sendLength, string serverUserName)
{
fixed (byte* pin_serverUserName = &serverUserName[0])
fixed (byte* pInBuff = inBuff)
fixed (byte* pOutBuff = outBuff)
var serverNameLength = Encoding.Unicode.GetByteCount(serverUserName);
var serverNameArray = ArrayPool<byte>.Shared.Rent(serverNameLength);

try
{
Encoding.Unicode.GetBytes(serverUserName, 0, serverUserName.Length, serverNameArray, 0);

fixed (byte* pin_serverUserName = serverNameArray)
fixed (byte* pInBuff = inBuff)
fixed (byte* pOutBuff = outBuff)
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
pOutBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
(uint)serverNameLength,
null,
null);
}
}
finally
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
pOutBuff,
ref sendLength,
out local_fDone,
pin_serverUserName,
(uint)serverUserName.Length,
null,
null);
ArrayPool<byte>.Shared.Return(serverNameArray);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -32,25 +33,25 @@ internal class SNIProxy
/// </summary>
/// <param name="sspiClientContextStatus">SSPI client context status</param>
/// <param name="receivedBuff">Receive buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <param name="serverNames">Service Principal Name buffer</param>
/// <returns>Memory for response</returns>
internal static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, byte[][] serverName)
internal static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, IReadOnlyList<string> serverNames)
{
// TODO: this should use ReadOnlyMemory all the way through
var array = ArrayPool<byte>.Shared.Rent(receivedBuff.Length);

try
{
receivedBuff.CopyTo(array);
return GenSspiClientContext(sspiClientContextStatus, array, receivedBuff.Length, serverName);
return GenSspiClientContext(sspiClientContextStatus, array, receivedBuff.Length, serverNames);
}
finally
{
ArrayPool<byte>.Shared.Return(array);
}
}

private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, int receivedBuffLength, byte[][] serverName)
private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, int receivedBuffLength, IReadOnlyList<string> serverSPNs)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand All @@ -66,7 +67,7 @@ private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus s
SecurityBuffer[] inSecurityBufferArray;
if (receivedBuff != null)
{
inSecurityBufferArray = new SecurityBuffer[] { new SecurityBuffer(receivedBuff, SecurityBufferType.SECBUFFER_TOKEN) };
inSecurityBufferArray = new SecurityBuffer[] { new SecurityBuffer(receivedBuff, 0, receivedBuffLength, SecurityBufferType.SECBUFFER_TOKEN) };
}
else
{
Expand All @@ -82,11 +83,6 @@ private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus s
| ContextFlagsPal.Delegate
| ContextFlagsPal.MutualAuth;

string[] serverSPNs = new string[serverName.Length];
for (int i = 0; i < serverName.Length; i++)
{
serverSPNs[i] = Encoding.Unicode.GetString(serverName[i]);
}
SecurityStatusPal statusCode = NegotiateStreamPal.InitializeSecurityContext(
credentialsHandle,
ref securityContext,
Expand Down Expand Up @@ -162,7 +158,7 @@ internal static SNIHandle CreateConnectionHandle(
string fullServerName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spnBuffer,
string serverSPN,
bool flushCache,
bool async,
Expand Down Expand Up @@ -226,12 +222,12 @@ internal static SNIHandle CreateConnectionHandle(
return sniHandle;
}

private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN)
{
Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName));
if (!string.IsNullOrWhiteSpace(serverSPN))
{
return new byte[1][] { Encoding.Unicode.GetBytes(serverSPN) };
return new[] { serverSPN };
}

string hostName = dataSource.ServerName;
Expand All @@ -249,7 +245,7 @@ private static byte[][] GetSqlServerSPNs(DataSource dataSource, string serverSPN
return GetSqlServerSPNs(hostName, postfix, dataSource._connectionProtocol);
}

private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol)
{
Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress));
IPHostEntry hostEntry = null;
Expand Down Expand Up @@ -280,12 +276,12 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr
string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}";
// Set both SPNs with and without Port as Port is optional for default instance
SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn), Encoding.Unicode.GetBytes(serverSpnWithDefaultPort) };
return new[] { serverSpn, serverSpnWithDefaultPort };
}
// else Named Pipes do not need to valid port

SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn);
return new byte[][] { Encoding.Unicode.GetBytes(serverSpn) };
return new[] { serverSpn };
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ internal static void Assert(string message)

private bool _is2022 = false;

private byte[][] _sniSpnBuffer = null;
private string[] _sniSpn = null;

// SqlStatistics
private SqlStatistics _statistics = null;
Expand Down Expand Up @@ -404,7 +404,7 @@ internal void Connect(
}
else
{
_sniSpnBuffer = null;
_sniSpn = null;
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler._objectID,
authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString());
}
Expand All @@ -416,7 +416,7 @@ internal void Connect(
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|SEC> Encryption will be disabled as target server is a SQL Local DB instance.");
}

_sniSpnBuffer = null;
_sniSpn = null;
_authenticationProvider = null;

// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
Expand Down Expand Up @@ -455,7 +455,7 @@ internal void Connect(
serverInfo.ExtendedServerName,
timeout,
out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
false,
true,
fParallel,
Expand All @@ -468,7 +468,7 @@ internal void Connect(
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);
_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
Expand Down Expand Up @@ -554,7 +554,7 @@ internal void Connect(
_physicalStateObj.CreatePhysicalSNIHandle(
serverInfo.ExtendedServerName,
timeout, out instanceName,
ref _sniSpnBuffer,
ref _sniSpn,
true,
true,
fParallel,
Expand All @@ -567,15 +567,15 @@ internal void Connect(
hostNameInCertificate,
serverCertificateFilename);

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this);

if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status)
{
_physicalStateObj.AddError(ProcessSNIError(_physicalStateObj));
SqlClientEventSource.Log.TryTraceEvent("<sc.TdsParser.Connect|ERR|SEC> Login failure");
ThrowExceptionAndWarning(_physicalStateObj);
}

_authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, _sniSpn);

uint retCode = _physicalStateObj.SniGetConnectionId(ref _connHandler._clientConnectionId);

Debug.Assert(retCode == TdsEnums.SNI_SUCCESS, "Unexpected failure state upon calling SniGetConnectionId");
Expand Down Expand Up @@ -12850,7 +12850,7 @@ internal string TraceString()
_fMARS ? bool.TrueString : bool.FalseString,
null == _sessionPool ? "(null)" : _sessionPool.TraceString(),
_is2005 ? bool.TrueString : bool.FalseString,
null == _sniSpnBuffer ? "(null)" : _sniSpnBuffer.Length.ToString((IFormatProvider)null),
null == _sniSpn ? "(null)" : _sniSpn.Length.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null),
_physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ internal abstract void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool fParallel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool parallel,
Expand All @@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle(
string hostNameInCertificate,
string serverCertificateFilename)
{
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN,
SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spn, serverSPN,
flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst,
hostNameInCertificate, serverCertificateFilename);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ internal override void CreatePhysicalSNIHandle(
string serverName,
TimeoutTimer timeout,
out byte[] instanceName,
ref byte[][] spnBuffer,
ref string[] spn,
bool flushCache,
bool async,
bool fParallel,
Expand All @@ -157,30 +157,27 @@ internal override void CreatePhysicalSNIHandle(
string serverCertificateFilename)
{
// We assume that the loadSSPILibrary has been called already. now allocate proper length of buffer
spnBuffer = new byte[1][];
if (isIntegratedSecurity)
{
// now allocate proper length of buffer
if (!string.IsNullOrEmpty(serverSPN))
{
// Native SNI requires the Unicode encoding and any other encoding like UTF8 breaks the code.
byte[] srvSPN = Encoding.Unicode.GetBytes(serverSPN);
Trace.Assert(srvSPN.Length <= SNINativeMethodWrapper.SniMaxComposedSpnLength, "Length of the provided SPN exceeded the buffer size.");
spnBuffer[0] = srvSPN;
SqlClientEventSource.Log.TryTraceEvent("<{0}.{1}|SEC> Server SPN `{2}` from the connection string is used.", nameof(TdsParserStateObjectNative), nameof(CreatePhysicalSNIHandle), serverSPN);
}
else
{
spnBuffer[0] = new byte[SNINativeMethodWrapper.SniMaxComposedSpnLength];
serverSPN = string.Empty;
}
}

SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async);
SQLDNSInfo cachedDNSInfo;
bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo);

_sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName,
_sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName,
flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate);
spn = new[] { serverSPN.TrimEnd() };
}

protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize)
Expand Down
Loading

0 comments on commit f40acdc

Please sign in to comment.