From 90b9e46b8bbcc86f4fef285c99309156a5e435e6 Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Wed, 17 Jan 2024 10:25:54 -0800 Subject: [PATCH] Use ReadOnlyMemory in --- .../Interop/SNINativeMethodWrapper.Windows.cs | 9 ++++--- .../Microsoft/Data/SqlClient/SNI/SNIProxy.cs | 19 +++++++++++-- .../Interop/SNINativeManagedWrapperARM64.cs | 2 +- .../Interop/SNINativeManagedWrapperX64.cs | 2 +- .../Interop/SNINativeManagedWrapperX86.cs | 2 +- .../Data/Interop/SNINativeMethodWrapper.cs | 27 ++++++++++--------- .../SSPI/ManagedSSPIContextProvider.cs | 5 ++-- .../SSPI/NativeSSPIContextProvider.cs | 4 +-- .../SSPI/NegotiateSSPIContextProvider.cs | 4 +-- .../SqlClient/SSPI/SSPIContextProvider.cs | 10 +++---- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 4 +-- 11 files changed, 53 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs index c8591a8c11..6f75be4e70 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs @@ -315,7 +315,7 @@ private static extern uint SNIOpenWrapper( [DllImport(SNI, CallingConvention = CallingConvention.Cdecl)] private static extern unsafe uint SNISecGenClientContextWrapper( [In] SNIHandle pConn, - [In, Out] byte[] pIn, + [In, Out] byte* pIn, uint cbIn, [In, Out] byte[] pOut, [In] ref uint pcbOut, @@ -471,15 +471,16 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int } } - internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName) + internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName) { fixed (byte* pin_serverUserName = &serverUserName[0]) + fixed (byte* pInBuff = inBuff) { bool local_fDone; return SNISecGenClientContextWrapper( pConnectionObject, - inBuff, - receivedLength, + pInBuff, + (uint)inBuff.Length, OutBuff, ref sendLength, out local_fDone, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 3df369a2f6..b0ddefee25 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Buffers; using System.Diagnostics; using System.IO; using System.Net; @@ -34,7 +35,21 @@ internal class SNIProxy /// Send buffer /// Service Principal Name buffer /// SNI error code - internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName) + internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory receivedBuff, ref byte[] sendBuff, byte[][] serverName) + { + // TODO: this should use ReadOnlyMemory all the way through + byte[] array = null; + + if (!receivedBuff.IsEmpty) + { + array = new byte[receivedBuff.Length]; + receivedBuff.CopyTo(array); + } + + GenSspiClientContext(sspiClientContextStatus, array, ref sendBuff, serverName); + } + + private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName) { SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext; ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags; @@ -189,7 +204,7 @@ internal static SNIHandle CreateConnectionHandle( case DataSource.Protocol.TCP: sniHandle = CreateTcpHandle(details, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); - break; + break; case DataSource.Protocol.NP: sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst); break; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperARM64.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperARM64.cs index 50552a3fb3..6e9bda11cd 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperARM64.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperARM64.cs @@ -116,7 +116,7 @@ internal static extern uint SNIOpenWrapper( [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISecGenClientContextWrapper")] internal static extern unsafe uint SNISecGenClientContextWrapper( [In] SNIHandle pConn, - [In, Out] byte[] pIn, + [In, Out] byte* pIn, uint cbIn, [In, Out] byte[] pOut, [In] ref uint pcbOut, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs index f4970e1cda..acb10c8c79 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs @@ -116,7 +116,7 @@ internal static extern uint SNIOpenWrapper( [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISecGenClientContextWrapper")] internal static extern unsafe uint SNISecGenClientContextWrapper( [In] SNIHandle pConn, - [In, Out] byte[] pIn, + [In, Out] byte* pIn, uint cbIn, [In, Out] byte[] pOut, [In] ref uint pcbOut, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs index 6e1a0abf5f..c8bb7c0e93 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs @@ -116,7 +116,7 @@ internal static extern uint SNIOpenWrapper( [DllImport(SNI, CallingConvention = CallingConvention.Cdecl, EntryPoint = "SNISecGenClientContextWrapper")] internal static extern unsafe uint SNISecGenClientContextWrapper( [In] SNIHandle pConn, - [In, Out] byte[] pIn, + [In, Out] byte* pIn, uint cbIn, [In, Out] byte[] pOut, [In] ref uint pcbOut, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs index b9af5849c4..898fcc6866 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs @@ -889,8 +889,7 @@ private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf, private static unsafe uint SNISecGenClientContextWrapper( [In] SNIHandle pConn, - [In, Out] byte[] pIn, - uint cbIn, + [In, Out] ReadOnlySpan pIn, [In, Out] byte[] pOut, [In] ref uint pcbOut, [MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone, @@ -899,16 +898,19 @@ private static unsafe uint SNISecGenClientContextWrapper( [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszUserName, [MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword) { - switch (s_architecture) + fixed (byte* pInPtr = pIn) { - case System.Runtime.InteropServices.Architecture.Arm64: - return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pIn, cbIn, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword); - case System.Runtime.InteropServices.Architecture.X64: - return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pIn, cbIn, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword); - case System.Runtime.InteropServices.Architecture.X86: - return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pIn, cbIn, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword); - default: - throw ADP.SNIPlatformNotSupported(s_architecture.ToString()); + switch (s_architecture) + { + case System.Runtime.InteropServices.Architecture.Arm64: + return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword); + case System.Runtime.InteropServices.Architecture.X64: + return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword); + case System.Runtime.InteropServices.Architecture.X86: + return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword); + default: + throw ADP.SNIPlatformNotSupported(s_architecture.ToString()); + } } } @@ -1378,7 +1380,7 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w } } - internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, byte[] inBuff, uint receivedLength, byte[] OutBuff, ref uint sendLength, byte[] serverUserName) + internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName) { fixed (byte* pin_serverUserName = &serverUserName[0]) { @@ -1386,7 +1388,6 @@ internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, return SNISecGenClientContextWrapper( pConnectionObject, inBuff, - receivedLength, OutBuff, ref sendLength, out local_fDone, diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs index 25fb2cadae..8395371de7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs @@ -1,5 +1,6 @@ #if !NETFRAMEWORK && !NET7_0_OR_GREATER +using System; using Microsoft.Data.SqlClient.SNI; #nullable enable @@ -10,11 +11,11 @@ internal sealed class ManagedSSPIContextProvider : SSPIContextProvider { private SspiClientContextStatus? _sspiClientContextStatus; - internal override void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) + internal override void GenerateSspiClientContext(ReadOnlyMemory received, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) { _sspiClientContextStatus ??= new SspiClientContextStatus(); - SNIProxy.GenSspiClientContext(_sspiClientContextStatus, receivedBuff, ref sendBuff, _sniSpnBuffer); + SNIProxy.GenSspiClientContext(_sspiClientContextStatus, received, ref sendBuff, _sniSpnBuffer); SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _physicalStateObj.SessionId); sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs index 718608d136..067682a617 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NativeSSPIContextProvider.cs @@ -50,7 +50,7 @@ private void LoadSSPILibrary() } } - internal override void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) + internal override void GenerateSspiClientContext(ReadOnlyMemory receivedBuff, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) { #if NETFRAMEWORK SNIHandle handle = _physicalStateObj.Handle; @@ -58,7 +58,7 @@ internal override void GenerateSspiClientContext(byte[] receivedBuff, uint recei Debug.Assert(_physicalStateObj.SessionHandle.Type == SessionHandle.NativeHandleType); SNIHandle handle = _physicalStateObj.SessionHandle.NativeHandle; #endif - if (0 != SNINativeMethodWrapper.SNISecGenClientContext(handle, receivedBuff, receivedLength, sendBuff, ref sendLength, _sniSpnBuffer[0])) + if (0 != SNINativeMethodWrapper.SNISecGenClientContext(handle, receivedBuff.Span, sendBuff, ref sendLength, _sniSpnBuffer[0])) { throw new InvalidOperationException(SQLMessage.SSPIGenerateError()); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs index bc11cf29c9..2d73d28390 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSSPIContextProvider.cs @@ -12,10 +12,10 @@ internal sealed class NegotiateSSPIContextProvider : SSPIContextProvider { private NegotiateAuthentication? _negotiateAuth = null; - internal override void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) + internal override void GenerateSspiClientContext(ReadOnlyMemory received, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer) { _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = Encoding.Unicode.GetString(_sniSpnBuffer[0]) }); - sendBuff = _negotiateAuth.GetOutgoingBlob(receivedBuff, out NegotiateAuthenticationStatusCode statusCode)!; + sendBuff = _negotiateAuth.GetOutgoingBlob(received.Span, out NegotiateAuthenticationStatusCode statusCode)!; SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}, StatusCode={1}", _physicalStateObj.SessionId, statusCode); if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs index fa8c27e3ab..256818660e 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SSPIContextProvider.cs @@ -27,16 +27,16 @@ private protected virtual void Initialize() { } - internal abstract void GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer); + internal abstract void GenerateSspiClientContext(ReadOnlyMemory input, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer); - internal void SSPIData(byte[] receivedBuff, UInt32 receivedLength, ref byte[] sendBuff, ref UInt32 sendLength, byte[] sniSpnBuffer) - => SSPIData(receivedBuff, receivedLength, ref sendBuff, ref sendLength, new[] { sniSpnBuffer }); + internal void SSPIData(ReadOnlyMemory receivedBuff, ref byte[] sendBuff, ref UInt32 sendLength, byte[] sniSpnBuffer) + => SSPIData(receivedBuff, ref sendBuff, ref sendLength, new[] { sniSpnBuffer }); - internal void SSPIData(byte[] receivedBuff, UInt32 receivedLength, ref byte[] sendBuff, ref UInt32 sendLength, byte[][] sniSpnBuffer) + internal void SSPIData(ReadOnlyMemory receivedBuff, ref byte[] sendBuff, ref UInt32 sendLength, byte[][] sniSpnBuffer) { try { - GenerateSspiClientContext(receivedBuff, receivedLength, ref sendBuff, ref sendLength, sniSpnBuffer); + GenerateSspiClientContext(receivedBuff, ref sendBuff, ref sendLength, sniSpnBuffer); } catch (Exception e) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 4366b4665a..8e0569b44b 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -31,7 +31,7 @@ internal void ProcessSSPI(int receivedLength) uint sendLength = _authenticationProvider.MaxSSPILength; // make call for SSPI data - _authenticationProvider.SSPIData(receivedBuff, (uint)receivedLength, ref sendBuff, ref sendLength, _sniSpnBuffer); + _authenticationProvider.SSPIData(receivedBuff.AsMemory(0, receivedLength), ref sendBuff, ref sendLength, _sniSpnBuffer); // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! _physicalStateObj.WriteByteArray(sendBuff, (int)sendLength, 0); @@ -194,7 +194,7 @@ internal void TdsLogin( // byte[] buffer and 0 for the int length. Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'"); _physicalStateObj.SniContext = SniContext.Snix_LoginSspi; - _authenticationProvider.SSPIData(Array.Empty(), 0, ref outSSPIBuff, ref outSSPILength, _sniSpnBuffer); + _authenticationProvider.SSPIData(ReadOnlyMemory.Empty, ref outSSPIBuff, ref outSSPILength, _sniSpnBuffer); if (outSSPILength > int.MaxValue) {