Skip to content

Commit

Permalink
Use ReadOnlyMemory in
Browse files Browse the repository at this point in the history
  • Loading branch information
twsouthwick committed Apr 8, 2024
1 parent 1bf025a commit 90b9e46
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ private static extern uint SNIOpenWrapper(
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl)]
private static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte[] pIn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In] ref uint pcbOut,
Expand Down Expand Up @@ -471,15 +471,16 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
{
fixed (byte* pin_serverUserName = &serverUserName[0])
fixed (byte* pInBuff = inBuff)
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
inBuff,
receivedLength,
pInBuff,
(uint)inBuff.Length,
OutBuff,
ref sendLength,
out local_fDone,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Net;
Expand Down Expand Up @@ -34,7 +35,21 @@ internal class SNIProxy
/// <param name="sendBuff">Send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
// TODO: this should use ReadOnlyMemory all the way through
byte[] array = null;

if (!receivedBuff.IsEmpty)
{
array = new byte[receivedBuff.Length];
receivedBuff.CopyTo(array);
}

GenSspiClientContext(sspiClientContextStatus, array, ref sendBuff, serverName);
}

private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -189,7 +204,7 @@ internal static SNIHandle CreateConnectionHandle(
case DataSource.Protocol.TCP:
sniHandle = CreateTcpHandle(details, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo,
tlsFirst, hostNameInCertificate, serverCertificateFilename);
break;
break;
case DataSource.Protocol.NP:
sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ internal static extern uint SNIOpenWrapper(
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISecGenClientContextWrapper")]
internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte[] pIn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In] ref uint pcbOut,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ internal static extern uint SNIOpenWrapper(
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISecGenClientContextWrapper")]
internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte[] pIn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In] ref uint pcbOut,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ internal static extern uint SNIOpenWrapper(
[DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISecGenClientContextWrapper")]
internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte[] pIn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In] ref uint pcbOut,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,7 @@ private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf,

private static unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte[] pIn,
uint cbIn,
[In, Out] ReadOnlySpan<byte> pIn,
[In, Out] byte[] pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
Expand All @@ -899,16 +898,19 @@ private static unsafe uint SNISecGenClientContextWrapper(
[MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszUserName,
[MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword)
{
switch (s_architecture)
fixed (byte* pInPtr = pIn)
{
case System.Runtime.InteropServices.Architecture.Arm64:
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pIn, cbIn, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X64:
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pIn, cbIn, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X86:
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pIn, cbIn, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
default:
throw ADP.SNIPlatformNotSupported(s_architecture.ToString());
switch (s_architecture)
{
case System.Runtime.InteropServices.Architecture.Arm64:
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X64:
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X86:
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
default:
throw ADP.SNIPlatformNotSupported(s_architecture.ToString());
}
}
}

Expand Down Expand Up @@ -1378,15 +1380,14 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
{
fixed (byte* pin_serverUserName = &serverUserName[0])
{
bool local_fDone;
return SNISecGenClientContextWrapper(
pConnectionObject,
inBuff,
receivedLength,
OutBuff,
ref sendLength,
out local_fDone,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#if !NETFRAMEWORK && !NET7_0_OR_GREATER

using System;
using Microsoft.Data.SqlClient.SNI;

#nullable enable
Expand All @@ -10,11 +11,11 @@ internal sealed class ManagedSSPIContextProvider : SSPIContextProvider
{
private SspiClientContextStatus? _sspiClientContextStatus;

internal override void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
internal override void GenerateSspiClientContext(ReadOnlyMemory<byte> received, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
_sspiClientContextStatus ??= new SspiClientContextStatus();

SNIProxy.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer);
SNIProxy.GenSspiClientContext(_sspiClientContextStatus, received, ref sendBuff, _sniSpnBuffer);
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _physicalStateObj.SessionId);
sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ private void LoadSSPILibrary()
}
}

internal override void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
internal override void GenerateSspiClientContext(ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
#if NETFRAMEWORK
SNIHandle handle = _physicalStateObj.Handle;
#else
Debug.Assert(_physicalStateObj.SessionHandle.Type == SessionHandle.NativeHandleType);
SNIHandle handle = _physicalStateObj.SessionHandle.NativeHandle;
#endif
if (0 != SNINativeMethodWrapper.SNISecGenClientContext(handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0]))
if (0 != SNINativeMethodWrapper.SNISecGenClientContext(handle, receivedBuff.Span, sendBuff, ref sendLength, _sniSpnBuffer[0]))
{
throw new InvalidOperationException(SQLMessage.SSPIGenerateError());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider
{
private NegotiateAuthentication? _negotiateAuth = null;

internal override void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
internal override void GenerateSspiClientContext(ReadOnlyMemory<byte> received, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
{
_negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = Encoding.Unicode.GetString(_sniSpnBuffer[0]) });
sendBuff = _negotiateAuth.GetOutgoingBlob(receivedBuff, out NegotiateAuthenticationStatusCode statusCode)!;
sendBuff = _negotiateAuth.GetOutgoingBlob(received.Span, out NegotiateAuthenticationStatusCode statusCode)!;
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}, StatusCode={1}", _physicalStateObj.SessionId, statusCode);
if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ private protected virtual void Initialize()
{
}

internal abstract void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);
internal abstract void GenerateSspiClientContext(ReadOnlyMemory<byte> input, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer);

internal void SSPIData(byte[] receivedBuff, UInt32 receivedLength, ref byte[] sendBuff, ref UInt32 sendLength, byte[] sniSpnBuffer)
=> SSPIData(receivedBuff, receivedLength, ref sendBuff, ref sendLength, new[] { sniSpnBuffer });
internal void SSPIData(ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, ref UInt32 sendLength, byte[] sniSpnBuffer)
=> SSPIData(receivedBuff, ref sendBuff, ref sendLength, new[] { sniSpnBuffer });

internal void SSPIData(byte[] receivedBuff, UInt32 receivedLength, ref byte[] sendBuff, ref UInt32 sendLength, byte[][] sniSpnBuffer)
internal void SSPIData(ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, ref UInt32 sendLength, byte[][] sniSpnBuffer)
{
try
{
GenerateSspiClientContext(receivedBuff, receivedLength, ref sendBuff, ref sendLength, sniSpnBuffer);
GenerateSspiClientContext(receivedBuff, ref sendBuff, ref sendLength, sniSpnBuffer);
}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ internal void ProcessSSPI(int receivedLength)
uint sendLength = _authenticationProvider.MaxSSPILength;

// make call for SSPI data
_authenticationProvider.SSPIData(receivedBuff, (uint)receivedLength, ref sendBuff, ref sendLength, _sniSpnBuffer);
_authenticationProvider.SSPIData(receivedBuff.AsMemory(0, receivedLength), ref sendBuff, ref sendLength, _sniSpnBuffer);

// DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA!
_physicalStateObj.WriteByteArray(sendBuff, (int)sendLength, 0);
Expand Down Expand Up @@ -194,7 +194,7 @@ internal void TdsLogin(
// byte[] buffer and 0 for the int length.
Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'");
_physicalStateObj.SniContext = SniContext.Snix_LoginSspi;
_authenticationProvider.SSPIData(Array.Empty<byte>(), 0, ref outSSPIBuff, ref outSSPILength, _sniSpnBuffer);
_authenticationProvider.SSPIData(ReadOnlyMemory<byte>.Empty, ref outSSPIBuff, ref outSSPILength, _sniSpnBuffer);

if (outSSPILength > int.MaxValue)
{
Expand Down

0 comments on commit 90b9e46

Please sign in to comment.