Skip to content

Commit

Permalink
Use full TLS record size for application data on Windows (#95595)
Browse files Browse the repository at this point in the history
* Replace pointers by spans and refs in SslStreamPal.Windows

* Correctly calculate the MaxDataSize

* Remove unwanted change

* Fixes after rebase

* Remove couple of unsafe usages
  • Loading branch information
rzikm authored Dec 8, 2023
1 parent e5c0cbc commit 6253efe
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ internal void ProcessHandshakeSuccess()

_headerSize = streamSizes.Header;
_trailerSize = streamSizes.Trailer;
_maxDataSize = checked(streamSizes.MaximumMessage - (_headerSize + _trailerSize));
_maxDataSize = streamSizes.MaximumMessage;
Debug.Assert(_maxDataSize > 0, "_maxDataSize > 0");

SslStreamPal.QueryContextConnectionInfo(_securityContext!, ref _connectionInfo);
Expand All @@ -942,18 +942,6 @@ internal void ProcessHandshakeSuccess()
#endif
}

/*++
Encrypt - Encrypts our bytes before we send them over the wire
PERF: make more efficient, this does an extra copy when the offset
is non-zero.
Input:
buffer - bytes for sending
offset -
size -
output - Encrypted bytes
--*/
internal ProtocolToken Encrypt(ReadOnlyMemory<byte> buffer)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.DumpBuffer(this, buffer.Span);
Expand Down Expand Up @@ -1337,7 +1325,7 @@ internal void EnsureAvailableSpace(int size)

var oldPayload = Payload;

