Skip to content

Commit

Permalink
only slice SocketAddress on success operation (dotnet#90284)
Browse files Browse the repository at this point in the history
* only slice SocketAddress on success operation

* null

* completed

* feedback

* feedback
  • Loading branch information
wfurt authored Aug 11, 2023
1 parent 2b4b6b2 commit caaed61
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,12 @@ protected override bool DoTryComplete(SocketAsyncContext context)
}
else
{
bool result = SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode);
SocketAddress = SocketAddress.Slice(0, socketAddressLen);
return result;
bool completed = SocketPal.TryCompleteReceiveFrom(context._socket, Buffer.Span, null, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode);
if (completed && ErrorCode == SocketError.Success)
{
SocketAddress = SocketAddress.Slice(0, socketAddressLen);
}
return completed;
}
}
}
Expand Down Expand Up @@ -508,7 +511,7 @@ public BufferListReceiveOperation(SocketAsyncContext context) : base(context) {
protected override bool DoTryComplete(SocketAsyncContext context)
{
bool completed = SocketPal.TryCompleteReceiveFrom(context._socket, default(Span<byte>), Buffers, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode);
if (ErrorCode == SocketError.Success)
if (completed && ErrorCode == SocketError.Success)
{
SocketAddress = SocketAddress.Slice(0, socketAddressLen);
}
Expand Down Expand Up @@ -542,7 +545,7 @@ public BufferPtrReceiveOperation(SocketAsyncContext context) : base(context) { }
protected override bool DoTryComplete(SocketAsyncContext context)
{
bool completed = SocketPal.TryCompleteReceiveFrom(context._socket, new Span<byte>(BufferPtr, Length), null, Flags, SocketAddress.Span, out int socketAddressLen, out BytesTransferred, out ReceivedFlags, out ErrorCode);
if (ErrorCode == SocketError.Success)
if (completed && ErrorCode == SocketError.Success)
{
SocketAddress = SocketAddress.Slice(0, socketAddressLen);
}
Expand All @@ -569,7 +572,7 @@ public ReceiveMessageFromOperation(SocketAsyncContext context) : base(context) {
protected override bool DoTryComplete(SocketAsyncContext context)
{
bool completed = SocketPal.TryCompleteReceiveMessageFrom(context._socket, Buffer.Span, Buffers, Flags, SocketAddress, out int socketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode);
if (ErrorCode == SocketError.Success)
if (completed && ErrorCode == SocketError.Success)
{
SocketAddress = SocketAddress.Slice(0, socketAddressLen);
}
Expand Down Expand Up @@ -599,7 +602,7 @@ public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(c
protected override bool DoTryComplete(SocketAsyncContext context)
{
bool completed = SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span<byte>(BufferPtr, Length), null, Flags, SocketAddress!, out int socketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode);
if (ErrorCode == SocketError.Success)
if (completed && ErrorCode == SocketError.Success)
{
SocketAddress = SocketAddress.Slice(0, socketAddressLen);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,7 @@ private static SocketError FinishOperationConnect()

private void UpdateReceivedSocketAddress(SocketAddress socketAddress)
{
if (_socketAddressSize > 0)
{
socketAddress.Size = _socketAddressSize;
}
socketAddress.Size = _socketAddressSize;
}

partial void FinishOperationReceiveMessageFrom();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags,
Count = (UIntPtr)buffer.Length
};

Debug.Assert(socketAddress.Length != 0 || sockAddr == null);

var messageHeader = new Interop.Sys.MessageHeader {
SocketAddress = sockAddr,
SocketAddressLen = socketAddress.Length,
Expand Down Expand Up @@ -468,7 +470,6 @@ private static unsafe int SysReceive(SafeSocketHandle socket, SocketFlags flags,
private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketFlags flags, Span<byte> buffer, Span<byte> socketAddress, out int socketAddressLen, bool isIPv4, bool isIPv6, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, out Interop.Error errno)
{
Debug.Assert(socket.IsSocket);
Debug.Assert(socketAddress != null, "Expected non-null socketAddress");

int cmsgBufferLen = Interop.Sys.GetControlMessageBufferSize(Convert.ToInt32(isIPv4), Convert.ToInt32(isIPv6));
byte* cmsgBuffer = stackalloc byte[cmsgBufferLen];
Expand All @@ -484,6 +485,8 @@ private static unsafe int SysReceiveMessageFrom(SafeSocketHandle socket, SocketF
Count = (UIntPtr)buffer.Length
};

Debug.Assert(socketAddress.Length != 0 || rawSocketAddress == null);

messageHeader = new Interop.Sys.MessageHeader {
SocketAddress = rawSocketAddress,
SocketAddressLen = socketAddress.Length,
Expand Down Expand Up @@ -1234,7 +1237,7 @@ public static SocketError Receive(SafeSocketHandle handle, IList<ArraySegment<by
}
else
{
if (!TryCompleteReceiveFrom(handle, buffers, socketFlags, null, out int _, out bytesTransferred, out _, out errorCode))
if (!TryCompleteReceiveFrom(handle, buffers, socketFlags, Span<byte>.Empty, out int _, out bytesTransferred, out _, out errorCode))
{
errorCode = SocketError.WouldBlock;
}
Expand Down

0 comments on commit caaed61

Please sign in to comment.