Payload = RentBuffer? ArrayPool<byte>.Shared.Rent(Size + size) : new byte[Size + size];
Payload = RentBuffer ? ArrayPool<byte>.Shared.Rent(Size + size) : new byte[Size + size];
if (oldPayload != null)
{
oldPayload.AsSpan<byte>().CopyTo(Payload);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Security.Authentication.ExtendedProtection;
Expand Down Expand Up @@ -49,7 +50,8 @@ public static Exception GetException(SecurityStatusPal status)

private static byte[] InitSessionTokenBuffer()
{
var schannelSessionToken = new Interop.SChannel.SCHANNEL_SESSION_TOKEN() {
var schannelSessionToken = new Interop.SChannel.SCHANNEL_SESSION_TOKEN()
{
dwTokenType = Interop.SChannel.SCHANNEL_SESSION,
dwFlags = Interop.SChannel.SSL_SESSION_DISABLE_RECONNECTS,
};
Expand All @@ -61,7 +63,7 @@ public static void VerifyPackageInfo()
SSPIWrapper.GetVerifyPackageInfo(GlobalSSPI.SSPISecureChannel, SecurityPackage, true);
}

private static unsafe void SetAlpn(ref InputSecurityBuffers inputBuffers, List<SslApplicationProtocol> alpn, Span<byte> localBuffer)
private static void SetAlpn(ref InputSecurityBuffers inputBuffers, List<SslApplicationProtocol> alpn, Span<byte> localBuffer)
{
if (alpn.Count == 1 && alpn[0] == SslApplicationProtocol.Http11)
{
Expand All @@ -82,7 +84,7 @@ private static unsafe void SetAlpn(ref InputSecurityBuffers inputBuffers, List<S
else
{
int protocolLength = Interop.Sec_Application_Protocols.GetProtocolLength(alpn);
int bufferLength = sizeof(Interop.Sec_Application_Protocols) + protocolLength;
int bufferLength = Unsafe.SizeOf<Interop.Sec_Application_Protocols>() + protocolLength;

Span<byte> alpnBuffer = bufferLength <= localBuffer.Length ? localBuffer : new byte[bufferLength];
Interop.Sec_Application_Protocols.SetProtocols(alpnBuffer, alpn, protocolLength);
Expand All @@ -99,7 +101,7 @@ public static SecurityStatusPal SelectApplicationProtocol(
throw new PlatformNotSupportedException(nameof(SelectApplicationProtocol));
}

public static unsafe ProtocolToken AcceptSecurityContext(
public static ProtocolToken AcceptSecurityContext(
ref SafeFreeCredentials? credentialsHandle,
ref SafeDeleteSslContext? context,
ReadOnlySpan<byte> inputBuffer,
Expand Down Expand Up @@ -141,7 +143,7 @@ public static bool TryUpdateClintCertificate(
return false;
}

public static unsafe ProtocolToken InitializeSecurityContext(
public static ProtocolToken InitializeSecurityContext(
ref SafeFreeCredentials? credentialsHandle,
ref SafeDeleteSslContext? context,
string? targetName,
Expand Down Expand Up @@ -445,32 +447,32 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC
input.Span.CopyTo(token.AvailableSpan.Slice(headerSize, input.Length));

const int NumSecBuffers = 4; // header + data + trailer + empty
Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
Span<Interop.SspiCli.SecBuffer> unmanagedBuffers = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(NumSecBuffers)
{
pBuffers = unmanagedBuffer
pBuffers = Unsafe.AsPointer(ref MemoryMarshal.GetReference(unmanagedBuffers))
};
fixed (byte* outputPtr = token.Payload)
{
Interop.SspiCli.SecBuffer* headerSecBuffer = &unmanagedBuffer[0];
headerSecBuffer->BufferType = SecurityBufferType.SECBUFFER_STREAM_HEADER;
headerSecBuffer->pvBuffer = (IntPtr)outputPtr;
headerSecBuffer->cbBuffer = headerSize;
ref Interop.SspiCli.SecBuffer headerSecBuffer = ref unmanagedBuffers[0];
headerSecBuffer.BufferType = SecurityBufferType.SECBUFFER_STREAM_HEADER;
headerSecBuffer.pvBuffer = (IntPtr)outputPtr;
headerSecBuffer.cbBuffer = headerSize;

Interop.SspiCli.SecBuffer* dataSecBuffer = &unmanagedBuffer[1];
dataSecBuffer->BufferType = SecurityBufferType.SECBUFFER_DATA;
dataSecBuffer->pvBuffer = (IntPtr)(outputPtr + headerSize);
dataSecBuffer->cbBuffer = input.Length;
ref Interop.SspiCli.SecBuffer dataSecBuffer = ref unmanagedBuffers[1];
dataSecBuffer.BufferType = SecurityBufferType.SECBUFFER_DATA;
dataSecBuffer.pvBuffer = (IntPtr)(outputPtr + headerSize);
dataSecBuffer.cbBuffer = input.Length;

Interop.SspiCli.SecBuffer* trailerSecBuffer = &unmanagedBuffer[2];
trailerSecBuffer->BufferType = SecurityBufferType.SECBUFFER_STREAM_TRAILER;
trailerSecBuffer->pvBuffer = (IntPtr)(outputPtr + headerSize + input.Length);
trailerSecBuffer->cbBuffer = trailerSize;
ref Interop.SspiCli.SecBuffer trailerSecBuffer = ref unmanagedBuffers[2];
trailerSecBuffer.BufferType = SecurityBufferType.SECBUFFER_STREAM_TRAILER;
trailerSecBuffer.pvBuffer = (IntPtr)(outputPtr + headerSize + input.Length);
trailerSecBuffer.cbBuffer = trailerSize;

Interop.SspiCli.SecBuffer* emptySecBuffer = &unmanagedBuffer[3];
emptySecBuffer->BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptySecBuffer->cbBuffer = 0;
emptySecBuffer->pvBuffer = IntPtr.Zero;
ref Interop.SspiCli.SecBuffer emptySecBuffer = ref unmanagedBuffers[3];
emptySecBuffer.BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptySecBuffer.cbBuffer = 0;
emptySecBuffer.pvBuffer = IntPtr.Zero;

int errorCode = GlobalSSPI.SSPISecureChannel.EncryptMessage(securityContext, ref sdcInOut, 0);

Expand All @@ -483,10 +485,10 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC
return token;
}

Debug.Assert(headerSecBuffer->cbBuffer >= 0 && dataSecBuffer->cbBuffer >= 0 && trailerSecBuffer->cbBuffer >= 0);
Debug.Assert(checked(headerSecBuffer->cbBuffer + dataSecBuffer->cbBuffer + trailerSecBuffer->cbBuffer) <= token.Payload!.Length);
Debug.Assert(headerSecBuffer.cbBuffer >= 0 && dataSecBuffer.cbBuffer >= 0 && trailerSecBuffer.cbBuffer >= 0);
Debug.Assert(checked(headerSecBuffer.cbBuffer + dataSecBuffer.cbBuffer + trailerSecBuffer.cbBuffer) <= token.Payload!.Length);

token.Size = checked(headerSecBuffer->cbBuffer + dataSecBuffer->cbBuffer + trailerSecBuffer->cbBuffer);
token.Size = checked(headerSecBuffer.cbBuffer + dataSecBuffer.cbBuffer + trailerSecBuffer.cbBuffer);
token.Status = new SecurityStatusPal(SecurityStatusPalErrorCode.OK);
}

Expand All @@ -496,25 +498,26 @@ public static unsafe ProtocolToken EncryptMessage(SafeDeleteSslContext securityC
public static unsafe SecurityStatusPal DecryptMessage(SafeDeleteSslContext? securityContext, Span<byte> buffer, out int offset, out int count)
{
const int NumSecBuffers = 4; // data + empty + empty + empty
fixed (byte* bufferPtr = buffer)

Span<Interop.SspiCli.SecBuffer> unmanagedBuffers = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
for (int i = 1; i < NumSecBuffers; i++)
{
Interop.SspiCli.SecBuffer* unmanagedBuffer = stackalloc Interop.SspiCli.SecBuffer[NumSecBuffers];
Interop.SspiCli.SecBuffer* dataBuffer = &unmanagedBuffer[0];
dataBuffer->BufferType = SecurityBufferType.SECBUFFER_DATA;
dataBuffer->pvBuffer = (IntPtr)bufferPtr;
dataBuffer->cbBuffer = buffer.Length;
ref Interop.SspiCli.SecBuffer emptyBuffer = ref unmanagedBuffers[i];
emptyBuffer.BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptyBuffer.pvBuffer = IntPtr.Zero;
emptyBuffer.cbBuffer = 0;
}

for (int i = 1; i < NumSecBuffers; i++)
{
Interop.SspiCli.SecBuffer* emptyBuffer = &unmanagedBuffer[i];
emptyBuffer->BufferType = SecurityBufferType.SECBUFFER_EMPTY;
emptyBuffer->pvBuffer = IntPtr.Zero;
emptyBuffer->cbBuffer = 0;
}
fixed (byte* bufferPtr = buffer)
{
ref Interop.SspiCli.SecBuffer dataBuffer = ref unmanagedBuffers[0];
dataBuffer.BufferType = SecurityBufferType.SECBUFFER_DATA;
dataBuffer.pvBuffer = (IntPtr)bufferPtr;
dataBuffer.cbBuffer = buffer.Length;

Interop.SspiCli.SecBufferDesc sdcInOut = new Interop.SspiCli.SecBufferDesc(NumSecBuffers)
{
pBuffers = unmanagedBuffer
pBuffers = Unsafe.AsPointer(ref MemoryMarshal.GetReference(unmanagedBuffers))
};
Interop.SECURITY_STATUS errorCode = (Interop.SECURITY_STATUS)GlobalSSPI.SSPISecureChannel.DecryptMessage(securityContext!, ref sdcInOut, out _);

Expand All @@ -525,12 +528,12 @@ public static unsafe SecurityStatusPal DecryptMessage(SafeDeleteSslContext? secu
for (int i = 0; i < NumSecBuffers; i++)
{
// Successfully decoded data and placed it at the following position in the buffer,
if ((errorCode == Interop.SECURITY_STATUS.OK && unmanagedBuffer[i].BufferType == SecurityBufferType.SECBUFFER_DATA)
if ((errorCode == Interop.SECURITY_STATUS.OK && unmanagedBuffers[i].BufferType == SecurityBufferType.SECBUFFER_DATA)
// or we failed to decode the data, here is the encoded data.
|| (errorCode != Interop.SECURITY_STATUS.OK && unmanagedBuffer[i].BufferType == SecurityBufferType.SECBUFFER_EXTRA))
|| (errorCode != Interop.SECURITY_STATUS.OK && unmanagedBuffers[i].BufferType == SecurityBufferType.SECBUFFER_EXTRA))
{
offset = (int)((byte*)unmanagedBuffer[i].pvBuffer - bufferPtr);
count = unmanagedBuffer[i].cbBuffer;
offset = (int)((byte*)unmanagedBuffers[i].pvBuffer - bufferPtr);
count = unmanagedBuffers[i].cbBuffer;

// output is ignored on Windows. We always decrypt in place and we set outputOffset to indicate where the data start.
Debug.Assert(offset >= 0 && count >= 0, $"Expected offset and count greater than 0, got {offset} and {count}");
Expand Down

0 comments on commit 6253efe

Please sign in to comment.